Mercurial > pylearn
annotate mlp_factory_approach.py @ 286:2ee53bae9ee0
renamed _nnet_ops.py to _test_nnet_opt.py to be used with autotest
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 13:55:59 -0400 |
parents | ae0a8345869b |
children | eded3cb54930 |
rev | line source |
---|---|
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
1 import copy, sys, os |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
2 import numpy |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
3 |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
4 import theano |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
5 from theano import tensor as T |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
6 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
7 from pylearn import dataset, nnet_ops, stopper, LookupList, filetensor |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
8 |
211
bd728c83faff
in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
208
diff
changeset
|
9 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
10 class AbstractFunction (Exception): pass |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
11 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
12 class AutoName(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
13 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
14 By inheriting from this class, class variables which have a name attribute |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
15 will have that name attribute set to the class variable name. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
16 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
17 class __metaclass__(type): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
18 def __init__(cls, name, bases, dct): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
19 type.__init__(name, bases, dct) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
20 for key, val in dct.items(): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
21 assert type(key) is str |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
22 if hasattr(val, 'name'): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
23 val.name = key |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
24 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
25 class GraphLearner(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
26 class Model(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
27 def __init__(self, algo, params): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
28 self.algo = algo |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
29 self.params = params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
30 graph = self.algo.graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
31 self.update_fn = algo._fn([graph.input, graph.target] + graph.params, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
32 [graph.nll] + graph.new_params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
33 self._fn_cache = {} |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
34 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
35 def __copy__(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
36 raise Exception('why not called?') |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
37 return GraphLearner.Model(self.algo, [copy.copy(p) for p in params]) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
38 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
39 def __eq__(self,other,tolerance=0.) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
40 """ Only compares weights of matrices and bias vector. """ |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
41 if not isinstance(other,GraphLearner.Model) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
42 return False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
43 for p in range(4) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
44 if self.params[p].shape != other.params[p].shape : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
45 return False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
46 if not numpy.all( numpy.abs(self.params[p] - other.params[p]) <= tolerance ) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
47 return False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
48 return True |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
49 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
50 def _cache(self, key, valfn): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
51 d = self._fn_cache |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
52 if key not in d: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
53 d[key] = valfn() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
54 return d[key] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
55 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
56 def update_minibatch(self, minibatch): |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
57 #assert isinstance(minibatch, LookupList) # why false??? |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
58 self.update_fn(minibatch['input'], minibatch['target'], *self.params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
59 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
60 def update(self, dataset, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
61 default_minibatch_size=32): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
62 """Update this model from more training data.""" |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
63 params = self.params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
64 minibatch_size = min(default_minibatch_size, len(dataset)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
65 for mb in dataset.minibatches(['input', 'target'], minibatch_size=minibatch_size): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
66 self.update_minibatch(mb) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
67 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
68 def save(self, f): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
69 self.algo.graph.save(f, self) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
70 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
71 def __call__(self, testset, fieldnames=['output_class']): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
72 """Apply this model (as a function) to new data. |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
73 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
74 @param testset: DataSet, whose fields feed Result terms in self.algo.g |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
75 @type testset: DataSet |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
76 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
77 @param fieldnames: names of results in self.algo.g to compute. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
78 @type fieldnames: list of strings |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
79 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
80 @return: DataSet with fields from fieldnames, computed from testset by |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
81 this model. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
82 @rtype: ApplyFunctionDataSet instance |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
83 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
84 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
85 graph = self.algo.graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
86 def getresult(name): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
87 r = getattr(graph, name) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
88 if not isinstance(r, theano.Result): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
89 raise TypeError('string does not name a theano.Result', (name, r)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
90 return r |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
91 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
92 provided = [getresult(name) for name in testset.fieldNames()] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
93 wanted = [getresult(name) for name in fieldnames] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
94 inputs = provided + graph.params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
95 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
96 theano_fn = self._cache((tuple(inputs), tuple(wanted)), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
97 lambda: self.algo._fn(inputs, wanted)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
98 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
99 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
100 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
101 class Graph(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
102 class Opt(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
103 merge = theano.gof.MergeOptimizer() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
104 gemm_opt_1 = theano.gof.TopoOptimizer(theano.tensor_opt.gemm_pattern_1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
105 sqr_opt_0 = theano.gof.TopoOptimizer(theano.gof.PatternSub( |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
106 (T.mul,'x', 'x'), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
107 (T.sqr, 'x'))) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
108 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
109 def __init__(self, do_sqr=True): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
110 self.do_sqr = do_sqr |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
111 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
112 def __call__(self, env): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
113 self.merge(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
114 self.gemm_opt_1(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
115 if self.do_sqr: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
116 self.sqr_opt_0(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
117 self.merge(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
118 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
119 def linker(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
120 return theano.gof.PerformLinker() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
121 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
122 def early_stopper(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
123 stopper.NStages(10,1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
124 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
125 def train_iter(self, trainset): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
126 raise AbstractFunction |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
127 optimizer = Opt() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
128 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
129 def load(self,f) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
130 raise AbstractFunction |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
131 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
132 def save(self,f,model) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
133 raise AbstractFunction |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
134 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
135 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
136 def __init__(self, graph): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
137 self.graph = graph |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
138 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
139 def _fn(self, inputs, outputs): |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
140 # Caching here would hamper multi-threaded apps |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
141 # prefer caching in Model.__call__ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
142 return theano.function(inputs, outputs, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
143 unpack_single=False, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
144 optimizer=self.graph.optimizer, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
145 linker=self.graph.linker() if hasattr(self.graph, 'linker') |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
146 else 'c&py') |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
147 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
148 def __call__(self, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
149 trainset=None, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
150 validset=None, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
151 iparams=None): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
152 """Allocate and optionally train a model |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
153 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
154 @param trainset: Data for minimizing the cost function |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
155 @type trainset: None or Dataset |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
156 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
157 @param validset: Data for early stopping |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
158 @type validset: None or Dataset |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
159 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
160 @param input: name of field to use as input |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
161 @type input: string |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
162 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
163 @param target: name of field to use as target |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
164 @type target: string |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
165 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
166 @return: model |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
167 @rtype: GraphLearner.Model instance |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
168 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
169 """ |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
170 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
171 iparams = self.graph.iparams() if iparams is None else iparams |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
172 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
173 # if we load, type(trainset) == 'str' |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
174 if isinstance(trainset,str) or isinstance(trainset,file): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
175 #loadmodel = GraphLearner.Model(self, iparams) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
176 loadmodel = self.graph.load(self,trainset) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
177 return loadmodel |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
178 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
179 curmodel = GraphLearner.Model(self, iparams) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
180 best = curmodel |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
181 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
182 if trainset is not None: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
183 #do some training by calling Model.update_minibatch() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
184 stp = self.graph.early_stopper() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
185 for mb in self.graph.train_iter(trainset): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
186 curmodel.update_minibatch(mb) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
187 if stp.set_score: |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
188 if validset: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
189 stp.score = curmodel(validset, ['validset_score']) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
190 if (stp.score < stp.best_score): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
191 best = copy.copy(curmodel) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
192 else: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
193 stp.score = 0.0 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
194 stp.next() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
195 if validset: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
196 curmodel = best |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
197 return curmodel |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
198 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
199 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
200 def graphMLP(ninputs, nhid, nclass, lr_val, l2coef_val=0.0): |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
201 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
202 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
203 def wrapper(i, node, thunk): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
204 if 0: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
205 print i, node |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
206 print thunk.inputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
207 print thunk.outputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
208 if node.op == nnet_ops.crossentropy_softmax_1hot_with_bias: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
209 print 'here is the nll op' |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
210 thunk() #actually compute this piece of the graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
211 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
212 class G(GraphLearner.Graph, AutoName): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
213 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
214 lr = T.constant(lr_val) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
215 assert l2coef_val == 0.0 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
216 l2coef = T.constant(l2coef_val) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
217 input = T.matrix() # n_examples x n_inputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
218 target = T.ivector() # len: n_examples |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
219 W2, b2 = T.matrix(), T.vector() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
220 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
221 W1, b1 = T.matrix(), T.vector() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
222 hid = T.tanh(b1 + T.dot(input, W1)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
223 hid_regularization = l2coef * T.sum(W1*W1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
224 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
225 params = [W1, b1, W2, b2] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
226 activations = b2 + T.dot(hid, W2) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
227 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
228 regularization = l2coef * T.sum(W2*W2) + hid_regularization |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
229 output_class = T.argmax(activations,1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
230 loss_01 = T.neq(output_class, target) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
231 #g_params = T.grad(nll + regularization, params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
232 g_params = T.grad(nll, params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
233 new_params = [T.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
234 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
235 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
236 def __eq__(self,other) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
237 print 'G.__eq__ from graphMLP(), not implemented yet' |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
238 return NotImplemented |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
239 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
240 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
241 def load(self, algo, f): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
242 """ Load from file the 2 matrices and bias vectors """ |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
243 cloase_at_end = False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
244 if isinstance(f,str) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
245 f = open(f,'r') |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
246 close_at_end = True |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
247 params = [] |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
248 for i in xrange(4): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
249 params.append(filetensor.read(f)) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
250 if close_at_end : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
251 f.close() |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
252 return GraphLearner.Model(algo, params) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
253 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
254 def save(self, f, model): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
255 """ Save params to file, so 2 matrices and 2 bias vectors. Same order as iparams. """ |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
256 cloase_at_end = False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
257 if isinstance(f,str) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
258 f = open(f,'w') |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
259 close_at_end = True |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
260 for p in model.params: |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
261 filetensor.write(f,p) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
262 if close_at_end : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
263 f.close() |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
264 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
265 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
266 def iparams(self): |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
267 """ init params. """ |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
268 def randsmall(*shape): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
269 return (numpy.random.rand(*shape) -0.5) * 0.001 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
270 return [randsmall(ninputs, nhid) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
271 , randsmall(nhid) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
272 , randsmall(nhid, nclass) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
273 , randsmall(nclass)] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
274 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
275 def train_iter(self, trainset): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
276 return trainset.minibatches(['input', 'target'], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
277 minibatch_size=min(len(trainset), 32), n_batches=300) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
278 def early_stopper(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
279 return stopper.NStages(300,1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
280 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
281 return G() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
282 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
283 |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
284 import unittest |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
285 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
286 class TestMLP(unittest.TestCase): |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
287 def blah(self, g): |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
288 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
289 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
290 [1, 0, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
291 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
292 {'input':slice(2),'target':2}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
293 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
294 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
295 [1, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
296 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
297 {'input':slice(2),'target':2}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
298 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
299 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
300 [1, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
301 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
302 {'input':slice(2)}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
303 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
304 learn_algo = GraphLearner(g) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
305 |
232 | 306 model1 = learn_algo(training_set1) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
307 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
308 model2 = learn_algo(training_set2) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
309 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
310 omatch = [o1 == o2 for o1, o2 in zip(model1(test_data), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
311 model2(test_data))] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
312 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
313 n_match = sum(omatch) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
314 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
315 self.failUnless(n_match == (numpy.sum(training_set1.fields()['target'] == |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
316 training_set2.fields()['target'])), omatch) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
317 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
318 model1.save('/tmp/model1') |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
319 |
265
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
320 #denoising_aa = GraphLearner(denoising_g) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
321 #model1 = denoising_aa(trainset) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
322 #hidset = model(trainset, fieldnames=['hidden']) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
323 #model2 = denoising_aa(hidset) |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
324 |
265
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
325 #f = open('blah', 'w') |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
326 #for m in model: |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
327 # m.save(f) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
328 #filetensor.write(f, initial_classification_weights) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
329 #f.flush() |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
330 |
265
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
331 #deep_sigmoid_net = GraphLearner(deepnetwork_g) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
332 #deep_model = deep_sigmoid_net.load('blah') |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
333 #deep_model.update(trainset) #do some fine tuning |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
334 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
335 model1_dup = learn_algo('/tmp/model1') |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
336 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
337 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
338 def equiv(self, g0, g1): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
339 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
340 [0, 1, 1], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
341 [1, 0, 1], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
342 [1, 1, 1]]), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
343 {'input':slice(2),'target':2}) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
344 learn_algo_0 = GraphLearner(g0) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
345 learn_algo_1 = GraphLearner(g1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
346 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
347 model_0 = learn_algo_0(training_set1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
348 model_1 = learn_algo_1(training_set1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
349 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
350 print '----' |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
351 for p in zip(model_0.params, model_1.params): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
352 abs_rel_err = theano.gradient.numeric_grad.abs_rel_err(p[0], p[1]) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
353 max_abs_rel_err = numpy.max(abs_rel_err) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
354 if max_abs_rel_err > 1.0e-7: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
355 print 'p0', p[0] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
356 print 'p1', p[1] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
357 #self.failUnless(max_abs_rel_err < 1.0e-7, max_abs_rel_err) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
358 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
359 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
360 def test0(self): self.blah(graphMLP(2, 10, 2, .1)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
361 def test1(self): self.blah(graphMLP(2, 3, 2, .1)) |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
362 |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
363 if __name__ == '__main__': |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
364 unittest.main() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
365 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
366 |