Mercurial > pylearn
changeset 191:e816821c1e50
added early stopping to mlp.__call__
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 14 May 2008 20:04:44 -0400 |
parents | aa7a3ecbcc90 |
children | f62a03c9d485 |
files | mlp_factory_approach.py |
diffstat | 1 files changed, 125 insertions(+), 95 deletions(-) [+] |
line wrap: on
line diff
--- a/mlp_factory_approach.py Wed May 14 16:24:10 2008 -0400 +++ b/mlp_factory_approach.py Wed May 14 20:04:44 2008 -0400 @@ -1,84 +1,80 @@ -import copy +import copy, sys import numpy import theano -import theano.tensor as t +from theano import tensor as t -import dataset -import nnet_ops +from tlearn import dataset, nnet_ops, stopper def _randshape(*shape): return (numpy.random.rand(*shape) -0.5) * 0.001 -class NeuralNet(object): - - class _Model(object): - def __init__(self, nnet, params): - self.nnet = nnet - self.params = params - - def __copy__(self): - return _Model(self.nnet, [copy.copy(p) for p in params]) +def _cache(d, key, valfn): + #valfn() is only evaluated if key isn't in dictionary d + if key not in d: + d[key] = valfn() + return d[key] - def update(self, trainset, stopper=None): - """Update this model from more training data.""" - v = self.nnet.v - params = self.params - update_fn = self.nnet._fn([v.input, v.target] + v.params, [v.nll] + v.new_params) - if stopper is not None: - raise NotImplementedError() - else: - for i in xrange(100): - for input, target in trainset.minibatches(['input', 'target'], - minibatch_size=min(32, len(trainset))): - dummy = update_fn(input, target[:,0], *params) - if 0: print dummy[0] #the nll +class _Model(object): + def __init__(self, algo, params): + self.algo = algo + self.params = params + v = algo.v + self.update_fn = algo._fn([v.input, v.target] + v.params, [v.nll] + v.new_params) + self._fn_cache = {} + + def __copy__(self): + return _Model(self.algo, [copy.copy(p) for p in params]) + + def update(self, input_target): + """Update this model from more training data.""" + params = self.params + #TODO: why should we have to unpack target like this? + for input, target in input_target: + self.update_fn(input, target[:,0], *params) - def __call__(self, testset, - output_fieldnames=['output_class'], - test_stats_collector=None, - copy_inputs=False, - put_stats_in_output_dataset=True, - output_attributes=[]): - """Apply this model (as a function) to new data""" - v = self.nnet.v - outputs = [getattr(self.nnet.v, name) for name in output_fieldnames] - if 'target' in testset: - fn = self.nnet._fn([v.input, v.target] + v.params, outputs) - return dataset.ApplyFunctionDataSet(testset, - lambda input, target: fn(input, target[:,0], *self.params), - output_fieldnames) - else: - fn = self.nnet._fn([v.input] + v.params, outputs) - return dataset.ApplyFunctionDataSet(testset, - lambda input: fn(input, *self.params), - output_fieldnames) - def _fn(self, inputs, outputs): - #it is possible for this function to implement function caching - #... but not necessarily desirable. - #- caching ruins the possibility of multi-threaded learning - #- caching demands more efficiency in the face of resizing inputs - #- caching makes it really hard to borrow references to function outputs - return theano.function(inputs, outputs, unpack_single=False, linker=self.linker) + def __call__(self, testset, fieldnames=['output_class']): + """Apply this model (as a function) to new data""" + #TODO: cache fn between calls + assert 'input' == testset.fieldNames()[0] + assert len(testset.fieldNames()) <= 2 + v = self.algo.v + outputs = [getattr(v, name) for name in fieldnames] + inputs = [v.input] + ([v.target] if 'target' in testset else []) + inputs.extend(v.params) + theano_fn = _cache(self._fn_cache, (tuple(inputs), tuple(outputs)), + lambda: self.algo._fn(inputs, outputs)) + lambda_fn = lambda *args: theano_fn(*(list(args) + self.params)) + return dataset.ApplyFunctionDataSet(testset, lambda_fn, fieldnames) - def __init__(self, ninputs, nhid, nclass, lr, nepochs, +class AutonameVars(object): + def __init__(self, dct): + for key, val in dct.items(): + if type(key) is str and hasattr(val, 'name'): + val.name = key + self.__dict__.update(dct) + +class MultiLayerPerceptron(object): + + def __init__(self, ninputs, nhid, nclass, lr, l2coef=0.0, linker='c&py', - hidden_layer=None): - class Vars: - def __init__(self, lr, l2coef): + hidden_layer=None, + early_stopper=None, + validation_portion=0.2, + V_extern=None): + class V_intern(AutonameVars): + def __init__(v_self, lr, l2coef, **kwargs): lr = t.constant(lr) l2coef = t.constant(l2coef) - input = t.matrix('input') # n_examples x n_inputs - target = t.ivector('target') # n_examples x 1 - W2 = t.matrix('W2') - b2 = t.vector('b2') + input = t.matrix() # n_examples x n_inputs + target = t.ivector() # len: n_examples + W2, b2 = t.matrix(), t.vector() if hidden_layer: hid, hid_params, hid_ivals, hid_regularization = hidden_layer(input) else: - W1 = t.matrix('W1') - b1 = t.vector('b1') + W1, b1 = t.matrix(), t.vector() hid = t.tanh(b1 + t.dot(input, W1)) hid_params = [W1, b1] hid_regularization = l2coef * t.sum(W1*W1) @@ -93,50 +89,84 @@ g_params = t.grad(nll + regularization, params) new_params = [t.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)] self.__dict__.update(locals()); del self.self + AutonameVars.__init__(v_self, locals()) self.nhid = nhid self.nclass = nclass - self.nepochs = nepochs - self.v = Vars(lr, l2coef) - self.params = None + self.v = V_intern(**locals()) if V_extern is None else V_extern(**locals()) self.linker = linker + self.early_stopper = early_stopper if early_stopper is not None else lambda: stopper.NStages(10,1) + self.validation_portion = validation_portion + + def _fn(self, inputs, outputs): + # Caching here would hamper multi-threaded apps + # prefer caching in _Model.__call__ + return theano.function(inputs, outputs, unpack_single=False, linker=self.linker) def __call__(self, trainset=None, iparams=None): + """Allocate and optionally train a model""" if iparams is None: iparams = [_randshape(self.nhid, self.nclass), _randshape(self.nclass)]\ + self.v.hid_ivals() - rval = NeuralNet._Model(self, iparams) + rval = _Model(self, iparams) if trainset: - rval.update(trainset) + if len(trainset) == sys.maxint: + raise NotImplementedError('Learning from infinite streams is not supported') + nval = int(self.validation_portion * len(trainset)) + nmin = len(trainset) - nval + assert nmin >= 0 + minset = trainset[:nmin] #real training set for minimizing loss + valset = trainset[nmin:] #validation set for early stopping + best = rval + for stp in self.early_stopper(): + rval.update( + trainset.minibatches(['input', 'target'], minibatch_size=min(32, + len(trainset)))) + if stp.set_score: + stp.score = rval(valset, ['loss_01']) + if (stp.score < stp.best_score): + best = copy.copy(rval) + rval = best return rval +import unittest + +class TestMLP(unittest.TestCase): + def test0(self): + + training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], + [0, 1, 1], + [1, 0, 1], + [1, 1, 1]]), + {'input':slice(2),'target':2}) + training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], + [0, 1, 1], + [1, 0, 0], + [1, 1, 1]]), + {'input':slice(2),'target':2}) + test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0], + [0, 1, 1], + [1, 0, 0], + [1, 1, 1]]), + {'input':slice(2)}) + + learn_algo = MultiLayerPerceptron(2, 10, 2, .1 + , linker='c&py' + , early_stopper = lambda:stopper.NStages(100,1)) + + model1 = learn_algo(training_set1) + + model2 = learn_algo(training_set2) + + n_match = 0 + for o1, o2 in zip(model1(test_data), model2(test_data)): + #print o1 + #print o2 + n_match += (o1 == o2) + + assert n_match == (numpy.sum(training_set1.fields()['target'] == + training_set2.fields()['target'])) + if __name__ == '__main__': - training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], - [0, 1, 1], - [1, 0, 1], - [1, 1, 1]]), - {'input':slice(2),'target':2}) - training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], - [0, 1, 1], - [1, 0, 0], - [1, 1, 1]]), - {'input':slice(2),'target':2}) - test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0], - [0, 1, 1], - [1, 0, 0], - [1, 1, 1]]), - {'input':slice(2)}) + unittest.main() - learn_algo = NeuralNet(2, 10, 3, .1, 1000) - - model1 = learn_algo(training_set1) - - model2 = learn_algo(training_set2) - - n_match = 0 - for o1, o2 in zip(model1(test_data), model2(test_data)): - n_match += (o1 == o2) - - print n_match, numpy.sum(training_set1.fields()['target'] == - training_set2.fields()['target']) -