annotate mlp_factory_approach.py @ 304:6ead65d30f1e

while learning using __call__, we can now set the early stopper
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Tue, 10 Jun 2008 17:16:49 -0400
parents eded3cb54930
children 93280a0c151a
rev   line source
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
1 import copy, sys, os
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
2 import numpy
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
3
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 import theano
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
5 from theano import tensor as T
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
7 import dataset, nnet_ops, stopper, filetensor
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
8 from pylearn.lookup_list import LookupList
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
9
211
bd728c83faff in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 208
diff changeset
10
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
11 class AbstractFunction (Exception): pass
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
12
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
13 class AutoName(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
14 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
15 By inheriting from this class, class variables which have a name attribute
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
16 will have that name attribute set to the class variable name.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
17 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
18 class __metaclass__(type):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
19 def __init__(cls, name, bases, dct):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
20 type.__init__(name, bases, dct)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
21 for key, val in dct.items():
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
22 assert type(key) is str
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
23 if hasattr(val, 'name'):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
24 val.name = key
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
25
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
26 class GraphLearner(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
27 class Model(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
28 def __init__(self, algo, params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
29 self.algo = algo
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
30 self.params = params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
31 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
32 self.update_fn = algo._fn([graph.input, graph.target] + graph.params,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
33 [graph.nll] + graph.new_params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
34 self._fn_cache = {}
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
35
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
36 def __copy__(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
37 raise Exception('why not called?')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
38 return GraphLearner.Model(self.algo, [copy.copy(p) for p in params])
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
39
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
40 def __eq__(self,other,tolerance=0.) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
41 """ Only compares weights of matrices and bias vector. """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
42 if not isinstance(other,GraphLearner.Model) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
43 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
44 for p in range(4) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
45 if self.params[p].shape != other.params[p].shape :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
46 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
47 if not numpy.all( numpy.abs(self.params[p] - other.params[p]) <= tolerance ) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
48 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
49 return True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
50
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
51 def _cache(self, key, valfn):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
52 d = self._fn_cache
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
53 if key not in d:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
54 d[key] = valfn()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
55 return d[key]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
56
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
57 def update_minibatch(self, minibatch):
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
58 if not isinstance(minibatch, LookupList):
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
59 print type(minibatch)
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
60 assert isinstance(minibatch, LookupList)
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
61 self.update_fn(minibatch['input'], minibatch['target'], *self.params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
62
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
63 def update(self, dataset,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
64 default_minibatch_size=32):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
65 """Update this model from more training data."""
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
66 params = self.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
67 minibatch_size = min(default_minibatch_size, len(dataset))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
68 for mb in dataset.minibatches(['input', 'target'], minibatch_size=minibatch_size):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
69 self.update_minibatch(mb)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
70
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
71 def save(self, f):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
72 self.algo.graph.save(f, self)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
73
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
74 def __call__(self, testset, fieldnames=['output_class']):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
75 """Apply this model (as a function) to new data.
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
76
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
77 @param testset: DataSet, whose fields feed Result terms in self.algo.g
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
78 @type testset: DataSet
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
79
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
80 @param fieldnames: names of results in self.algo.g to compute.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
81 @type fieldnames: list of strings
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
82
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
83 @return: DataSet with fields from fieldnames, computed from testset by
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
84 this model.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
85 @rtype: ApplyFunctionDataSet instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
86
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
87 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
88 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
89 def getresult(name):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
90 r = getattr(graph, name)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
91 if not isinstance(r, theano.Result):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
92 raise TypeError('string does not name a theano.Result', (name, r))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
93 return r
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
94
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
95 provided = [getresult(name) for name in testset.fieldNames()]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
96 wanted = [getresult(name) for name in fieldnames]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
97 inputs = provided + graph.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
98
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
99 theano_fn = self._cache((tuple(inputs), tuple(wanted)),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
100 lambda: self.algo._fn(inputs, wanted))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
101 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
102 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
103
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
104 class Graph(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
105 class Opt(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
106 merge = theano.gof.MergeOptimizer()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
107 gemm_opt_1 = theano.gof.TopoOptimizer(theano.tensor_opt.gemm_pattern_1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
108 sqr_opt_0 = theano.gof.TopoOptimizer(theano.gof.PatternSub(
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
109 (T.mul,'x', 'x'),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
110 (T.sqr, 'x')))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
111
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
112 def __init__(self, do_sqr=True):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
113 self.do_sqr = do_sqr
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
114
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
115 def __call__(self, env):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
116 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
117 self.gemm_opt_1(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
118 if self.do_sqr:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
119 self.sqr_opt_0(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
120 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
121
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
122 def linker(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
123 return theano.gof.PerformLinker()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
124
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
125 def early_stopper(self):
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
126 stopper.NStages(300,1)
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
127
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
128 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
129 raise AbstractFunction
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
130 optimizer = Opt()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
131
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
132 def load(self,f) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
133 raise AbstractFunction
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
134
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
135 def save(self,f,model) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
136 raise AbstractFunction
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
137
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
138
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
139 def __init__(self, graph):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
140 self.graph = graph
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
141
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
142 def _fn(self, inputs, outputs):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
143 # Caching here would hamper multi-threaded apps
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
144 # prefer caching in Model.__call__
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
145 return theano.function(inputs, outputs,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
146 unpack_single=False,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
147 optimizer=self.graph.optimizer,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
148 linker=self.graph.linker() if hasattr(self.graph, 'linker')
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
149 else 'c|py')
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
150
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
151 def __call__(self,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
152 trainset=None,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
153 validset=None,
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
154 iparams=None,
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
155 stp=None):
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
156 """Allocate and optionally train a model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
157
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
158 @param trainset: Data for minimizing the cost function
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
159 @type trainset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
160
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
161 @param validset: Data for early stopping
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
162 @type validset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
163
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
164 @param input: name of field to use as input
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
165 @type input: string
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
166
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
167 @param target: name of field to use as target
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
168 @type target: string
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
169
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
170 @param stp: early stopper, if None use default in graphMLP.G
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
171 @type stp: None or early stopper
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
172
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
173 @return: model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
174 @rtype: GraphLearner.Model instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
175
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
176 """
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
177
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
178 iparams = self.graph.iparams() if iparams is None else iparams
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
179
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
180 # if we load, type(trainset) == 'str'
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
181 if isinstance(trainset,str) or isinstance(trainset,file):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
182 #loadmodel = GraphLearner.Model(self, iparams)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
183 loadmodel = self.graph.load(self,trainset)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
184 return loadmodel
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
185
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
186 curmodel = GraphLearner.Model(self, iparams)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
187 best = curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
188
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
189 if trainset is not None:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
190 #do some training by calling Model.update_minibatch()
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
191 if stp == None :
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
192 stp = self.graph.early_stopper()
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
193 try :
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
194 countiter = 0
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
195 for mb in self.graph.train_iter(trainset):
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
196 curmodel.update_minibatch(mb)
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
197 if stp.set_score:
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
198 if validset:
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
199 stp.score = curmodel(validset, ['validset_score'])
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
200 if (stp.score < stp.best_score):
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
201 best = copy.copy(curmodel)
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
202 else:
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
203 stp.score = 0.0
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
204 countiter +=1
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
205 stp.next()
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
206 except StopIteration :
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
207 print 'Iterations stopped after ', countiter,' iterations'
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
208 if validset:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
209 curmodel = best
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
210 return curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
211
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
212
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
213 def graphMLP(ninputs, nhid, nclass, lr_val, l2coef_val=0.0):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
214
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
215
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
216 def wrapper(i, node, thunk):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
217 if 0:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
218 print i, node
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
219 print thunk.inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
220 print thunk.outputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
221 if node.op == nnet_ops.crossentropy_softmax_1hot_with_bias:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
222 print 'here is the nll op'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
223 thunk() #actually compute this piece of the graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
224
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
225 class G(GraphLearner.Graph, AutoName):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
226
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
227 lr = T.constant(lr_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
228 assert l2coef_val == 0.0
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
229 l2coef = T.constant(l2coef_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
230 input = T.matrix() # n_examples x n_inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
231 target = T.ivector() # len: n_examples
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
232 #target = T.matrix()
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
233 W2, b2 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
234
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
235 W1, b1 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
236 hid = T.tanh(b1 + T.dot(input, W1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
237 hid_regularization = l2coef * T.sum(W1*W1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
238
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
239 params = [W1, b1, W2, b2]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
240 activations = b2 + T.dot(hid, W2)
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
241 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target )
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
242 regularization = l2coef * T.sum(W2*W2) + hid_regularization
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
243 output_class = T.argmax(activations,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
244 loss_01 = T.neq(output_class, target)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
245 #g_params = T.grad(nll + regularization, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
246 g_params = T.grad(nll, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
247 new_params = [T.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
248
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
249
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
250 def __eq__(self,other) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
251 print 'G.__eq__ from graphMLP(), not implemented yet'
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
252 return NotImplemented
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
253
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
254
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
255 def load(self, algo, f):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
256 """ Load from file the 2 matrices and bias vectors """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
257 cloase_at_end = False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
258 if isinstance(f,str) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
259 f = open(f,'r')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
260 close_at_end = True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
261 params = []
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
262 for i in xrange(4):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
263 params.append(filetensor.read(f))
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
264 if close_at_end :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
265 f.close()
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
266 return GraphLearner.Model(algo, params)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
267
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
268 def save(self, f, model):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
269 """ Save params to file, so 2 matrices and 2 bias vectors. Same order as iparams. """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
270 cloase_at_end = False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
271 if isinstance(f,str) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
272 f = open(f,'w')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
273 close_at_end = True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
274 for p in model.params:
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
275 filetensor.write(f,p)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
276 if close_at_end :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
277 f.close()
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
278
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
279
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
280 def iparams(self):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
281 """ init params. """
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
282 def randsmall(*shape):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
283 return (numpy.random.rand(*shape) -0.5) * 0.001
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
284 return [randsmall(ninputs, nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
285 , randsmall(nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
286 , randsmall(nhid, nclass)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
287 , randsmall(nclass)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
288
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
289 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
290 return trainset.minibatches(['input', 'target'],
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
291 minibatch_size=min(len(trainset), 32), n_batches=2000)
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
292 def early_stopper(self):
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
293 """ overwrites GraphLearner.graph function """
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
294 return stopper.NStages(300,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
295
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
296 return G()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
297
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
298
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
299 import unittest
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
300
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
301 class TestMLP(unittest.TestCase):
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
302 def blah(self, g):
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
303 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
304 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
305 [1, 0, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
306 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
307 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
308 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
309 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
310 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
311 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
312 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
313 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
314 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
315 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
316 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
317 {'input':slice(2)})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
318
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
319 learn_algo = GraphLearner(g)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
320
232
c047238e5b3f Fixed by James
delallea@opale.iro.umontreal.ca
parents: 226
diff changeset
321 model1 = learn_algo(training_set1)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
322
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
323 model2 = learn_algo(training_set2)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
324
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
325 omatch = [o1 == o2 for o1, o2 in zip(model1(test_data),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
326 model2(test_data))]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
327
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
328 n_match = sum(omatch)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
329
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
330 self.failUnless(n_match == (numpy.sum(training_set1.fields()['target'] ==
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
331 training_set2.fields()['target'])), omatch)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
332
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
333 model1.save('/tmp/model1')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
334
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
335 #denoising_aa = GraphLearner(denoising_g)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
336 #model1 = denoising_aa(trainset)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
337 #hidset = model(trainset, fieldnames=['hidden'])
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
338 #model2 = denoising_aa(hidset)
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
339
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
340 #f = open('blah', 'w')
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
341 #for m in model:
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
342 # m.save(f)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
343 #filetensor.write(f, initial_classification_weights)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
344 #f.flush()
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
345
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
346 #deep_sigmoid_net = GraphLearner(deepnetwork_g)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
347 #deep_model = deep_sigmoid_net.load('blah')
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
348 #deep_model.update(trainset) #do some fine tuning
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
349
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
350 model1_dup = learn_algo('/tmp/model1')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
351
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
352
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
353 def equiv(self, g0, g1):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
354 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
355 [0, 1, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
356 [1, 0, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
357 [1, 1, 1]]),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
358 {'input':slice(2),'target':2})
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
359 learn_algo_0 = GraphLearner(g0)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
360 learn_algo_1 = GraphLearner(g1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
361
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
362 model_0 = learn_algo_0(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
363 model_1 = learn_algo_1(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
364
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
365 print '----'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
366 for p in zip(model_0.params, model_1.params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
367 abs_rel_err = theano.gradient.numeric_grad.abs_rel_err(p[0], p[1])
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
368 max_abs_rel_err = numpy.max(abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
369 if max_abs_rel_err > 1.0e-7:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
370 print 'p0', p[0]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
371 print 'p1', p[1]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
372 #self.failUnless(max_abs_rel_err < 1.0e-7, max_abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
373
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
374
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
375 def test0(self): self.blah(graphMLP(2, 10, 2, .1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
376 def test1(self): self.blah(graphMLP(2, 3, 2, .1))
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
377
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
378 if __name__ == '__main__':
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
379 unittest.main()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
380
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
381