Mercurial > pylearn
annotate mlp_factory_approach.py @ 248:82ba488b2c24
polished filetensor a little
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 13:14:45 -0400 |
parents | 3156a9976183 |
children | a1793a5e9523 |
rev | line source |
---|---|
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
1 import copy, sys |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
2 import numpy |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
3 |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
4 import theano |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
5 from theano import tensor as T |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
6 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
7 from pylearn import dataset, nnet_ops, stopper, LookupList |
211
bd728c83faff
in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents:
208
diff
changeset
|
8 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
9 class AbstractFunction (Exception): pass |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
10 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
11 class AutoName(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
12 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
13 By inheriting from this class, class variables which have a name attribute |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
14 will have that name attribute set to the class variable name. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
15 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
16 class __metaclass__(type): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
17 def __init__(cls, name, bases, dct): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
18 type.__init__(name, bases, dct) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
19 for key, val in dct.items(): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
20 assert type(key) is str |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
21 if hasattr(val, 'name'): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
22 val.name = key |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
23 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
24 class GraphLearner(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
25 class Model(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
26 def __init__(self, algo, params): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
27 self.algo = algo |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
28 self.params = params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
29 graph = self.algo.graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
30 self.update_fn = algo._fn([graph.input, graph.target] + graph.params, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
31 [graph.nll] + graph.new_params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
32 self._fn_cache = {} |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
33 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
34 def __copy__(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
35 raise Exception('why not called?') |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
36 return GraphLearner.Model(self.algo, [copy.copy(p) for p in params]) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
37 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
38 def _cache(self, key, valfn): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
39 d = self._fn_cache |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
40 if key not in d: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
41 d[key] = valfn() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
42 return d[key] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
43 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
44 def update_minibatch(self, minibatch): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
45 assert isinstance(minibatch, LookupList) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
46 self.update_fn(minibatch['input'], minibatch['target'], *self.params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
47 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
48 def update(self, dataset, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
49 default_minibatch_size=32): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
50 """Update this model from more training data.""" |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
51 params = self.params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
52 minibatch_size = min(default_minibatch_size, len(dataset)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
53 for mb in dataset.minibatches(['input', 'target'], minibatch_size=minibatch_size): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
54 self.update_minibatch(mb) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
55 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
56 def __call__(self, testset, fieldnames=['output_class']): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
57 """Apply this model (as a function) to new data. |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
58 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
59 @param testset: DataSet, whose fields feed Result terms in self.algo.g |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
60 @type testset: DataSet |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
61 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
62 @param fieldnames: names of results in self.algo.g to compute. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
63 @type fieldnames: list of strings |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
64 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
65 @return: DataSet with fields from fieldnames, computed from testset by |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
66 this model. |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
67 @rtype: ApplyFunctionDataSet instance |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
68 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
69 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
70 graph = self.algo.graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
71 def getresult(name): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
72 r = getattr(graph, name) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
73 if not isinstance(r, theano.Result): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
74 raise TypeError('string does not name a theano.Result', (name, r)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
75 return r |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
76 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
77 provided = [getresult(name) for name in testset.fieldNames()] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
78 wanted = [getresult(name) for name in fieldnames] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
79 inputs = provided + graph.params |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
80 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
81 theano_fn = self._cache((tuple(inputs), tuple(wanted)), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
82 lambda: self.algo._fn(inputs, wanted)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
83 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
84 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
85 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
86 class Graph(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
87 class Opt(object): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
88 merge = theano.gof.MergeOptimizer() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
89 gemm_opt_1 = theano.gof.TopoOptimizer(theano.tensor_opt.gemm_pattern_1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
90 sqr_opt_0 = theano.gof.TopoOptimizer(theano.gof.PatternSub( |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
91 (T.mul,'x', 'x'), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
92 (T.sqr, 'x'))) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
93 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
94 def __init__(self, do_sqr=True): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
95 self.do_sqr = do_sqr |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
96 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
97 def __call__(self, env): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
98 self.merge(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
99 self.gemm_opt_1(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
100 if self.do_sqr: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
101 self.sqr_opt_0(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
102 self.merge(env) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
103 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
104 def linker(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
105 return theano.gof.PerformLinker() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
106 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
107 def early_stopper(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
108 stopper.NStages(10,1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
109 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
110 def train_iter(self, trainset): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
111 raise AbstractFunction |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
112 optimizer = Opt() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
113 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
114 def __init__(self, graph): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
115 self.graph = graph |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
116 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
117 def _fn(self, inputs, outputs): |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
118 # Caching here would hamper multi-threaded apps |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
119 # prefer caching in Model.__call__ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
120 return theano.function(inputs, outputs, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
121 unpack_single=False, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
122 optimizer=self.graph.optimizer, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
123 linker=self.graph.linker() if hasattr(self.graph, 'linker') |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
124 else 'c&py') |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
125 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
126 def __call__(self, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
127 trainset=None, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
128 validset=None, |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
129 iparams=None): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
130 """Allocate and optionally train a model |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
131 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
132 @param trainset: Data for minimizing the cost function |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
133 @type trainset: None or Dataset |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
134 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
135 @param validset: Data for early stopping |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
136 @type validset: None or Dataset |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
137 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
138 @param input: name of field to use as input |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
139 @type input: string |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
140 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
141 @param target: name of field to use as target |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
142 @type target: string |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
143 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
144 @return: model |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
145 @rtype: GraphLearner.Model instance |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
146 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
147 """ |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
148 iparams = self.graph.iparams() if iparams is None else iparams |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
149 curmodel = GraphLearner.Model(self, iparams) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
150 best = curmodel |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
151 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
152 if trainset is not None: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
153 #do some training by calling Model.update_minibatch() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
154 stp = self.graph.early_stopper() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
155 for mb in self.graph.train_iter(trainset): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
156 curmodel.update_minibatch(mb) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
157 if stp.set_score: |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
158 if validset: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
159 stp.score = curmodel(validset, ['validset_score']) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
160 if (stp.score < stp.best_score): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
161 best = copy.copy(curmodel) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
162 else: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
163 stp.score = 0.0 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
164 stp.next() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
165 if validset: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
166 curmodel = best |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
167 return curmodel |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
168 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
169 def graphMLP(ninputs, nhid, nclass, lr_val, l2coef_val=0.0): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
170 def wrapper(i, node, thunk): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
171 if 0: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
172 print i, node |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
173 print thunk.inputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
174 print thunk.outputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
175 if node.op == nnet_ops.crossentropy_softmax_1hot_with_bias: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
176 print 'here is the nll op' |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
177 thunk() #actually compute this piece of the graph |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
178 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
179 class G(GraphLearner.Graph, AutoName): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
180 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
181 lr = T.constant(lr_val) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
182 assert l2coef_val == 0.0 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
183 l2coef = T.constant(l2coef_val) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
184 input = T.matrix() # n_examples x n_inputs |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
185 target = T.ivector() # len: n_examples |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
186 W2, b2 = T.matrix(), T.vector() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
187 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
188 W1, b1 = T.matrix(), T.vector() |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
189 hid = T.tanh(b1 + T.dot(input, W1)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
190 hid_regularization = l2coef * T.sum(W1*W1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
191 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
192 params = [W1, b1, W2, b2] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
193 activations = b2 + T.dot(hid, W2) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
194 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
195 regularization = l2coef * T.sum(W2*W2) + hid_regularization |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
196 output_class = T.argmax(activations,1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
197 loss_01 = T.neq(output_class, target) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
198 #g_params = T.grad(nll + regularization, params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
199 g_params = T.grad(nll, params) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
200 new_params = [T.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
201 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
202 def iparams(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
203 def randsmall(*shape): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
204 return (numpy.random.rand(*shape) -0.5) * 0.001 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
205 return [randsmall(ninputs, nhid) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
206 , randsmall(nhid) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
207 , randsmall(nhid, nclass) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
208 , randsmall(nclass)] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
209 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
210 def train_iter(self, trainset): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
211 return trainset.minibatches(['input', 'target'], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
212 minibatch_size=min(len(trainset), 32), n_batches=300) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
213 def early_stopper(self): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
214 return stopper.NStages(300,1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
215 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
216 return G() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
217 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
218 |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
219 import unittest |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
220 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
221 class TestMLP(unittest.TestCase): |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
222 def blah(self, g): |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
223 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
224 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
225 [1, 0, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
226 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
227 {'input':slice(2),'target':2}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
228 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
229 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
230 [1, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
231 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
232 {'input':slice(2),'target':2}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
233 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
234 [0, 1, 1], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
235 [1, 0, 0], |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
236 [1, 1, 1]]), |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
237 {'input':slice(2)}) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
238 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
239 learn_algo = GraphLearner(g) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
240 |
232 | 241 model1 = learn_algo(training_set1) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
242 |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
243 model2 = learn_algo(training_set2) |
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
244 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
245 omatch = [o1 == o2 for o1, o2 in zip(model1(test_data), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
246 model2(test_data))] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
247 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
248 n_match = sum(omatch) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
249 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
250 self.failUnless(n_match == (numpy.sum(training_set1.fields()['target'] == |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
251 training_set2.fields()['target'])), omatch) |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
252 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
253 def equiv(self, g0, g1): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
254 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
255 [0, 1, 1], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
256 [1, 0, 1], |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
257 [1, 1, 1]]), |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
258 {'input':slice(2),'target':2}) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
259 learn_algo_0 = GraphLearner(g0) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
260 learn_algo_1 = GraphLearner(g1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
261 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
262 model_0 = learn_algo_0(training_set1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
263 model_1 = learn_algo_1(training_set1) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
264 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
265 print '----' |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
266 for p in zip(model_0.params, model_1.params): |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
267 abs_rel_err = theano.gradient.numeric_grad.abs_rel_err(p[0], p[1]) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
268 max_abs_rel_err = numpy.max(abs_rel_err) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
269 if max_abs_rel_err > 1.0e-7: |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
270 print 'p0', p[0] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
271 print 'p1', p[1] |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
272 #self.failUnless(max_abs_rel_err < 1.0e-7, max_abs_rel_err) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
273 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
274 |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
275 def test0(self): self.blah(graphMLP(2, 10, 2, .1)) |
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
276 def test1(self): self.blah(graphMLP(2, 3, 2, .1)) |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
277 |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
278 if __name__ == '__main__': |
208
bf320808919f
back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents:
207
diff
changeset
|
279 unittest.main() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
280 |
244
3156a9976183
mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
232
diff
changeset
|
281 |