Mercurial > pylearn
annotate mlp_factory_approach.py @ 228:6f55e301c687
optimisation of ArrayDataSet
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Fri, 16 May 2008 16:38:07 -0400 |
parents | e816821c1e50 |
children | c5a7105fa40b |
rev | line source |
---|---|
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
1 import copy, sys |
190
aa7a3ecbcc90
progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
189
diff
changeset
|
2 import numpy |
aa7a3ecbcc90
progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
189
diff
changeset
|
3 |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
4 import theano |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
5 from theano import tensor as t |
190
aa7a3ecbcc90
progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
189
diff
changeset
|
6 |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
7 from tlearn import dataset, nnet_ops, stopper |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
8 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
9 def _randshape(*shape): |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
10 return (numpy.random.rand(*shape) -0.5) * 0.001 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
11 |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
12 def _cache(d, key, valfn): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
13 #valfn() is only evaluated if key isn't in dictionary d |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
14 if key not in d: |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
15 d[key] = valfn() |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
16 return d[key] |
190
aa7a3ecbcc90
progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
189
diff
changeset
|
17 |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
18 class _Model(object): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
19 def __init__(self, algo, params): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
20 self.algo = algo |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
21 self.params = params |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
22 v = algo.v |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
23 self.update_fn = algo._fn([v.input, v.target] + v.params, [v.nll] + v.new_params) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
24 self._fn_cache = {} |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
25 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
26 def __copy__(self): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
27 return _Model(self.algo, [copy.copy(p) for p in params]) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
28 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
29 def update(self, input_target): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
30 """Update this model from more training data.""" |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
31 params = self.params |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
32 #TODO: why should we have to unpack target like this? |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
33 for input, target in input_target: |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
34 self.update_fn(input, target[:,0], *params) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
35 |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
36 def __call__(self, testset, fieldnames=['output_class']): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
37 """Apply this model (as a function) to new data""" |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
38 #TODO: cache fn between calls |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
39 assert 'input' == testset.fieldNames()[0] |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
40 assert len(testset.fieldNames()) <= 2 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
41 v = self.algo.v |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
42 outputs = [getattr(v, name) for name in fieldnames] |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
43 inputs = [v.input] + ([v.target] if 'target' in testset else []) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
44 inputs.extend(v.params) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
45 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
46 lambda: self.algo._fn(inputs, outputs)) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
47 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
48 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
49 |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
50 class AutonameVars(object): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
51 def __init__(self, dct): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
52 for key, val in dct.items(): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
53 if type(key) is str and hasattr(val, 'name'): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
54 val.name = key |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
55 self.__dict__.update(dct) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
56 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
57 class MultiLayerPerceptron(object): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
58 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
59 def __init__(self, ninputs, nhid, nclass, lr, |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
60 l2coef=0.0, |
189
8f58abb943d4
many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
187
diff
changeset
|
61 linker='c&py', |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
62 hidden_layer=None, |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
63 early_stopper=None, |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
64 validation_portion=0.2, |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
65 V_extern=None): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
66 class V_intern(AutonameVars): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
67 def __init__(v_self, lr, l2coef, **kwargs): |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
68 lr = t.constant(lr) |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
69 l2coef = t.constant(l2coef) |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
70 input = t.matrix() # n_examples x n_inputs |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
71 target = t.ivector() # len: n_examples |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
72 W2, b2 = t.matrix(), t.vector() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
73 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
74 if hidden_layer: |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
75 hid, hid_params, hid_ivals, hid_regularization = hidden_layer(input) |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
76 else: |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
77 W1, b1 = t.matrix(), t.vector() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
78 hid = t.tanh(b1 + t.dot(input, W1)) |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
79 hid_params = [W1, b1] |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
80 hid_regularization = l2coef * t.sum(W1*W1) |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
81 hid_ivals = lambda : [_randshape(ninputs, nhid), _randshape(nhid)] |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
82 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
83 params = [W2, b2] + hid_params |
189
8f58abb943d4
many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
187
diff
changeset
|
84 activations = b2 + t.dot(hid, W2) |
8f58abb943d4
many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
187
diff
changeset
|
85 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
86 regularization = l2coef * t.sum(W2*W2) + hid_regularization |
189
8f58abb943d4
many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
187
diff
changeset
|
87 output_class = t.argmax(activations,1) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
88 loss_01 = t.neq(output_class, target) |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
89 g_params = t.grad(nll + regularization, params) |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
90 new_params = [t.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)] |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
91 self.__dict__.update(locals()); del self.self |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
92 AutonameVars.__init__(v_self, locals()) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
93 self.nhid = nhid |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
94 self.nclass = nclass |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
95 self.v = V_intern(**locals()) if V_extern is None else V_extern(**locals()) |
189
8f58abb943d4
many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
187
diff
changeset
|
96 self.linker = linker |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
97 self.early_stopper = early_stopper if early_stopper is not None else lambda: stopper.NStages(10,1) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
98 self.validation_portion = validation_portion |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
99 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
100 def _fn(self, inputs, outputs): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
101 # Caching here would hamper multi-threaded apps |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
102 # prefer caching in _Model.__call__ |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
103 return theano.function(inputs, outputs, unpack_single=False, linker=self.linker) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
104 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
105 def __call__(self, trainset=None, iparams=None): |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
106 """Allocate and optionally train a model""" |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
107 if iparams is None: |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
108 iparams = [_randshape(self.nhid, self.nclass), _randshape(self.nclass)]\ |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
109 + self.v.hid_ivals() |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
110 rval = _Model(self, iparams) |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
111 if trainset: |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
112 if len(trainset) == sys.maxint: |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
113 raise NotImplementedError('Learning from infinite streams is not supported') |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
114 nval = int(self.validation_portion * len(trainset)) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
115 nmin = len(trainset) - nval |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
116 assert nmin >= 0 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
117 minset = trainset[:nmin] #real training set for minimizing loss |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
118 valset = trainset[nmin:] #validation set for early stopping |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
119 best = rval |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
120 for stp in self.early_stopper(): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
121 rval.update( |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
122 trainset.minibatches(['input', 'target'], minibatch_size=min(32, |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
123 len(trainset)))) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
124 if stp.set_score: |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
125 stp.score = rval(valset, ['loss_01']) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
126 if (stp.score < stp.best_score): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
127 best = copy.copy(rval) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
128 rval = best |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
129 return rval |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
130 |
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
131 |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
132 import unittest |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
133 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
134 class TestMLP(unittest.TestCase): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
135 def test0(self): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
136 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
137 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
138 [0, 1, 1], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
139 [1, 0, 1], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
140 [1, 1, 1]]), |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
141 {'input':slice(2),'target':2}) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
142 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
143 [0, 1, 1], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
144 [1, 0, 0], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
145 [1, 1, 1]]), |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
146 {'input':slice(2),'target':2}) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
147 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
148 [0, 1, 1], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
149 [1, 0, 0], |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
150 [1, 1, 1]]), |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
151 {'input':slice(2)}) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
152 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
153 learn_algo = MultiLayerPerceptron(2, 10, 2, .1 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
154 , linker='c&py' |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
155 , early_stopper = lambda:stopper.NStages(100,1)) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
156 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
157 model1 = learn_algo(training_set1) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
158 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
159 model2 = learn_algo(training_set2) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
160 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
161 n_match = 0 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
162 for o1, o2 in zip(model1(test_data), model2(test_data)): |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
163 #print o1 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
164 #print o2 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
165 n_match += (o1 == o2) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
166 |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
167 assert n_match == (numpy.sum(training_set1.fields()['target'] == |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
168 training_set2.fields()['target'])) |
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
169 |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
170 if __name__ == '__main__': |
191
e816821c1e50
added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
190
diff
changeset
|
171 unittest.main() |
187
ebbb0e749565
added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
172 |