Mercurial > pylearn
diff mlp_factory_approach.py @ 187:ebbb0e749565
added mlp_factory_approach
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 14 May 2008 11:51:08 -0400 |
parents | |
children | 8f58abb943d4 f2ddc795ec49 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/mlp_factory_approach.py Wed May 14 11:51:08 2008 -0400 @@ -0,0 +1,127 @@ +import dataset +import theano +import theano.tensor as t +import numpy +import nnet_ops + +def _randshape(*shape): + return (numpy.random.rand(*shape) -0.5) * 0.001 +def _function(inputs, outputs, linker='c&py'): + return theano.function(inputs, outputs, unpack_single=False,linker=linker) + +class NeuralNet(object): + + class Model(object): + def __init__(self, nnet, params): + self.nnet = nnet + self.params = params + + def update(self, trainset, stopper=None): + """Update this model from more training data.""" + v = self.nnet.v + params = self.params + update_fn = _function([v.input, v.target] + v.params, [v.nll] + v.new_params) + if stopper is not None: + raise NotImplementedError() + else: + for i in xrange(100): + for input, target in trainset.minibatches(['input', 'target'], + minibatch_size=min(32, len(trainset))): + dummy = update_fn(input, target[:,0], *params) + if 0: print dummy[0] #the nll + + def __call__(self, testset, + output_fieldnames=['output_class'], + test_stats_collector=None, + copy_inputs=False, + put_stats_in_output_dataset=True, + output_attributes=[]): + """Apply this model (as a function) to new data""" + inputs = [self.nnet.v.input, self.nnet.v.target] + self.nnet.v.params + fn = _function(inputs, [getattr(self.nnet.v, name) for name in output_fieldnames]) + if 'target' in testset.fields(): + return dataset.ApplyFunctionDataSet(testset, + lambda input, target: fn(input, target[:,0], *self.params), + output_fieldnames) + else: + return dataset.ApplyFunctionDataSet(testset, + lambda input: fn(input, numpy.zeros(1,dtype='int64'), *self.params), + output_fieldnames) + + def __init__(self, ninputs, nhid, nclass, lr, nepochs, + l2coef=0.0, + linker='c&yp', + hidden_layer=None): + class Vars: + def __init__(self, lr, l2coef): + lr = t.constant(lr) + l2coef = t.constant(l2coef) + input = t.matrix('input') # n_examples x n_inputs + target = t.ivector('target') # n_examples x 1 + W2 = t.matrix('W2') + b2 = t.vector('b2') + + if hidden_layer: + hid, hid_params, hid_ivals, hid_regularization = hidden_layer(input) + else: + W1 = t.matrix('W1') + b1 = t.vector('b1') + hid = t.tanh(b1 + t.dot(input, W1)) + hid_params = [W1, b1] + hid_regularization = l2coef * t.sum(W1*W1) + hid_ivals = lambda : [_randshape(ninputs, nhid), _randshape(nhid)] + + params = [W2, b2] + hid_params + nll, predictions = nnet_ops.crossentropy_softmax_1hot( b2 + t.dot(hid, W2), target) + regularization = l2coef * t.sum(W2*W2) + hid_regularization + output_class = t.argmax(predictions,1) + loss_01 = t.neq(output_class, target) + g_params = t.grad(nll + regularization, params) + new_params = [t.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)] + self.__dict__.update(locals()); del self.self + self.nhid = nhid + self.nclass = nclass + self.nepochs = nepochs + self.v = Vars(lr, l2coef) + self.params = None + + def __call__(self, trainset=None, iparams=None): + if iparams is None: + iparams = [_randshape(self.nhid, self.nclass), _randshape(self.nclass)]\ + + self.v.hid_ivals() + rval = NeuralNet.Model(self, iparams) + if trainset: + rval.update(trainset) + return rval + + +if __name__ == '__main__': + training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], + [0, 1, 1], + [1, 0, 1], + [1, 1, 1]]), + {'input':slice(2),'target':2}) + training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0], + [0, 1, 1], + [1, 0, 0], + [1, 1, 1]]), + {'input':slice(2),'target':2}) + test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0], + [0, 1, 1], + [1, 0, 0], + [1, 1, 1]]), + {'input':slice(2)}) + + learn_algo = NeuralNet(2, 10, 3, .1, 1000) + + model1 = learn_algo(training_set1) + + model2 = learn_algo(training_set2) + + n_match = 0 + for o1, o2 in zip(model1(test_data), model2(test_data)): + n_match += (o1 == o2) + + print n_match, numpy.sum(training_set1.fields()['target'] == + training_set2.fields()['target']) +