Mercurial > pylearn
annotate mlp_factory_approach.py @ 436:d7ed780364b3
image_tools
author | Olivier Breuleux <breuleuo@iro.umontreal.ca> |
---|---|
date | Wed, 06 Aug 2008 19:39:14 -0400 |
parents | 93280a0c151a |
children |
rev | line source |
---|---|
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
1 import copy, sys, os |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
2 import numpy |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
3 |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
4 import theano |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
5 from theano import tensor as T |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
6 |
299
eded3cb54930
small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
265
diff
changeset
|
7 import dataset, nnet_ops, stopper, filetensor |
304
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
8 from pylearn.lookup_list import LookupList |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
9 |
211
bd728c83faff
in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
208
diff
changeset
|
10 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
11 class AbstractFunction (Exception): pass |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
12 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
13 class AutoName(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
14 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
15 By inheriting from this class, class variables which have a name attribute |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
16 will have that name attribute set to the class variable name. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
17 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
18 class __metaclass__(type): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
19 def __init__(cls, name, bases, dct): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
20 type.__init__(name, bases, dct) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
21 for key, val in dct.items(): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
22 assert type(key) is str |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
23 if hasattr(val, 'name'): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
24 val.name = key |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
25 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
26 class GraphLearner(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
27 class Model(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
28 def __init__(self, algo, params): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
29 self.algo = algo |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
30 self.params = params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
31 graph = self.algo.graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
32 self.update_fn = algo._fn([graph.input, graph.target] + graph.params, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
33 [graph.nll] + graph.new_params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
34 self._fn_cache = {} |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
35 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
36 def __copy__(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
37 raise Exception('why not called?') |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
38 return GraphLearner.Model(self.algo, [copy.copy(p) for p in params]) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
39 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
40 def __eq__(self,other,tolerance=0.) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
41 """ Only compares weights of matrices and bias vector. """ |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
42 if not isinstance(other,GraphLearner.Model) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
43 return False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
44 for p in range(4) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
45 if self.params[p].shape != other.params[p].shape : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
46 return False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
47 if not numpy.all( numpy.abs(self.params[p] - other.params[p]) <= tolerance ) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
48 return False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
49 return True |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
50 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
51 def _cache(self, key, valfn): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
52 d = self._fn_cache |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
53 if key not in d: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
54 d[key] = valfn() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
55 return d[key] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
56 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
57 def update_minibatch(self, minibatch): |
299
eded3cb54930
small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
265
diff
changeset
|
58 if not isinstance(minibatch, LookupList): |
eded3cb54930
small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
265
diff
changeset
|
59 print type(minibatch) |
eded3cb54930
small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
265
diff
changeset
|
60 assert isinstance(minibatch, LookupList) |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
61 self.update_fn(minibatch['input'], minibatch['target'], *self.params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
62 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
63 def update(self, dataset, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
64 default_minibatch_size=32): |
305
93280a0c151a
more verbose
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
304
diff
changeset
|
65 """ |
93280a0c151a
more verbose
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
304
diff
changeset
|
66 Update this model from more training data.Uses all the data once, cut |
93280a0c151a
more verbose
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
304
diff
changeset
|
67 into minibatches. No early stopper here. |
93280a0c151a
more verbose
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
304
diff
changeset
|
68 """ |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
69 params = self.params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
70 minibatch_size = min(default_minibatch_size, len(dataset)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
71 for mb in dataset.minibatches(['input', 'target'], minibatch_size=minibatch_size): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
72 self.update_minibatch(mb) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
73 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
74 def save(self, f): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
75 self.algo.graph.save(f, self) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
76 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
77 def __call__(self, testset, fieldnames=['output_class']): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
78 """Apply this model (as a function) to new data. |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
79 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
80 @param testset: DataSet, whose fields feed Result terms in self.algo.g |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
81 @type testset: DataSet |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
82 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
83 @param fieldnames: names of results in self.algo.g to compute. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
84 @type fieldnames: list of strings |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
85 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
86 @return: DataSet with fields from fieldnames, computed from testset by |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
87 this model. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
88 @rtype: ApplyFunctionDataSet instance |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
89 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
90 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
91 graph = self.algo.graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
92 def getresult(name): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
93 r = getattr(graph, name) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
94 if not isinstance(r, theano.Result): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
95 raise TypeError('string does not name a theano.Result', (name, r)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
96 return r |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
97 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
98 provided = [getresult(name) for name in testset.fieldNames()] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
99 wanted = [getresult(name) for name in fieldnames] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
100 inputs = provided + graph.params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
101 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
102 theano_fn = self._cache((tuple(inputs), tuple(wanted)), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
103 lambda: self.algo._fn(inputs, wanted)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
104 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
105 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
106 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
107 class Graph(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
108 class Opt(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
109 merge = theano.gof.MergeOptimizer() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
110 gemm_opt_1 = theano.gof.TopoOptimizer(theano.tensor_opt.gemm_pattern_1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
111 sqr_opt_0 = theano.gof.TopoOptimizer(theano.gof.PatternSub( |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
112 (T.mul,'x', 'x'), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
113 (T.sqr, 'x'))) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
114 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
115 def __init__(self, do_sqr=True): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
116 self.do_sqr = do_sqr |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
117 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
118 def __call__(self, env): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
119 self.merge(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
120 self.gemm_opt_1(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
121 if self.do_sqr: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
122 self.sqr_opt_0(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
123 self.merge(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
124 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
125 def linker(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
126 return theano.gof.PerformLinker() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
127 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
128 def early_stopper(self): |
304
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
129 stopper.NStages(300,1) |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
130 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
131 def train_iter(self, trainset): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
132 raise AbstractFunction |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
133 optimizer = Opt() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
134 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
135 def load(self,f) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
136 raise AbstractFunction |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
137 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
138 def save(self,f,model) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
139 raise AbstractFunction |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
140 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
141 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
142 def __init__(self, graph): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
143 self.graph = graph |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
144 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
145 def _fn(self, inputs, outputs): |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
146 # Caching here would hamper multi-threaded apps |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
147 # prefer caching in Model.__call__ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
148 return theano.function(inputs, outputs, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
149 unpack_single=False, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
150 optimizer=self.graph.optimizer, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
151 linker=self.graph.linker() if hasattr(self.graph, 'linker') |
304
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
152 else 'c|py') |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
153 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
154 def __call__(self, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
155 trainset=None, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
156 validset=None, |
304
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
157 iparams=None, |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
158 stp=None): |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
159 """Allocate and optionally train a model |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
160 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
161 @param trainset: Data for minimizing the cost function |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
162 @type trainset: None or Dataset |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
163 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
164 @param validset: Data for early stopping |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
165 @type validset: None or Dataset |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
166 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
167 @param input: name of field to use as input |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
168 @type input: string |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
169 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
170 @param target: name of field to use as target |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
171 @type target: string |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
172 |
304
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
173 @param stp: early stopper, if None use default in graphMLP.G |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
174 @type stp: None or early stopper |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
175 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
176 @return: model |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
177 @rtype: GraphLearner.Model instance |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
178 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
179 """ |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
180 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
181 iparams = self.graph.iparams() if iparams is None else iparams |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
182 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
183 # if we load, type(trainset) == 'str' |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
184 if isinstance(trainset,str) or isinstance(trainset,file): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
185 #loadmodel = GraphLearner.Model(self, iparams) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
186 loadmodel = self.graph.load(self,trainset) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
187 return loadmodel |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
188 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
189 curmodel = GraphLearner.Model(self, iparams) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
190 best = curmodel |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
191 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
192 if trainset is not None: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
193 #do some training by calling Model.update_minibatch() |
304
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
194 if stp == None : |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
195 stp = self.graph.early_stopper() |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
196 try : |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
197 countiter = 0 |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
198 for mb in self.graph.train_iter(trainset): |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
199 curmodel.update_minibatch(mb) |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
200 if stp.set_score: |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
201 if validset: |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
202 stp.score = curmodel(validset, ['validset_score']) |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
203 if (stp.score < stp.best_score): |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
204 best = copy.copy(curmodel) |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
205 else: |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
206 stp.score = 0.0 |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
207 countiter +=1 |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
208 stp.next() |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
209 except StopIteration : |
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
210 print 'Iterations stopped after ', countiter,' iterations' |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
211 if validset: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
212 curmodel = best |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
213 return curmodel |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
214 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
215 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
216 def graphMLP(ninputs, nhid, nclass, lr_val, l2coef_val=0.0): |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
217 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
218 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
219 def wrapper(i, node, thunk): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
220 if 0: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
221 print i, node |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
222 print thunk.inputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
223 print thunk.outputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
224 if node.op == nnet_ops.crossentropy_softmax_1hot_with_bias: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
225 print 'here is the nll op' |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
226 thunk() #actually compute this piece of the graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
227 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
228 class G(GraphLearner.Graph, AutoName): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
229 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
230 lr = T.constant(lr_val) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
231 assert l2coef_val == 0.0 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
232 l2coef = T.constant(l2coef_val) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
233 input = T.matrix() # n_examples x n_inputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
234 target = T.ivector() # len: n_examples |
299
eded3cb54930
small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
265
diff
changeset
|
235 #target = T.matrix() |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
236 W2, b2 = T.matrix(), T.vector() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
237 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
238 W1, b1 = T.matrix(), T.vector() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
239 hid = T.tanh(b1 + T.dot(input, W1)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
240 hid_regularization = l2coef * T.sum(W1*W1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
241 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
242 params = [W1, b1, W2, b2] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
243 activations = b2 + T.dot(hid, W2) |
299
eded3cb54930
small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
265
diff
changeset
|
244 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target ) |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
245 regularization = l2coef * T.sum(W2*W2) + hid_regularization |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
246 output_class = T.argmax(activations,1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
247 loss_01 = T.neq(output_class, target) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
248 #g_params = T.grad(nll + regularization, params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
249 g_params = T.grad(nll, params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
250 new_params = [T.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
251 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
252 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
253 def __eq__(self,other) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
254 print 'G.__eq__ from graphMLP(), not implemented yet' |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
255 return NotImplemented |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
256 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
257 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
258 def load(self, algo, f): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
259 """ Load from file the 2 matrices and bias vectors """ |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
260 cloase_at_end = False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
261 if isinstance(f,str) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
262 f = open(f,'r') |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
263 close_at_end = True |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
264 params = [] |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
265 for i in xrange(4): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
266 params.append(filetensor.read(f)) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
267 if close_at_end : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
268 f.close() |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
269 return GraphLearner.Model(algo, params) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
270 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
271 def save(self, f, model): |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
272 """ Save params to file, so 2 matrices and 2 bias vectors. Same order as iparams. """ |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
273 cloase_at_end = False |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
274 if isinstance(f,str) : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
275 f = open(f,'w') |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
276 close_at_end = True |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
277 for p in model.params: |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
278 filetensor.write(f,p) |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
279 if close_at_end : |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
280 f.close() |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
281 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
282 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
283 def iparams(self): |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
284 """ init params. """ |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
285 def randsmall(*shape): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
286 return (numpy.random.rand(*shape) -0.5) * 0.001 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
287 return [randsmall(ninputs, nhid) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
288 , randsmall(nhid) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
289 , randsmall(nhid, nclass) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
290 , randsmall(nclass)] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
291 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
292 def train_iter(self, trainset): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
293 return trainset.minibatches(['input', 'target'], |
304
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
294 minibatch_size=min(len(trainset), 32), n_batches=2000) |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
295 def early_stopper(self): |
304
6ead65d30f1e
while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
299
diff
changeset
|
296 """ overwrites GraphLearner.graph function """ |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
297 return stopper.NStages(300,1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
298 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
299 return G() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
300 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
301 |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
302 import unittest |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
303 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
304 class TestMLP(unittest.TestCase): |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
305 def blah(self, g): |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
306 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
307 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
308 [1, 0, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
309 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
310 {'input':slice(2),'target':2}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
311 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
312 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
313 [1, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
314 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
315 {'input':slice(2),'target':2}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
316 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
317 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
318 [1, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
319 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
320 {'input':slice(2)}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
321 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
322 learn_algo = GraphLearner(g) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
323 |
232 | 324 model1 = learn_algo(training_set1) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
325 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
326 model2 = learn_algo(training_set2) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
327 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
328 omatch = [o1 == o2 for o1, o2 in zip(model1(test_data), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
329 model2(test_data))] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
330 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
331 n_match = sum(omatch) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
332 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
333 self.failUnless(n_match == (numpy.sum(training_set1.fields()['target'] == |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
334 training_set2.fields()['target'])), omatch) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
335 |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
336 model1.save('/tmp/model1') |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
337 |
265
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
338 #denoising_aa = GraphLearner(denoising_g) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
339 #model1 = denoising_aa(trainset) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
340 #hidset = model(trainset, fieldnames=['hidden']) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
341 #model2 = denoising_aa(hidset) |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
342 |
265
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
343 #f = open('blah', 'w') |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
344 #for m in model: |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
345 # m.save(f) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
346 #filetensor.write(f, initial_classification_weights) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
347 #f.flush() |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
348 |
265
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
349 #deep_sigmoid_net = GraphLearner(deepnetwork_g) |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
350 #deep_model = deep_sigmoid_net.load('blah') |
ae0a8345869b
commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
264
diff
changeset
|
351 #deep_model.update(trainset) #do some fine tuning |
264
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
352 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
353 model1_dup = learn_algo('/tmp/model1') |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
354 |
a1793a5e9523
we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
244
diff
changeset
|
355 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
356 def equiv(self, g0, g1): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
357 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
358 [0, 1, 1], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
359 [1, 0, 1], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
360 [1, 1, 1]]), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
361 {'input':slice(2),'target':2}) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
362 learn_algo_0 = GraphLearner(g0) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
363 learn_algo_1 = GraphLearner(g1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
364 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
365 model_0 = learn_algo_0(training_set1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
366 model_1 = learn_algo_1(training_set1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
367 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
368 print '----' |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
369 for p in zip(model_0.params, model_1.params): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
370 abs_rel_err = theano.gradient.numeric_grad.abs_rel_err(p[0], p[1]) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
371 max_abs_rel_err = numpy.max(abs_rel_err) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
372 if max_abs_rel_err > 1.0e-7: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
373 print 'p0', p[0] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
374 print 'p1', p[1] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
375 #self.failUnless(max_abs_rel_err < 1.0e-7, max_abs_rel_err) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
376 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
377 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
378 def test0(self): self.blah(graphMLP(2, 10, 2, .1)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
379 def test1(self): self.blah(graphMLP(2, 3, 2, .1)) |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
380 |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
381 if __name__ == '__main__': |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
382 unittest.main() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
383 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
384 |