view mlp_factory_approach.py @ 206:f2ddc795ec49

changes made with Pascal but should probably be discarded
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Fri, 16 May 2008 16:36:27 -0400
parents ebbb0e749565
children c5a7105fa40b
line wrap: on
line source

import dataset
import theano
import theano.tensor as t
import numpy
import nnet_ops

def _randshape(*shape): 
    return (numpy.random.rand(*shape) -0.5) * 0.001
def _function(inputs, outputs, linker='c&py'):
    return theano.function(inputs, outputs, unpack_single=False,linker=linker)

class NeuralNet(object):

    class Model(object):
        def __init__(self, nnet, params):
            self.nnet = nnet
            self.params = params

        def update(self, trainset, stopper=None):
            """Update this model from more training data."""
            v = self.nnet.v
            params = self.params
            update_fn = _function([v.input, v.target] + v.params, [v.nll] + v.new_params)
            if stopper is not None: 
                raise NotImplementedError()
            else:
                for i in xrange(100):
                    for input, target in trainset.minibatches(['input', 'target'],
                            minibatch_size=min(32, len(trainset))):
                        results = update_fn(input, target[:,0], *params)
                        if 0: print results[0]
                        # print params['b']

        def __call__(self, testset,
                output_fieldnames=['output_class'],
                test_stats_collector=None,
                copy_inputs=False,
                put_stats_in_output_dataset=True,
                output_attributes=[]):
            """Apply this model (as a function) to new data"""
            inputs = [self.nnet.v.input, self.nnet.v.target] + self.nnet.v.params
            fn = _function(inputs, [getattr(self.nnet.v, name) for name in output_fieldnames])
            if 'target' in testset.fieldNames():
                return dataset.ApplyFunctionDataSet(testset, 
                    lambda input, target: fn(input, target[:,0], *self.params),
                    output_fieldnames)
            else:
                return dataset.ApplyFunctionDataSet(testset, 
                    lambda input: fn(input, numpy.zeros(1,dtype='int64'), *self.params),
                    output_fieldnames)

    def __init__(self, ninputs, nhid, nclass, lr, nepochs, 
                 l2coef=0.0,
                 linker='c&yp', 
                 hidden_layer=None):
        if not hidden_layer:
            hidden_layer = AffineSigmoidLayer("hidden",ninputs,nhid,l2coef)
        class Vars:
            def __init__(self, lr, l2coef):
                lr = t.constant(lr)
                l2coef = t.constant(l2coef)
                input = t.matrix('input') # n_examples x n_inputs
                target = t.ivector('target') # n_examples x 1
                W2 = t.matrix('W2')
                b2 = t.vector('b2')

                hid = hidden_layer(input)
                hid_params = hidden_layer.params()
                hid_params_init_vals = hidden_layer.params_ivals()
                hid_regularization = hidden_layer.regularization()
                    
                params = [W2, b2] + hid_params
                nll, predictions = nnet_ops.crossentropy_softmax_1hot( b2 + t.dot(hid, W2), target)
                regularization = l2coef * t.sum(W2*W2) + hid_regularization
                output_class = t.argmax(predictions,1)
                loss_01 = t.neq(output_class, target)
                g_params = t.grad(nll + regularization, params)
                new_params = [t.sub_inplace(p, lr * gp) for p,gp in zip(params, g_params)]
                setattr_and_name(self, locals())
        self.nhid = nhid
        self.nclass = nclass
        self.nepochs = nepochs
        self.v = Vars(lr, l2coef)
        self.params = None

    def __call__(self, trainset=None, iparams=None):
        if iparams is None:
            iparams = LookupList(["W","b"],[_randshape(self.nhid, self.nclass), _randshape(self.nclass)])
                    + self.v.hid_params_init_vals()
        rval = NeuralNet.Model(self, iparams)
        if trainset:
            rval.update(trainset)
        return rval


def setattr_and_name(self, dict):
    """This will do a self.__setattr__ for all elements in the dict
    (except for element self). In addition it will make sure that
    each element's .name (if it exists) is set to the element's key
    in the dicitonary.
    Typical usage:  setattr_and_name(self, locals())  """
    for varname,var in locals.items():
        if var is not self:
            if hasattr(var,"name") and not var.name:
                var.name=varname
            self.__setattr__(varname,var)


if __name__ == '__main__':
    training_set1 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
                                                     [0, 1, 1],
                                                     [1, 0, 1],
                                                     [1, 1, 1]]),
                                        {'input':slice(2),'target':2})
    training_set2 = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
                                                     [0, 1, 1],
                                                     [1, 0, 0],
                                                     [1, 1, 1]]),
                                        {'input':slice(2),'target':2})
    test_data = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
                                                     [0, 1, 1],
                                                     [1, 0, 0],
                                                     [1, 1, 1]]),
                                        {'input':slice(2)})


    learn_algo = NeuralNet(2, 10, 3, .1, 1000)

    model = learn_algo()

    model1 = learn_algo(training_set1)

    model2 = learn_algo(training_set2)

    n_match = 0
    for o1, o2 in zip(model1(test_data), model2(test_data)):
        n_match += (o1 == o2) 

    print n_match, numpy.sum(training_set1.fields()['target'] ==
            training_set2.fields()['target'])