annotate mlp_factory_approach.py @ 235:a70f2c973ea5

re-enabling old ArrayDataSet indexing
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 28 May 2008 14:09:19 -0400
parents c047238e5b3f
children 3156a9976183
rev   line source
225
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
1 """
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
2
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
3
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
4
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
5 This file is deprecated. I'm continuing development in hpu/models.py.
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
6
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
7 Get that project like this: hg clone ssh://user@lgcm/../bergstrj/hpu
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
8
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
9
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
10
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
11
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
12
8bc16220b29a deprecation note for mlp
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 215
diff changeset
13 """
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
14 import copy, sys
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
15 import numpy
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
16
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
17 import theano
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
18 from theano import tensor as t
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
19
218
df3fae88ab46 small debugging
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 215
diff changeset
20 from pylearn import dataset, nnet_ops, stopper
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
21
211
bd728c83faff in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 208
diff changeset
22
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
23 def _randshape(*shape):
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
24 return (numpy.random.rand(*shape) -0.5) * 0.001
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
25
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
26 def _cache(d, key, valfn):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
27 #valfn() is only evaluated if key isn't in dictionary d
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
28 if key not in d:
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
29 d[key] = valfn()
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
30 return d[key]
190
aa7a3ecbcc90 progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 189
diff changeset
31
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
32 class _Model(object):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
33 def __init__(self, algo, params):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
34 self.algo = algo
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
35 self.params = params
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
36 v = algo.v
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
37 self.update_fn = algo._fn([v.input, v.target] + v.params, [v.nll] + v.new_params)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
38 self._fn_cache = {}
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
39
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
40 def __copy__(self):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
41 return _Model(self.algo, [copy.copy(p) for p in params])
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
42
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
43 def update(self, input_target):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
44 """Update this model from more training data."""
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
45 params = self.params
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
46 #TODO: why should we have to unpack target like this?
218
df3fae88ab46 small debugging
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 215
diff changeset
47 # tbm : creates problem...
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
48 for input, target in input_target:
232
c047238e5b3f Fixed by James
delallea@opale.iro.umontreal.ca
parents: 226
diff changeset
49 rval= self.update_fn(input, target, *params)
212
9b57ea8c767f previous commit was supposed to concern only one file, dataset.py, try to undo my other changes with this commit (nothing was broken though, just useless debugging prints)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 211
diff changeset
50 #print rval[0]
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
51
218
df3fae88ab46 small debugging
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 215
diff changeset
52 def __call__(self, testset, fieldnames=['output_class'],input='input',target='target'):
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
53 """Apply this model (as a function) to new data"""
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
54 #TODO: cache fn between calls
218
df3fae88ab46 small debugging
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 215
diff changeset
55 assert input == testset.fieldNames()[0] # why first one???
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
56 assert len(testset.fieldNames()) <= 2
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
57 v = self.algo.v
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
58 outputs = [getattr(v, name) for name in fieldnames]
218
df3fae88ab46 small debugging
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 215
diff changeset
59 inputs = [v.input] + ([v.target] if target in testset else [])
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
60 inputs.extend(v.params)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
61 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
62 lambda: self.algo._fn(inputs, outputs))
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
63 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
64 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
65
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
66 class AutonameVars(object):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
67 def __init__(self, dct):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
68 for key, val in dct.items():
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
69 if type(key) is str and hasattr(val, 'name'):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
70 val.name = key
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
71 self.__dict__.update(dct)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
72
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
73 class MultiLayerPerceptron(object):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
74
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
75 def __init__(self, ninputs, nhid, nclass, lr,
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
76 l2coef=0.0,
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
77 linker='c&py',
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
78 hidden_layer=None,
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
79 early_stopper=None,
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
80 validation_portion=0.2,
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
81 V_extern=None):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
82 class V_intern(AutonameVars):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
83 def __init__(v_self, lr, l2coef, **kwargs):
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
84 lr = t.constant(lr)
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
85 l2coef = t.constant(l2coef)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
86 input = t.matrix() # n_examples x n_inputs
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
87 target = t.ivector() # len: n_examples
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
88 W2, b2 = t.matrix(), t.vector()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
89
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
90 if hidden_layer:
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
91 hid, hid_params, hid_ivals, hid_regularization = hidden_layer(input)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
92 else:
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
93 W1, b1 = t.matrix(), t.vector()
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
94 hid = t.tanh(b1 + t.dot(input, W1))
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
95 hid_params = [W1, b1]
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
96 hid_regularization = l2coef * t.sum(W1*W1)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
97 hid_ivals = lambda : [_randshape(ninputs, nhid), _randshape(nhid)]
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
98
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
99 params = [W2, b2] + hid_params
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
100 activations = b2 + t.dot(hid, W2)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
101 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
102 regularization = l2coef * t.sum(W2*W2) + hid_regularization
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
103 output_class = t.argmax(activations,1)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
104 loss_01 = t.neq(output_class, target)
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
105 g_params = t.grad(nll + regularization, params)
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
106 new_params = [t.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
107 self.__dict__.update(locals()); del self.self
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
108 AutonameVars.__init__(v_self, locals())
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
109 self.nhid = nhid
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
110 self.nclass = nclass
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
111 self.v = V_intern(**locals()) if V_extern is None else V_extern(**locals())
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
112 self.linker = linker
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
113 self.early_stopper = early_stopper if early_stopper is not None else lambda: stopper.NStages(10,1)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
114 self.validation_portion = validation_portion
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
115
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
116 def _fn(self, inputs, outputs):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
117 # Caching here would hamper multi-threaded apps
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
118 # prefer caching in _Model.__call__
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
119 return theano.function(inputs, outputs, unpack_single=False, linker=self.linker)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
120
211
bd728c83faff in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 208
diff changeset
121 def __call__(self, trainset=None, iparams=None, input='input', target='target'):
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
122 """Allocate and optionally train a model"""
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
123 if iparams is None:
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
124 iparams = [_randshape(self.nhid, self.nclass), _randshape(self.nclass)]\
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
125 + self.v.hid_ivals()
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
126 rval = _Model(self, iparams)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
127 if trainset:
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
128 if len(trainset) == sys.maxint:
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
129 raise NotImplementedError('Learning from infinite streams is not supported')
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
130 nval = int(self.validation_portion * len(trainset))
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
131 nmin = len(trainset) - nval
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
132 assert nmin >= 0
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
133 minset = trainset[:nmin] #real training set for minimizing loss
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
134 valset = trainset[nmin:] #validation set for early stopping
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
135 best = rval
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
136 for stp in self.early_stopper():
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
137 rval.update(
211
bd728c83faff in __get__, problem if the i.stop was None, i being the slice, added one line replacing None by the len(self)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 208
diff changeset
138 minset.minibatches([input, target], minibatch_size=min(32,
232
c047238e5b3f Fixed by James
delallea@opale.iro.umontreal.ca
parents: 226
diff changeset
139 len(minset))))
212
9b57ea8c767f previous commit was supposed to concern only one file, dataset.py, try to undo my other changes with this commit (nothing was broken though, just useless debugging prints)
Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
parents: 211
diff changeset
140 #print 'mlp.__call__(), we did an update'
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
141 if stp.set_score:
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
142 stp.score = rval(valset, ['loss_01'])
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
143 if (stp.score < stp.best_score):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
144 best = copy.copy(rval)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
145 rval = best
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
146 return rval
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
147
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
148
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
149 import unittest
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
150
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
151 class TestMLP(unittest.TestCase):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
152 def test0(self):
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
153
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
154 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
155 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
156 [1, 0, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
157 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
158 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
159 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
160 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
161 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
162 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
163 {'input':slice(2),'target':2})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
164 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
165 [0, 1, 1],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
166 [1, 0, 0],
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
167 [1, 1, 1]]),
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
168 {'input':slice(2)})
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
169
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
170 learn_algo = MultiLayerPerceptron(2, 10, 2, .1
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
171 , linker='c&py'
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
172 , early_stopper = lambda:stopper.NStages(100,1))
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
173
232
c047238e5b3f Fixed by James
delallea@opale.iro.umontreal.ca
parents: 226
diff changeset
174 model1 = learn_algo(training_set1)
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
175
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
176 model2 = learn_algo(training_set2)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
177
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
178 n_match = 0
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
179 for o1, o2 in zip(model1(test_data), model2(test_data)):
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
180 #print o1
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
181 #print o2
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
182 n_match += (o1 == o2)
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
183
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
184 assert n_match == (numpy.sum(training_set1.fields()['target'] ==
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
185 training_set2.fields()['target']))
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
186
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
187 if __name__ == '__main__':
208
bf320808919f back to James' version
Yoshua Bengio <bengioy@iro.umontreal.ca>
parents: 207
diff changeset
188 unittest.main()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
189