annotate mlp_factory_approach.py @ 191:e816821c1e50

added early stopping to mlp.__call__
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 14 May 2008 20:04:44 -0400
parents aa7a3ecbcc90
children c5a7105fa40b
rev   line source
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
1 import copy, sys
190
aa7a3ecbcc90 progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 189
diff changeset
2 import numpy
aa7a3ecbcc90 progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 189
diff changeset
3
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 import theano
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
5 from theano import tensor as t
190
aa7a3ecbcc90 progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 189
diff changeset
6
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
7 from tlearn import dataset, nnet_ops, stopper
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
8
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
9 def _randshape(*shape):
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
10 return (numpy.random.rand(*shape) -0.5) * 0.001
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
11
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
12 def _cache(d, key, valfn):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
13 #valfn() is only evaluated if key isn't in dictionary d
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
14 if key not in d:
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
15 d[key] = valfn()
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
16 return d[key]
190
aa7a3ecbcc90 progress toward early stopping
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 189
diff changeset
17
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
18 class _Model(object):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
19 def __init__(self, algo, params):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
20 self.algo = algo
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
21 self.params = params
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
22 v = algo.v
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
23 self.update_fn = algo._fn([v.input, v.target] + v.params, [v.nll] + v.new_params)
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
24 self._fn_cache = {}
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
25
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
26 def __copy__(self):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
27 return _Model(self.algo, [copy.copy(p) for p in params])
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
28
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
29 def update(self, input_target):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
30 """Update this model from more training data."""
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
31 params = self.params
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
32 #TODO: why should we have to unpack target like this?
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
33 for input, target in input_target:
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
34 self.update_fn(input, target[:,0], *params)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
35
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
36 def __call__(self, testset, fieldnames=['output_class']):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
37 """Apply this model (as a function) to new data"""
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
38 #TODO: cache fn between calls
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
39 assert 'input' == testset.fieldNames()[0]
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
40 assert len(testset.fieldNames()) <= 2
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
41 v = self.algo.v
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
42 outputs = [getattr(v, name) for name in fieldnames]
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
43 inputs = [v.input] + ([v.target] if 'target' in testset else [])
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
44 inputs.extend(v.params)
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
45 theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)),
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
46 lambda: self.algo._fn(inputs, outputs))
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
47 lambda_fn = lambda *args: theano_fn(*(list(args) + self.params))
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
48 return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
49
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
50 class AutonameVars(object):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
51 def __init__(self, dct):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
52 for key, val in dct.items():
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
53 if type(key) is str and hasattr(val, 'name'):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
54 val.name = key
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
55 self.__dict__.update(dct)
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
56
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
57 class MultiLayerPerceptron(object):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
58
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
59 def __init__(self, ninputs, nhid, nclass, lr,
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
60 l2coef=0.0,
189
8f58abb943d4 many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 187
diff changeset
61 linker='c&py',
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
62 hidden_layer=None,
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
63 early_stopper=None,
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
64 validation_portion=0.2,
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
65 V_extern=None):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
66 class V_intern(AutonameVars):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
67 def __init__(v_self, lr, l2coef, **kwargs):
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
68 lr = t.constant(lr)
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
69 l2coef = t.constant(l2coef)
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
70 input = t.matrix() # n_examples x n_inputs
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
71 target = t.ivector() # len: n_examples
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
72 W2, b2 = t.matrix(), t.vector()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
73
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
74 if hidden_layer:
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
75 hid, hid_params, hid_ivals, hid_regularization = hidden_layer(input)
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
76 else:
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
77 W1, b1 = t.matrix(), t.vector()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
78 hid = t.tanh(b1 + t.dot(input, W1))
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
79 hid_params = [W1, b1]
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
80 hid_regularization = l2coef * t.sum(W1*W1)
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
81 hid_ivals = lambda : [_randshape(ninputs, nhid), _randshape(nhid)]
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
82
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
83 params = [W2, b2] + hid_params
189
8f58abb943d4 many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 187
diff changeset
84 activations = b2 + t.dot(hid, W2)
8f58abb943d4 many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 187
diff changeset
85 nll, predictions = nnet_ops.crossentropy_softmax_1hot(activations, target)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
86 regularization = l2coef * t.sum(W2*W2) + hid_regularization
189
8f58abb943d4 many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 187
diff changeset
87 output_class = t.argmax(activations,1)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
88 loss_01 = t.neq(output_class, target)
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
89 g_params = t.grad(nll + regularization, params)
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
90 new_params = [t.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
91 self.__dict__.update(locals()); del self.self
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
92 AutonameVars.__init__(v_self, locals())
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
93 self.nhid = nhid
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
94 self.nclass = nclass
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
95 self.v = V_intern(**locals()) if V_extern is None else V_extern(**locals())
189
8f58abb943d4 many changes to NeuralNet
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 187
diff changeset
96 self.linker = linker
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
97 self.early_stopper = early_stopper if early_stopper is not None else lambda: stopper.NStages(10,1)
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
98 self.validation_portion = validation_portion
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
99
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
100 def _fn(self, inputs, outputs):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
101 # Caching here would hamper multi-threaded apps
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
102 # prefer caching in _Model.__call__
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
103 return theano.function(inputs, outputs, unpack_single=False, linker=self.linker)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
104
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
105 def __call__(self, trainset=None, iparams=None):
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
106 """Allocate and optionally train a model"""
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
107 if iparams is None:
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
108 iparams = [_randshape(self.nhid, self.nclass), _randshape(self.nclass)]\
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
109 + self.v.hid_ivals()
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
110 rval = _Model(self, iparams)
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
111 if trainset:
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
112 if len(trainset) == sys.maxint:
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
113 raise NotImplementedError('Learning from infinite streams is not supported')
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
114 nval = int(self.validation_portion * len(trainset))
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
115 nmin = len(trainset) - nval
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
116 assert nmin >= 0
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
117 minset = trainset[:nmin] #real training set for minimizing loss
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
118 valset = trainset[nmin:] #validation set for early stopping
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
119 best = rval
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
120 for stp in self.early_stopper():
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
121 rval.update(
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
122 trainset.minibatches(['input', 'target'], minibatch_size=min(32,
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
123 len(trainset))))
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
124 if stp.set_score:
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
125 stp.score = rval(valset, ['loss_01'])
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
126 if (stp.score < stp.best_score):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
127 best = copy.copy(rval)
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
128 rval = best
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
129 return rval
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
130
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
131
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
132 import unittest
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
133
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
134 class TestMLP(unittest.TestCase):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
135 def test0(self):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
136
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
137 training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
138 [0, 1, 1],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
139 [1, 0, 1],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
140 [1, 1, 1]]),
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
141 {'input':slice(2),'target':2})
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
142 training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
143 [0, 1, 1],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
144 [1, 0, 0],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
145 [1, 1, 1]]),
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
146 {'input':slice(2),'target':2})
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
147 test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
148 [0, 1, 1],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
149 [1, 0, 0],
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
150 [1, 1, 1]]),
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
151 {'input':slice(2)})
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
152
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
153 learn_algo = MultiLayerPerceptron(2, 10, 2, .1
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
154 , linker='c&py'
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
155 , early_stopper = lambda:stopper.NStages(100,1))
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
156
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
157 model1 = learn_algo(training_set1)
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
158
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
159 model2 = learn_algo(training_set2)
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
160
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
161 n_match = 0
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
162 for o1, o2 in zip(model1(test_data), model2(test_data)):
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
163 #print o1
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
164 #print o2
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
165 n_match += (o1 == o2)
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
166
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
167 assert n_match == (numpy.sum(training_set1.fields()['target'] ==
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
168 training_set2.fields()['target']))
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
169
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
170 if __name__ == '__main__':
191
e816821c1e50 added early stopping to mlp.__call__
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 190
diff changeset
171 unittest.main()
187
ebbb0e749565 added mlp_factory_approach
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
172