annotate mlp_factory_approach.py @ 338:7d4792fc28ae

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 16 Jun 2008 17:17:50 -0400
parents 93280a0c151a
children
rev   line source
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
1 import copy, sys, os
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
2 import numpy
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
3
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 import theano
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
5 from theano import tensor as T
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
7 import dataset, nnet_ops, stopper, filetensor
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
8 from pylearn.lookup_list import LookupList
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
9
211
bd728c83faff in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 208
diff changeset
10
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
11 class AbstractFunction (Exception): pass
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
12
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
13 class AutoName(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
14 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
15 By inheriting from this class, class variables which have a name attribute
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
16 will have that name attribute set to the class variable name.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
17 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
18 class __metaclass__(type):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
19 def __init__(cls, name, bases, dct):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
20 type.__init__(name, bases, dct)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
21 for key, val in dct.items():
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
22 assert type(key) is str
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
23 if hasattr(val, 'name'):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
24 val.name = key
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
25
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
26 class GraphLearner(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
27 class Model(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
28 def __init__(self, algo, params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
29 self.algo = algo
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
30 self.params = params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
31 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
32 self.update_fn = algo._fn([graph.input, graph.target] + graph.params,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
33 [graph.nll] + graph.new_params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
34 self._fn_cache = {}
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
35
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
36 def __copy__(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
37 raise Exception('why not called?')
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
38 return GraphLearner.Model(self.algo, [copy.copy(p) for p in params])
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
39
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
40 def __eq__(self,other,tolerance=0.) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
41 """ Only compares weights of matrices and bias vector. """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
42 if not isinstance(other,GraphLearner.Model) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
43 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
44 for p in range(4) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
45 if self.params[p].shape != other.params[p].shape :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
46 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
47 if not numpy.all( numpy.abs(self.params[p] - other.params[p]) <= tolerance ) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
48 return False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
49 return True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
50
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
51 def _cache(self, key, valfn):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
52 d = self._fn_cache
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
53 if key not in d:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
54 d[key] = valfn()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
55 return d[key]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
56
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
57 def update_minibatch(self, minibatch):
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
58 if not isinstance(minibatch, LookupList):
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
59 print type(minibatch)
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
60 assert isinstance(minibatch, LookupList)
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
61 self.update_fn(minibatch['input'], minibatch['target'], *self.params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
62
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
63 def update(self, dataset,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
64 default_minibatch_size=32):
305
93280a0c151a more verbose
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 304
diff changeset
65 """
93280a0c151a more verbose
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 304
diff changeset
66 Update this model from more training data.Uses all the data once, cut
93280a0c151a more verbose
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 304
diff changeset
67 into minibatches. No early stopper here.
93280a0c151a more verbose
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 304
diff changeset
68 """
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
69 params = self.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
70 minibatch_size = min(default_minibatch_size, len(dataset))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
71 for mb in dataset.minibatches(['input', 'target'], minibatch_size=minibatch_size):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
72 self.update_minibatch(mb)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
73
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
74 def save(self, f):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
75 self.algo.graph.save(f, self)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
76
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
77 def __call__(self, testset, fieldnames=['output_class']):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
78 """Apply this model (as a function) to new data.
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
79
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
80 @param testset: DataSet, whose fields feed Result terms in self.algo.g
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
81 @type testset: DataSet
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
82
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
83 @param fieldnames: names of results in self.algo.g to compute.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
84 @type fieldnames: list of strings
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
85
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
86 @return: DataSet with fields from fieldnames, computed from testset by
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
87 this model.
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
88 @rtype: ApplyFunctionDataSet instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
89
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
90 """
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
91 graph = self.algo.graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
92 def getresult(name):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
93 r = getattr(graph, name)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
94 if not isinstance(r, theano.Result):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
95 raise TypeError('string does not name a theano.Result', (name, r))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
96 return r
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
97
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
98 provided = [getresult(name) for name in testset.fieldNames()]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
99 wanted = [getresult(name) for name in fieldnames]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
100 inputs = provided + graph.params
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
101
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
102 theano_fn = self._cache((tuple(inputs), tuple(wanted)),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
103 lambda: self.algo._fn(inputs, wanted))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
104 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
105 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
106
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
107 class Graph(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
108 class Opt(object):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
109 merge = theano.gof.MergeOptimizer()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
110 gemm_opt_1 = theano.gof.TopoOptimizer(theano.tensor_opt.gemm_pattern_1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
111 sqr_opt_0 = theano.gof.TopoOptimizer(theano.gof.PatternSub(
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
112 (T.mul,'x', 'x'),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
113 (T.sqr, 'x')))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
114
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
115 def __init__(self, do_sqr=True):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
116 self.do_sqr = do_sqr
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
117
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
118 def __call__(self, env):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
119 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
120 self.gemm_opt_1(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
121 if self.do_sqr:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
122 self.sqr_opt_0(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
123 self.merge(env)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
124
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
125 def linker(self):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
126 return theano.gof.PerformLinker()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
127
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
128 def early_stopper(self):
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
129 stopper.NStages(300,1)
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
130
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
131 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
132 raise AbstractFunction
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
133 optimizer = Opt()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
134
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
135 def load(self,f) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
136 raise AbstractFunction
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
137
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
138 def save(self,f,model) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
139 raise AbstractFunction
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
140
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
141
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
142 def __init__(self, graph):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
143 self.graph = graph
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
144
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
145 def _fn(self, inputs, outputs):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
146 # Caching here would hamper multi-threaded apps
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
147 # prefer caching in Model.__call__
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
148 return theano.function(inputs, outputs,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
149 unpack_single=False,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
150 optimizer=self.graph.optimizer,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
151 linker=self.graph.linker() if hasattr(self.graph, 'linker')
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
152 else 'c|py')
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
153
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
154 def __call__(self,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
155 trainset=None,
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
156 validset=None,
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
157 iparams=None,
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
158 stp=None):
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
159 """Allocate and optionally train a model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
160
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
161 @param trainset: Data for minimizing the cost function
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
162 @type trainset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
163
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
164 @param validset: Data for early stopping
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
165 @type validset: None or Dataset
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
166
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
167 @param input: name of field to use as input
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
168 @type input: string
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
169
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
170 @param target: name of field to use as target
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
171 @type target: string
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
172
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
173 @param stp: early stopper, if None use default in graphMLP.G
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
174 @type stp: None or early stopper
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
175
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
176 @return: model
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
177 @rtype: GraphLearner.Model instance
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
178
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
179 """
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
180
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
181 iparams = self.graph.iparams() if iparams is None else iparams
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
182
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
183 # if we load, type(trainset) == 'str'
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
184 if isinstance(trainset,str) or isinstance(trainset,file):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
185 #loadmodel = GraphLearner.Model(self, iparams)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
186 loadmodel = self.graph.load(self,trainset)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
187 return loadmodel
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
188
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
189 curmodel = GraphLearner.Model(self, iparams)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
190 best = curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
191
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
192 if trainset is not None:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
193 #do some training by calling Model.update_minibatch()
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
194 if stp == None :
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
195 stp = self.graph.early_stopper()
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
196 try :
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
197 countiter = 0
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
198 for mb in self.graph.train_iter(trainset):
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
199 curmodel.update_minibatch(mb)
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
200 if stp.set_score:
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
201 if validset:
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
202 stp.score = curmodel(validset, ['validset_score'])
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
203 if (stp.score < stp.best_score):
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
204 best = copy.copy(curmodel)
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
205 else:
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
206 stp.score = 0.0
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
207 countiter +=1
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
208 stp.next()
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
209 except StopIteration :
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
210 print 'Iterations stopped after ', countiter,' iterations'
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
211 if validset:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
212 curmodel = best
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
213 return curmodel
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
214
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
215
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
216 def graphMLP(ninputs, nhid, nclass, lr_val, l2coef_val=0.0):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
217
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
218
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
219 def wrapper(i, node, thunk):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
220 if 0:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
221 print i, node
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
222 print thunk.inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
223 print thunk.outputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
224 if node.op == nnet_ops.crossentropy_softmax_1hot_with_bias:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
225 print 'here is the nll op'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
226 thunk() #actually compute this piece of the graph
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
227
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
228 class G(GraphLearner.Graph, AutoName):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
229
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
230 lr = T.constant(lr_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
231 assert l2coef_val == 0.0
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
232 l2coef = T.constant(l2coef_val)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
233 input = T.matrix() # n_examples x n_inputs
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
234 target = T.ivector() # len: n_examples
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
235 #target = T.matrix()
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
236 W2, b2 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
237
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
238 W1, b1 = T.matrix(), T.vector()
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
239 hid = T.tanh(b1 + T.dot(input, W1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
240 hid_regularization = l2coef * T.sum(W1*W1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
241
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
242 params = [W1, b1, W2, b2]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
243 activations = b2 + T.dot(hid, W2)
299
eded3cb54930 small bug fixed
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 265
diff changeset
244 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target )
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
245 regularization = l2coef * T.sum(W2*W2) + hid_regularization
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
246 output_class = T.argmax(activations,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
247 loss_01 = T.neq(output_class, target)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
248 #g_params = T.grad(nll + regularization, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
249 g_params = T.grad(nll, params)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
250 new_params = [T.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
251
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
252
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
253 def __eq__(self,other) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
254 print 'G.__eq__ from graphMLP(), not implemented yet'
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
255 return NotImplemented
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
256
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
257
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
258 def load(self, algo, f):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
259 """ Load from file the 2 matrices and bias vectors """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
260 cloase_at_end = False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
261 if isinstance(f,str) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
262 f = open(f,'r')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
263 close_at_end = True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
264 params = []
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
265 for i in xrange(4):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
266 params.append(filetensor.read(f))
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
267 if close_at_end :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
268 f.close()
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
269 return GraphLearner.Model(algo, params)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
270
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
271 def save(self, f, model):
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
272 """ Save params to file, so 2 matrices and 2 bias vectors. Same order as iparams. """
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
273 cloase_at_end = False
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
274 if isinstance(f,str) :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
275 f = open(f,'w')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
276 close_at_end = True
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
277 for p in model.params:
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
278 filetensor.write(f,p)
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
279 if close_at_end :
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
280 f.close()
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
281
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
282
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
283 def iparams(self):
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
284 """ init params. """
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
285 def randsmall(*shape):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
286 return (numpy.random.rand(*shape) -0.5) * 0.001
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
287 return [randsmall(ninputs, nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
288 , randsmall(nhid)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
289 , randsmall(nhid, nclass)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
290 , randsmall(nclass)]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
291
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
292 def train_iter(self, trainset):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
293 return trainset.minibatches(['input', 'target'],
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
294 minibatch_size=min(len(trainset), 32), n_batches=2000)
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
295 def early_stopper(self):
304
6ead65d30f1e while learning using __call__, we can now set the early stopper
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 299
diff changeset
296 """ overwrites GraphLearner.graph function """
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
297 return stopper.NStages(300,1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
298
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
299 return G()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
300
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
301
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
302 import unittest
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
303
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
304 class TestMLP(unittest.TestCase):
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
305 def blah(self, g):
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
306 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
307 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
308 [1, 0, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
309 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
310 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
311 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
312 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
313 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
314 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
315 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
316 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
317 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
318 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
319 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
320 {'input':slice(2)})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
321
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
322 learn_algo = GraphLearner(g)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
323
232
c047238e5b3f Fixed by James
delallea@opale.iro.umontreal.ca
parents: 226
diff changeset
324 model1 = learn_algo(training_set1)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
325
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
326 model2 = learn_algo(training_set2)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
327
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
328 omatch = [o1 == o2 for o1, o2 in zip(model1(test_data),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
329 model2(test_data))]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
330
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
331 n_match = sum(omatch)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
332
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
333 self.failUnless(n_match == (numpy.sum(training_set1.fields()['target'] ==
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
334 training_set2.fields()['target'])), omatch)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
335
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
336 model1.save('/tmp/model1')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
337
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
338 #denoising_aa = GraphLearner(denoising_g)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
339 #model1 = denoising_aa(trainset)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
340 #hidset = model(trainset, fieldnames=['hidden'])
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
341 #model2 = denoising_aa(hidset)
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
342
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
343 #f = open('blah', 'w')
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
344 #for m in model:
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
345 # m.save(f)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
346 #filetensor.write(f, initial_classification_weights)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
347 #f.flush()
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
348
265
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
349 #deep_sigmoid_net = GraphLearner(deepnetwork_g)
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
350 #deep_model = deep_sigmoid_net.load('blah')
ae0a8345869b commented junk in the default test (main function) of mlp_factory_approach so the test still works
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 264
diff changeset
351 #deep_model.update(trainset) #do some fine tuning
264
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
352
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
353 model1_dup = learn_algo('/tmp/model1')
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
354
a1793a5e9523 we can now load and save in a file, see test class in the file for an example, but basically it's model1.save(filename) or learn_algo(filename) to load
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 244
diff changeset
355
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
356 def equiv(self, g0, g1):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
357 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
358 [0, 1, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
359 [1, 0, 1],
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
360 [1, 1, 1]]),
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
361 {'input':slice(2),'target':2})
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
362 learn_algo_0 = GraphLearner(g0)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
363 learn_algo_1 = GraphLearner(g1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
364
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
365 model_0 = learn_algo_0(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
366 model_1 = learn_algo_1(training_set1)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
367
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
368 print '----'
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
369 for p in zip(model_0.params, model_1.params):
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
370 abs_rel_err = theano.gradient.numeric_grad.abs_rel_err(p[0], p[1])
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
371 max_abs_rel_err = numpy.max(abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
372 if max_abs_rel_err > 1.0e-7:
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
373 print 'p0', p[0]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
374 print 'p1', p[1]
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
375 #self.failUnless(max_abs_rel_err < 1.0e-7, max_abs_rel_err)
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
376
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
377
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
378 def test0(self): self.blah(graphMLP(2, 10, 2, .1))
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
379 def test1(self): self.blah(graphMLP(2, 3, 2, .1))
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
380
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
381 if __name__ == '__main__':
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
382 unittest.main()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
383
244
3156a9976183 mlp_factory_approach.py, updated and un-deprecated by popular demand
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 232
diff changeset
384