annotate mlp_factory_approach.py @ 286:2ee53bae9ee0

renamed _nnet_ops.py to _test_nnet_opt.py to be used with autotest
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Fri, 06 Jun 2008 13:55:59 -0400
parents ae0a8345869b
children eded3cb54930
rev   line source
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
1 import copy, sys, os
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
2 import numpy
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
3
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 import theano
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
5 from theano import tensor as T
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
7 from pylearn import dataset, nnet_ops, stopper, LookupList, filetensor
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
8
211
bd728c83faff in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 208
diff changeset
9
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
10 class AbstractFunction (Exception): pass
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
11
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
12 class AutoName(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
13 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
14 By inheriting from this class, class variables which have a name attribute
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
15 will have that name attribute set to the class variable name.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
16 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
17 class __metaclass__(type):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
18 def __init__(cls, name, bases, dct):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
19 type.__init__(name, bases, dct)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
20 for key, val in dct.items():
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
21 assert type(key) is str
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
22 if hasattr(val, 'name'):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
23 val.name = key
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
24
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
25 class GraphLearner(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
26 class Model(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
27 def __init__(self, algo, params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
28 self.algo = algo
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
29 self.params = params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
30 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
31 self.update_fn = algo._fn([graph.input, graph.target] + graph.params,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
32 [graph.nll] + graph.new_params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
33 self._fn_cache = {}
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
34
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
35 def __copy__(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
36 raise Exception('why not called?')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
37 return GraphLearner.Model(self.algo, [copy.copy(p) for p in params])
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
38
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
39 def __eq__(self,other,tolerance=0.) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
40 """ Only compares weights of matrices and bias vector. """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
41 if not isinstance(other,GraphLearner.Model) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
42 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
43 for p in range(4) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
44 if self.params[p].shape != other.params[p].shape :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
45 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
46 if not numpy.all( numpy.abs(self.params[p] - other.params[p]) <= tolerance ) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
47 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
48 return True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
49
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
50 def _cache(self, key, valfn):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
51 d = self._fn_cache
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
52 if key not in d:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
53 d[key] = valfn()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
54 return d[key]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
55
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
56 def update_minibatch(self, minibatch):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
57 #assert isinstance(minibatch, LookupList) # why false???
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
58 self.update_fn(minibatch['input'], minibatch['target'], *self.params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
59
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
60 def update(self, dataset,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
61 default_minibatch_size=32):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
62 """Update this model from more training data."""
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
63 params = self.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
64 minibatch_size = min(default_minibatch_size, len(dataset))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
65 for mb in dataset.minibatches(['input', 'target'], minibatch_size=minibatch_size):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
66 self.update_minibatch(mb)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
67
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
68 def save(self, f):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
69 self.algo.graph.save(f, self)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
70
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
71 def __call__(self, testset, fieldnames=['output_class']):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
72 """Apply this model (as a function) to new data.
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
73
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
74 @param testset: DataSet, whose fields feed Result terms in self.algo.g
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
75 @type testset: DataSet
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
76
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
77 @param fieldnames: names of results in self.algo.g to compute.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
78 @type fieldnames: list of strings
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
79
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
80 @return: DataSet with fields from fieldnames, computed from testset by
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
81 this model.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
82 @rtype: ApplyFunctionDataSet instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
83
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
84 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
85 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
86 def getresult(name):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
87 r = getattr(graph, name)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
88 if not isinstance(r, theano.Result):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
89 raise TypeError('string does not name a theano.Result', (name, r))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
90 return r
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
91
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
92 provided = [getresult(name) for name in testset.fieldNames()]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
93 wanted = [getresult(name) for name in fieldnames]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
94 inputs = provided + graph.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
95
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
96 theano_fn = self._cache((tuple(inputs), tuple(wanted)),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
97 lambda: self.algo._fn(inputs, wanted))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
98 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
99 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
100
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
101 class Graph(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
102 class Opt(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
103 merge = theano.gof.MergeOptimizer()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
104 gemm_opt_1 = theano.gof.TopoOptimizer(theano.tensor_opt.gemm_pattern_1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
105 sqr_opt_0 = theano.gof.TopoOptimizer(theano.gof.PatternSub(
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
106 (T.mul,'x', 'x'),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
107 (T.sqr, 'x')))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
108
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
109 def __init__(self, do_sqr=True):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
110 self.do_sqr = do_sqr
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
111
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
112 def __call__(self, env):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
113 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
114 self.gemm_opt_1(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
115 if self.do_sqr:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
116 self.sqr_opt_0(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
117 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
118
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
119 def linker(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
120 return theano.gof.PerformLinker()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
121
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
122 def early_stopper(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
123 stopper.NStages(10,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
124
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
125 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
126 raise AbstractFunction
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
127 optimizer = Opt()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
128
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
129 def load(self,f) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
130 raise AbstractFunction
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
131
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
132 def save(self,f,model) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
133 raise AbstractFunction
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
134
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
135
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
136 def __init__(self, graph):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
137 self.graph = graph
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
138
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
139 def _fn(self, inputs, outputs):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
140 # Caching here would hamper multi-threaded apps
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
141 # prefer caching in Model.__call__
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
142 return theano.function(inputs, outputs,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
143 unpack_single=False,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
144 optimizer=self.graph.optimizer,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
145 linker=self.graph.linker() if hasattr(self.graph, 'linker')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
146 else 'c&py')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
147
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
148 def __call__(self,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
149 trainset=None,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
150 validset=None,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
151 iparams=None):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
152 """Allocate and optionally train a model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
153
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
154 @param trainset: Data for minimizing the cost function
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
155 @type trainset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
156
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
157 @param validset: Data for early stopping
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
158 @type validset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
159
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
160 @param input: name of field to use as input
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
161 @type input: string
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
162
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
163 @param target: name of field to use as target
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
164 @type target: string
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
165
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
166 @return: model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
167 @rtype: GraphLearner.Model instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
168
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
169 """
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
170
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
171 iparams = self.graph.iparams() if iparams is None else iparams
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
172
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
173 # if we load, type(trainset) == 'str'
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
174 if isinstance(trainset,str) or isinstance(trainset,file):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
175 #loadmodel = GraphLearner.Model(self, iparams)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
176 loadmodel = self.graph.load(self,trainset)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
177 return loadmodel
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
178
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
179 curmodel = GraphLearner.Model(self, iparams)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
180 best = curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
181
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
182 if trainset is not None:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
183 #do some training by calling Model.update_minibatch()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
184 stp = self.graph.early_stopper()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
185 for mb in self.graph.train_iter(trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
186 curmodel.update_minibatch(mb)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
187 if stp.set_score:
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
188 if validset:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
189 stp.score = curmodel(validset, ['validset_score'])
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
190 if (stp.score < stp.best_score):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
191 best = copy.copy(curmodel)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
192 else:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
193 stp.score = 0.0
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
194 stp.next()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
195 if validset:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
196 curmodel = best
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
197 return curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
198
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
199
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
200 def graphMLP(ninputs, nhid, nclass, lr_val, l2coef_val=0.0):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
201
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
202
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
203 def wrapper(i, node, thunk):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
204 if 0:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
205 print i, node
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
206 print thunk.inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
207 print thunk.outputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
208 if node.op == nnet_ops.crossentropy_softmax_1hot_with_bias:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
209 print 'here is the nll op'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
210 thunk() #actually compute this piece of the graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
211
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
212 class G(GraphLearner.Graph, AutoName):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
213
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
214 lr = T.constant(lr_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
215 assert l2coef_val == 0.0
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
216 l2coef = T.constant(l2coef_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
217 input = T.matrix() # n_examples x n_inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
218 target = T.ivector() # len: n_examples
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
219 W2, b2 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
220
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
221 W1, b1 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
222 hid = T.tanh(b1 + T.dot(input, W1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
223 hid_regularization = l2coef * T.sum(W1*W1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
224
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
225 params = [W1, b1, W2, b2]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
226 activations = b2 + T.dot(hid, W2)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
227 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
228 regularization = l2coef * T.sum(W2*W2) + hid_regularization
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
229 output_class = T.argmax(activations,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
230 loss_01 = T.neq(output_class, target)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
231 #g_params = T.grad(nll + regularization, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
232 g_params = T.grad(nll, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
233 new_params = [T.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
234
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
235
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
236 def __eq__(self,other) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
237 print 'G.__eq__ from graphMLP(), not implemented yet'
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
238 return NotImplemented
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
239
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
240
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
241 def load(self, algo, f):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
242 """ Load from file the 2 matrices and bias vectors """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
243 cloase_at_end = False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
244 if isinstance(f,str) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
245 f = open(f,'r')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
246 close_at_end = True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
247 params = []
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
248 for i in xrange(4):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
249 params.append(filetensor.read(f))
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
250 if close_at_end :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
251 f.close()
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
252 return GraphLearner.Model(algo, params)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
253
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
254 def save(self, f, model):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
255 """ Save params to file, so 2 matrices and 2 bias vectors. Same order as iparams. """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
256 cloase_at_end = False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
257 if isinstance(f,str) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
258 f = open(f,'w')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
259 close_at_end = True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
260 for p in model.params:
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
261 filetensor.write(f,p)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
262 if close_at_end :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
263 f.close()
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
265
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
266 def iparams(self):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
267 """ init params. """
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
268 def randsmall(*shape):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
269 return (numpy.random.rand(*shape) -0.5) * 0.001
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
270 return [randsmall(ninputs, nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
271 , randsmall(nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
272 , randsmall(nhid, nclass)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
273 , randsmall(nclass)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
274
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
275 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
276 return trainset.minibatches(['input', 'target'],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
277 minibatch_size=min(len(trainset), 32), n_batches=300)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
278 def early_stopper(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
279 return stopper.NStages(300,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
280
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
281 return G()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
282
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
283
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
284 import unittest
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
285
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
286 class TestMLP(unittest.TestCase):
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
287 def blah(self, g):
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
288 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
289 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
290 [1, 0, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
291 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
292 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
293 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
294 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
295 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
296 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
297 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
298 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
299 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
300 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
301 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
302 {'input':slice(2)})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
303
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
304 learn_algo = GraphLearner(g)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
305
232
c047238e5b3f Fixed by James
delallea@opale.iro.umontreal.ca
parents: 226
diff changeset
306 model1 = learn_algo(training_set1)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
307
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
308 model2 = learn_algo(training_set2)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
309
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
310 omatch = [o1 == o2 for o1, o2 in zip(model1(test_data),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
311 model2(test_data))]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
312
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
313 n_match = sum(omatch)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
314
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
315 self.failUnless(n_match == (numpy.sum(training_set1.fields()['target'] ==
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
316 training_set2.fields()['target'])), omatch)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
317
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
318 model1.save('/tmp/model1')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
319
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
320 #denoising_aa = GraphLearner(denoising_g)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
321 #model1 = denoising_aa(trainset)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
322 #hidset = model(trainset, fieldnames=['hidden'])
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
323 #model2 = denoising_aa(hidset)
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
324
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
325 #f = open('blah', 'w')
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
326 #for m in model:
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
327 # m.save(f)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
328 #filetensor.write(f, initial_classification_weights)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
329 #f.flush()
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
330
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
331 #deep_sigmoid_net = GraphLearner(deepnetwork_g)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
332 #deep_model = deep_sigmoid_net.load('blah')
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
333 #deep_model.update(trainset) #do some fine tuning
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
334
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
335 model1_dup = learn_algo('/tmp/model1')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
336
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
337
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
338 def equiv(self, g0, g1):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
339 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
340 [0, 1, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
341 [1, 0, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
342 [1, 1, 1]]),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
343 {'input':slice(2),'target':2})
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
344 learn_algo_0 = GraphLearner(g0)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
345 learn_algo_1 = GraphLearner(g1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
346
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
347 model_0 = learn_algo_0(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
348 model_1 = learn_algo_1(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
349
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
350 print '----'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
351 for p in zip(model_0.params, model_1.params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
352 abs_rel_err = theano.gradient.numeric_grad.abs_rel_err(p[0], p[1])
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
353 max_abs_rel_err = numpy.max(abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
354 if max_abs_rel_err > 1.0e-7:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
355 print 'p0', p[0]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
356 print 'p1', p[1]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
357 #self.failUnless(max_abs_rel_err < 1.0e-7, max_abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
358
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
359
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
360 def test0(self): self.blah(graphMLP(2, 10, 2, .1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
361 def test1(self): self.blah(graphMLP(2, 3, 2, .1))
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
362
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
363 if __name__ == '__main__':
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
364 unittest.main()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
365
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
366