annotate mlp_factory_approach.py @ 251:7e6edee187e3

optimization of CachedDataSet__getitem__
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 03 Jun 2008 12:25:53 -0400
parents 3156a9976183
children a1793a5e9523
rev   line source
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
1 import copy, sys
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
2 import numpy
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
3
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 import theano
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
5 from theano import tensor as T
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
7 from pylearn import dataset, nnet_ops, stopper, LookupList
211
bd728c83faff in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 208
diff changeset
8
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
9 class AbstractFunction (Exception): pass
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
10
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
11 class AutoName(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
12 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
13 By inheriting from this class, class variables which have a name attribute
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
14 will have that name attribute set to the class variable name.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
15 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
16 class __metaclass__(type):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
17 def __init__(cls, name, bases, dct):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
18 type.__init__(name, bases, dct)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
19 for key, val in dct.items():
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
20 assert type(key) is str
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
21 if hasattr(val, 'name'):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
22 val.name = key
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
23
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
24 class GraphLearner(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
25 class Model(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
26 def __init__(self, algo, params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
27 self.algo = algo
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
28 self.params = params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
29 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
30 self.update_fn = algo._fn([graph.input, graph.target] + graph.params,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
31 [graph.nll] + graph.new_params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
32 self._fn_cache = {}
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
33
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
34 def __copy__(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
35 raise Exception('why not called?')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
36 return GraphLearner.Model(self.algo, [copy.copy(p) for p in params])
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
37
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
38 def _cache(self, key, valfn):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
39 d = self._fn_cache
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
40 if key not in d:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
41 d[key] = valfn()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
42 return d[key]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
43
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
44 def update_minibatch(self, minibatch):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
45 assert isinstance(minibatch, LookupList)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
46 self.update_fn(minibatch['input'], minibatch['target'], *self.params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
47
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
48 def update(self, dataset,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
49 default_minibatch_size=32):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
50 """Update this model from more training data."""
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
51 params = self.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
52 minibatch_size = min(default_minibatch_size, len(dataset))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
53 for mb in dataset.minibatches(['input', 'target'], minibatch_size=minibatch_size):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
54 self.update_minibatch(mb)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
55
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
56 def __call__(self, testset, fieldnames=['output_class']):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
57 """Apply this model (as a function) to new data.
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
58
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
59 @param testset: DataSet, whose fields feed Result terms in self.algo.g
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
60 @type testset: DataSet
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
61
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
62 @param fieldnames: names of results in self.algo.g to compute.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
63 @type fieldnames: list of strings
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
64
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
65 @return: DataSet with fields from fieldnames, computed from testset by
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
66 this model.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
67 @rtype: ApplyFunctionDataSet instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
68
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
69 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
70 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
71 def getresult(name):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
72 r = getattr(graph, name)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
73 if not isinstance(r, theano.Result):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
74 raise TypeError('string does not name a theano.Result', (name, r))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
75 return r
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
76
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
77 provided = [getresult(name) for name in testset.fieldNames()]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
78 wanted = [getresult(name) for name in fieldnames]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
79 inputs = provided + graph.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
80
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
81 theano_fn = self._cache((tuple(inputs), tuple(wanted)),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
82 lambda: self.algo._fn(inputs, wanted))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
83 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
84 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
85
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
86 class Graph(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
87 class Opt(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
88 merge = theano.gof.MergeOptimizer()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
89 gemm_opt_1 = theano.gof.TopoOptimizer(theano.tensor_opt.gemm_pattern_1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
90 sqr_opt_0 = theano.gof.TopoOptimizer(theano.gof.PatternSub(
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
91 (T.mul,'x', 'x'),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
92 (T.sqr, 'x')))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
93
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
94 def __init__(self, do_sqr=True):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
95 self.do_sqr = do_sqr
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
96
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
97 def __call__(self, env):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
98 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
99 self.gemm_opt_1(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
100 if self.do_sqr:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
101 self.sqr_opt_0(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
102 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
103
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
104 def linker(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
105 return theano.gof.PerformLinker()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
106
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
107 def early_stopper(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
108 stopper.NStages(10,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
109
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
110 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
111 raise AbstractFunction
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
112 optimizer = Opt()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
113
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
114 def __init__(self, graph):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
115 self.graph = graph
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
116
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
117 def _fn(self, inputs, outputs):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
118 # Caching here would hamper multi-threaded apps
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
119 # prefer caching in Model.__call__
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
120 return theano.function(inputs, outputs,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
121 unpack_single=False,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
122 optimizer=self.graph.optimizer,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
123 linker=self.graph.linker() if hasattr(self.graph, 'linker')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
124 else 'c&py')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
125
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
126 def __call__(self,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
127 trainset=None,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
128 validset=None,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
129 iparams=None):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
130 """Allocate and optionally train a model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
131
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
132 @param trainset: Data for minimizing the cost function
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
133 @type trainset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
134
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
135 @param validset: Data for early stopping
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
136 @type validset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
137
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
138 @param input: name of field to use as input
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
139 @type input: string
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
140
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
141 @param target: name of field to use as target
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
142 @type target: string
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
143
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
144 @return: model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
145 @rtype: GraphLearner.Model instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
146
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
147 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
148 iparams = self.graph.iparams() if iparams is None else iparams
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
149 curmodel = GraphLearner.Model(self, iparams)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
150 best = curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
151
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
152 if trainset is not None:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
153 #do some training by calling Model.update_minibatch()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
154 stp = self.graph.early_stopper()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
155 for mb in self.graph.train_iter(trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
156 curmodel.update_minibatch(mb)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
157 if stp.set_score:
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
158 if validset:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
159 stp.score = curmodel(validset, ['validset_score'])
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
160 if (stp.score < stp.best_score):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
161 best = copy.copy(curmodel)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
162 else:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
163 stp.score = 0.0
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
164 stp.next()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
165 if validset:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
166 curmodel = best
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
167 return curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
168
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
169 def graphMLP(ninputs, nhid, nclass, lr_val, l2coef_val=0.0):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
170 def wrapper(i, node, thunk):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
171 if 0:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
172 print i, node
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
173 print thunk.inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
174 print thunk.outputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
175 if node.op == nnet_ops.crossentropy_softmax_1hot_with_bias:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
176 print 'here is the nll op'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
177 thunk() #actually compute this piece of the graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
178
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
179 class G(GraphLearner.Graph, AutoName):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
180
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
181 lr = T.constant(lr_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
182 assert l2coef_val == 0.0
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
183 l2coef = T.constant(l2coef_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
184 input = T.matrix() # n_examples x n_inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
185 target = T.ivector() # len: n_examples
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
186 W2, b2 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
187
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
188 W1, b1 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
189 hid = T.tanh(b1 + T.dot(input, W1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
190 hid_regularization = l2coef * T.sum(W1*W1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
191
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
192 params = [W1, b1, W2, b2]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
193 activations = b2 + T.dot(hid, W2)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
194 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
195 regularization = l2coef * T.sum(W2*W2) + hid_regularization
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
196 output_class = T.argmax(activations,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
197 loss_01 = T.neq(output_class, target)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
198 #g_params = T.grad(nll + regularization, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
199 g_params = T.grad(nll, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
200 new_params = [T.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
201
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
202 def iparams(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
203 def randsmall(*shape):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
204 return (numpy.random.rand(*shape) -0.5) * 0.001
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
205 return [randsmall(ninputs, nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
206 , randsmall(nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
207 , randsmall(nhid, nclass)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
208 , randsmall(nclass)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
209
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
210 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
211 return trainset.minibatches(['input', 'target'],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
212 minibatch_size=min(len(trainset), 32), n_batches=300)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
213 def early_stopper(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
214 return stopper.NStages(300,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
215
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
216 return G()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
217
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
218
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
219 import unittest
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
220
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
221 class TestMLP(unittest.TestCase):
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
222 def blah(self, g):
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
223 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
224 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
225 [1, 0, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
226 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
227 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
228 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
229 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
230 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
231 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
232 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
233 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
234 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
235 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
236 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
237 {'input':slice(2)})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
238
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
239 learn_algo = GraphLearner(g)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
240
232
c047238e5b3f Fixed by James
delallea@opale.iro.umontreal.ca
parents: 226
diff changeset
241 model1 = learn_algo(training_set1)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
242
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
243 model2 = learn_algo(training_set2)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
244
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
245 omatch = [o1 == o2 for o1, o2 in zip(model1(test_data),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
246 model2(test_data))]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
247
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
248 n_match = sum(omatch)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
249
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
250 self.failUnless(n_match == (numpy.sum(training_set1.fields()['target'] ==
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
251 training_set2.fields()['target'])), omatch)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
252
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
253 def equiv(self, g0, g1):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
254 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
255 [0, 1, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
256 [1, 0, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
257 [1, 1, 1]]),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
258 {'input':slice(2),'target':2})
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
259 learn_algo_0 = GraphLearner(g0)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
260 learn_algo_1 = GraphLearner(g1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
261
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
262 model_0 = learn_algo_0(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
263 model_1 = learn_algo_1(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
264
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
265 print '----'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
266 for p in zip(model_0.params, model_1.params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
267 abs_rel_err = theano.gradient.numeric_grad.abs_rel_err(p[0], p[1])
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
268 max_abs_rel_err = numpy.max(abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
269 if max_abs_rel_err > 1.0e-7:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
270 print 'p0', p[0]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
271 print 'p1', p[1]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
272 #self.failUnless(max_abs_rel_err < 1.0e-7, max_abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
273
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
274
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
275 def test0(self): self.blah(graphMLP(2, 10, 2, .1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
276 def test1(self): self.blah(graphMLP(2, 3, 2, .1))
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
277
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
278 if __name__ == '__main__':
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
279 unittest.main()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
280
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
281