diff mlp_factory_approach.py @ 187:ebbb0e749565

added mlp_factory_approach
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 14 May 2008 11:51:08 -0400
parents
children 8f58abb943d4 f2ddc795ec49
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/mlp_factory_approach.py	Wed May 14 11:51:08 2008 -0400
@@ -0,0 +1,127 @@
+import dataset
+import theano
+import theano.tensor as t
+import numpy
+import nnet_ops
+
+def _randshape(*shape): 
+    return (numpy.random.rand(*shape) -0.5) * 0.001
+def _function(inputs, outputs, linker='c&py'):
+    return theano.function(inputs, outputs, unpack_single=False,linker=linker)
+
+class NeuralNet(object):
+
+    class Model(object):
+        def __init__(self, nnet, params):
+            self.nnet = nnet
+            self.params = params
+
+        def update(self, trainset, stopper=None):
+            """Update this model from more training data."""
+            v = self.nnet.v
+            params = self.params
+            update_fn = _function([v.input, v.target] + v.params, [v.nll] + v.new_params)
+            if stopper is not None: 
+                raise NotImplementedError()
+            else:
+                for i in xrange(100):
+                    for input, target in trainset.minibatches(['input', 'target'],
+                            minibatch_size=min(32, len(trainset))):
+                        dummy = update_fn(input, target[:,0], *params)
+                        if 0: print dummy[0] #the nll
+
+        def __call__(self, testset,
+                output_fieldnames=['output_class'],
+                test_stats_collector=None,
+                copy_inputs=False,
+                put_stats_in_output_dataset=True,
+                output_attributes=[]):
+            """Apply this model (as a function) to new data"""
+            inputs = [self.nnet.v.input, self.nnet.v.target] + self.nnet.v.params
+            fn = _function(inputs, [getattr(self.nnet.v, name) for name in output_fieldnames])
+            if 'target' in testset.fields():
+                return dataset.ApplyFunctionDataSet(testset, 
+                    lambda input, target: fn(input, target[:,0], *self.params),
+                    output_fieldnames)
+            else:
+                return dataset.ApplyFunctionDataSet(testset, 
+                    lambda input: fn(input, numpy.zeros(1,dtype='int64'), *self.params),
+                    output_fieldnames)
+
+    def __init__(self, ninputs, nhid, nclass, lr, nepochs, 
+            l2coef=0.0,
+            linker='c&yp', 
+            hidden_layer=None):
+        class Vars:
+            def __init__(self, lr, l2coef):
+                lr = t.constant(lr)
+                l2coef = t.constant(l2coef)
+                input = t.matrix('input') # n_examples x n_inputs
+                target = t.ivector('target') # n_examples x 1
+                W2 = t.matrix('W2')
+                b2 = t.vector('b2')
+
+                if hidden_layer:
+                    hid, hid_params, hid_ivals, hid_regularization = hidden_layer(input)
+                else:
+                    W1 = t.matrix('W1')
+                    b1 = t.vector('b1')
+                    hid = t.tanh(b1 + t.dot(input, W1))
+                    hid_params = [W1, b1]
+                    hid_regularization = l2coef * t.sum(W1*W1)
+                    hid_ivals = lambda : [_randshape(ninputs, nhid), _randshape(nhid)]
+
+                params = [W2, b2] + hid_params
+                nll, predictions = nnet_ops.crossentropy_softmax_1hot( b2 + t.dot(hid, W2), target)
+                regularization = l2coef * t.sum(W2*W2) + hid_regularization
+                output_class = t.argmax(predictions,1)
+                loss_01 = t.neq(output_class, target)
+                g_params = t.grad(nll + regularization, params)
+                new_params = [t.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
+                self.__dict__.update(locals()); del self.self
+        self.nhid = nhid
+        self.nclass = nclass
+        self.nepochs = nepochs
+        self.v = Vars(lr, l2coef)
+        self.params = None
+
+    def __call__(self, trainset=None, iparams=None):
+        if iparams is None:
+            iparams = [_randshape(self.nhid, self.nclass), _randshape(self.nclass)]\
+                    + self.v.hid_ivals()
+        rval = NeuralNet.Model(self, iparams)
+        if trainset:
+            rval.update(trainset)
+        return rval
+
+
+if __name__ == '__main__':
+    training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
+                                                     [0, 1, 1],
+                                                     [1, 0, 1],
+                                                     [1, 1, 1]]),
+                                        {'input':slice(2),'target':2})
+    training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
+                                                     [0, 1, 1],
+                                                     [1, 0, 0],
+                                                     [1, 1, 1]]),
+                                        {'input':slice(2),'target':2})
+    test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
+                                                     [0, 1, 1],
+                                                     [1, 0, 0],
+                                                     [1, 1, 1]]),
+                                        {'input':slice(2)})
+
+    learn_algo = NeuralNet(2, 10, 3, .1, 1000)
+
+    model1 = learn_algo(training_set1)
+
+    model2 = learn_algo(training_set2)
+
+    n_match = 0
+    for o1, o2 in zip(model1(test_data), model2(test_data)):
+        n_match += (o1 == o2) 
+
+    print n_match, numpy.sum(training_set1.fields()['target'] ==
+            training_set2.fields()['target'])
+