annotate mlp_factory_approach.py @ 299:eded3cb54930

small bug fixed
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:58:45 -0400
parents ae0a8345869b
children 6ead65d30f1e
rev   line source
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
1 import copy, sys, os
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
2 import numpy
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
3
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 import theano
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
5 from theano import tensor as T
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
7 import dataset, nnet_ops, stopper, filetensor
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
8 from lookup_list import LookupList
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
9
211
bd728c83faff in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 208
diff changeset
10
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
11 class AbstractFunction (Exception): pass
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
12
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
13 class AutoName(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
14 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
15 By inheriting from this class, class variables which have a name attribute
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
16 will have that name attribute set to the class variable name.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
17 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
18 class __metaclass__(type):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
19 def __init__(cls, name, bases, dct):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
20 type.__init__(name, bases, dct)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
21 for key, val in dct.items():
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
22 assert type(key) is str
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
23 if hasattr(val, 'name'):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
24 val.name = key
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
25
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
26 class GraphLearner(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
27 class Model(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
28 def __init__(self, algo, params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
29 self.algo = algo
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
30 self.params = params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
31 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
32 self.update_fn = algo._fn([graph.input, graph.target] + graph.params,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
33 [graph.nll] + graph.new_params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
34 self._fn_cache = {}
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
35
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
36 def __copy__(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
37 raise Exception('why not called?')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
38 return GraphLearner.Model(self.algo, [copy.copy(p) for p in params])
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
39
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
40 def __eq__(self,other,tolerance=0.) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
41 """ Only compares weights of matrices and bias vector. """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
42 if not isinstance(other,GraphLearner.Model) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
43 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
44 for p in range(4) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
45 if self.params[p].shape != other.params[p].shape :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
46 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
47 if not numpy.all( numpy.abs(self.params[p] - other.params[p]) <= tolerance ) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
48 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
49 return True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
50
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
51 def _cache(self, key, valfn):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
52 d = self._fn_cache
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
53 if key not in d:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
54 d[key] = valfn()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
55 return d[key]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
56
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
57 def update_minibatch(self, minibatch):
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
58 if not isinstance(minibatch, LookupList):
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
59 print type(minibatch)
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
60 assert isinstance(minibatch, LookupList)
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
61 self.update_fn(minibatch['input'], minibatch['target'], *self.params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
62
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
63 def update(self, dataset,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
64 default_minibatch_size=32):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
65 """Update this model from more training data."""
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
66 params = self.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
67 minibatch_size = min(default_minibatch_size, len(dataset))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
68 for mb in dataset.minibatches(['input', 'target'], minibatch_size=minibatch_size):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
69 self.update_minibatch(mb)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
70
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
71 def save(self, f):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
72 self.algo.graph.save(f, self)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
73
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
74 def __call__(self, testset, fieldnames=['output_class']):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
75 """Apply this model (as a function) to new data.
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
76
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
77 @param testset: DataSet, whose fields feed Result terms in self.algo.g
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
78 @type testset: DataSet
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
79
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
80 @param fieldnames: names of results in self.algo.g to compute.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
81 @type fieldnames: list of strings
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
82
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
83 @return: DataSet with fields from fieldnames, computed from testset by
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
84 this model.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
85 @rtype: ApplyFunctionDataSet instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
86
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
87 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
88 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
89 def getresult(name):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
90 r = getattr(graph, name)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
91 if not isinstance(r, theano.Result):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
92 raise TypeError('string does not name a theano.Result', (name, r))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
93 return r
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
94
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
95 provided = [getresult(name) for name in testset.fieldNames()]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
96 wanted = [getresult(name) for name in fieldnames]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
97 inputs = provided + graph.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
98
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
99 theano_fn = self._cache((tuple(inputs), tuple(wanted)),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
100 lambda: self.algo._fn(inputs, wanted))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
101 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
102 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
103
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
104 class Graph(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
105 class Opt(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
106 merge = theano.gof.MergeOptimizer()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
107 gemm_opt_1 = theano.gof.TopoOptimizer(theano.tensor_opt.gemm_pattern_1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
108 sqr_opt_0 = theano.gof.TopoOptimizer(theano.gof.PatternSub(
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
109 (T.mul,'x', 'x'),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
110 (T.sqr, 'x')))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
111
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
112 def __init__(self, do_sqr=True):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
113 self.do_sqr = do_sqr
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
114
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
115 def __call__(self, env):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
116 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
117 self.gemm_opt_1(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
118 if self.do_sqr:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
119 self.sqr_opt_0(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
120 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
121
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
122 def linker(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
123 return theano.gof.PerformLinker()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
124
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
125 def early_stopper(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
126 stopper.NStages(10,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
127
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
128 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
129 raise AbstractFunction
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
130 optimizer = Opt()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
131
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
132 def load(self,f) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
133 raise AbstractFunction
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
134
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
135 def save(self,f,model) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
136 raise AbstractFunction
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
137
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
138
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
139 def __init__(self, graph):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
140 self.graph = graph
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
141
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
142 def _fn(self, inputs, outputs):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
143 # Caching here would hamper multi-threaded apps
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
144 # prefer caching in Model.__call__
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
145 return theano.function(inputs, outputs,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
146 unpack_single=False,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
147 optimizer=self.graph.optimizer,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
148 linker=self.graph.linker() if hasattr(self.graph, 'linker')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
149 else 'c&py')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
150
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
151 def __call__(self,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
152 trainset=None,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
153 validset=None,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
154 iparams=None):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
155 """Allocate and optionally train a model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
156
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
157 @param trainset: Data for minimizing the cost function
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
158 @type trainset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
159
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
160 @param validset: Data for early stopping
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
161 @type validset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
162
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
163 @param input: name of field to use as input
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
164 @type input: string
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
165
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
166 @param target: name of field to use as target
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
167 @type target: string
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
168
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
169 @return: model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
170 @rtype: GraphLearner.Model instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
171
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
172 """
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
173
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
174 iparams = self.graph.iparams() if iparams is None else iparams
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
175
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
176 # if we load, type(trainset) == 'str'
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
177 if isinstance(trainset,str) or isinstance(trainset,file):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
178 #loadmodel = GraphLearner.Model(self, iparams)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
179 loadmodel = self.graph.load(self,trainset)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
180 return loadmodel
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
181
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
182 curmodel = GraphLearner.Model(self, iparams)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
183 best = curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
184
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
185 if trainset is not None:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
186 #do some training by calling Model.update_minibatch()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
187 stp = self.graph.early_stopper()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
188 for mb in self.graph.train_iter(trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
189 curmodel.update_minibatch(mb)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
190 if stp.set_score:
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
191 if validset:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
192 stp.score = curmodel(validset, ['validset_score'])
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
193 if (stp.score < stp.best_score):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
194 best = copy.copy(curmodel)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
195 else:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
196 stp.score = 0.0
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
197 stp.next()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
198 if validset:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
199 curmodel = best
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
200 return curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
201
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
202
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
203 def graphMLP(ninputs, nhid, nclass, lr_val, l2coef_val=0.0):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
204
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
205
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
206 def wrapper(i, node, thunk):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
207 if 0:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
208 print i, node
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
209 print thunk.inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
210 print thunk.outputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
211 if node.op == nnet_ops.crossentropy_softmax_1hot_with_bias:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
212 print 'here is the nll op'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
213 thunk() #actually compute this piece of the graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
214
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
215 class G(GraphLearner.Graph, AutoName):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
216
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
217 lr = T.constant(lr_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
218 assert l2coef_val == 0.0
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
219 l2coef = T.constant(l2coef_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
220 input = T.matrix() # n_examples x n_inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
221 target = T.ivector() # len: n_examples
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
222 #target = T.matrix()
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
223 W2, b2 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
224
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
225 W1, b1 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
226 hid = T.tanh(b1 + T.dot(input, W1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
227 hid_regularization = l2coef * T.sum(W1*W1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
228
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
229 params = [W1, b1, W2, b2]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
230 activations = b2 + T.dot(hid, W2)
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
231 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target )
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
232 regularization = l2coef * T.sum(W2*W2) + hid_regularization
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
233 output_class = T.argmax(activations,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
234 loss_01 = T.neq(output_class, target)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
235 #g_params = T.grad(nll + regularization, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
236 g_params = T.grad(nll, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
237 new_params = [T.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
238
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
239
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
240 def __eq__(self,other) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
241 print 'G.__eq__ from graphMLP(), not implemented yet'
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
242 return NotImplemented
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
243
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
244
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
245 def load(self, algo, f):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
246 """ Load from file the 2 matrices and bias vectors """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
247 cloase_at_end = False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
248 if isinstance(f,str) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
249 f = open(f,'r')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
250 close_at_end = True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
251 params = []
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
252 for i in xrange(4):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
253 params.append(filetensor.read(f))
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
254 if close_at_end :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
255 f.close()
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
256 return GraphLearner.Model(algo, params)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
257
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
258 def save(self, f, model):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
259 """ Save params to file, so 2 matrices and 2 bias vectors. Same order as iparams. """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
260 cloase_at_end = False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
261 if isinstance(f,str) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
262 f = open(f,'w')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
263 close_at_end = True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
264 for p in model.params:
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
265 filetensor.write(f,p)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
266 if close_at_end :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
267 f.close()
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
268
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
269
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
270 def iparams(self):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
271 """ init params. """
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
272 def randsmall(*shape):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
273 return (numpy.random.rand(*shape) -0.5) * 0.001
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
274 return [randsmall(ninputs, nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
275 , randsmall(nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
276 , randsmall(nhid, nclass)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
277 , randsmall(nclass)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
278
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
279 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
280 return trainset.minibatches(['input', 'target'],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
281 minibatch_size=min(len(trainset), 32), n_batches=300)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
282 def early_stopper(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
283 return stopper.NStages(300,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
284
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
285 return G()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
286
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
287
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
288 import unittest
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
289
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
290 class TestMLP(unittest.TestCase):
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
291 def blah(self, g):
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
292 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
293 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
294 [1, 0, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
295 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
296 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
297 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
298 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
299 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
300 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
301 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
302 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
303 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
304 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
305 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
306 {'input':slice(2)})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
307
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
308 learn_algo = GraphLearner(g)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
309
232
c047238e5b3f Fixed by James
delallea@opale.iro.umontreal.ca
parents: 226
diff changeset
310 model1 = learn_algo(training_set1)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
311
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
312 model2 = learn_algo(training_set2)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
313
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
314 omatch = [o1 == o2 for o1, o2 in zip(model1(test_data),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
315 model2(test_data))]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
316
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
317 n_match = sum(omatch)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
318
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
319 self.failUnless(n_match == (numpy.sum(training_set1.fields()['target'] ==
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
320 training_set2.fields()['target'])), omatch)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
321
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
322 model1.save('/tmp/model1')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
323
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
324 #denoising_aa = GraphLearner(denoising_g)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
325 #model1 = denoising_aa(trainset)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
326 #hidset = model(trainset, fieldnames=['hidden'])
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
327 #model2 = denoising_aa(hidset)
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
328
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
329 #f = open('blah', 'w')
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
330 #for m in model:
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
331 # m.save(f)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
332 #filetensor.write(f, initial_classification_weights)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
333 #f.flush()
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
334
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
335 #deep_sigmoid_net = GraphLearner(deepnetwork_g)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
336 #deep_model = deep_sigmoid_net.load('blah')
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
337 #deep_model.update(trainset) #do some fine tuning
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
338
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
339 model1_dup = learn_algo('/tmp/model1')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
340
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
341
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
342 def equiv(self, g0, g1):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
343 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
344 [0, 1, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
345 [1, 0, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
346 [1, 1, 1]]),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
347 {'input':slice(2),'target':2})
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
348 learn_algo_0 = GraphLearner(g0)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
349 learn_algo_1 = GraphLearner(g1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
350
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
351 model_0 = learn_algo_0(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
352 model_1 = learn_algo_1(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
353
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
354 print '----'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
355 for p in zip(model_0.params, model_1.params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
356 abs_rel_err = theano.gradient.numeric_grad.abs_rel_err(p[0], p[1])
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
357 max_abs_rel_err = numpy.max(abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
358 if max_abs_rel_err > 1.0e-7:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
359 print 'p0', p[0]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
360 print 'p1', p[1]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
361 #self.failUnless(max_abs_rel_err < 1.0e-7, max_abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
362
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
363
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
364 def test0(self): self.blah(graphMLP(2, 10, 2, .1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
365 def test1(self): self.blah(graphMLP(2, 3, 2, .1))
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
366
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
367 if __name__ == '__main__':
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
368 unittest.main()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
369
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
370