view test_mlp.py @ 274:ed70580f2324

bugfix in FieldSubsetDataSet
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Thu, 05 Jun 2008 13:46:26 -0400
parents ebbb0e749565
children
line wrap: on
line source


from mlp import *
import dataset
import nnet_ops


from functools import partial
def separator(debugger, i, node, *ths):
    print "==================="

def what(debugger, i, node, *ths):
    print "#%i" % i, node

def parents(debugger, i, node, *ths):
    print [input.step for input in node.inputs]

def input_shapes(debugger, i, node, *ths):
    print "input shapes: ",
    for r in node.inputs:
        if hasattr(r.value, 'shape'):
            print r.value.shape,
        else:
            print "no_shape",
    print

def input_types(debugger, i, node, *ths):
    print "input types: ",
    for r in node.inputs:
        print r.type,
    print

def output_shapes(debugger, i, node, *ths):
    print "output shapes:",
    for r in node.outputs:
        if hasattr(r.value, 'shape'):
            print r.value.shape,
        else:
            print "no_shape",
    print

def output_types(debugger, i, node, *ths):
    print "output types:",
    for r in node.outputs:
        print r.type,
    print


def test0():
    linker = 'c|py'
    #linker = partial(theano.gof.DebugLinker, linkers = [theano.gof.OpWiseCLinker],
    #                 debug_pre = [separator, what, parents, input_types, input_shapes],
    #                 debug_post = [output_shapes, output_types],
    #                 compare_fn = lambda x, y: numpy.all(x == y))
    
    nnet = OneHiddenLayerNNetClassifier(10,2,.001,1000, linker = linker)
    training_set = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
                                                     [0, 1, 1],
                                                     [1, 0, 1],
                                                     [1, 1, 1]]),
                                        {'input':slice(2),'target':2})
    fprop=nnet(training_set)

    output_ds = fprop(training_set)

    for fieldname in output_ds.fieldNames():
        print fieldname+"=",output_ds[fieldname]

def test1():
    nnet = ManualNNet(2, 10,3,.1,1000)
    training_set = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
                                                     [0, 1, 1],
                                                     [1, 0, 1],
                                                     [1, 1, 1]]),
                                        {'input':slice(2),'target':2})
    fprop=nnet(training_set)

    output_ds = fprop(training_set)

    for fieldname in output_ds.fieldNames():
        print fieldname+"=",output_ds[fieldname]

def test2():
    training_set = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
                                                     [0, 1, 1],
                                                     [1, 0, 1],
                                                     [1, 1, 1]]),
                                        {'input':slice(2),'target':2})
    nin, nhid=2, 10
    def sigm_layer(input):
        W1 = t.matrix('W1')
        b1 = t.vector('b1')
        return (nnet_ops.sigmoid(b1 + t.dot(input, W1)),
                [W1, b1],
                [(numpy.random.rand(nin, nhid) -0.5) * 0.001, numpy.zeros(nhid)])
    nnet = ManualNNet(nin, nhid, 3, .1, 1000, hidden_layer=sigm_layer)
    fprop=nnet(training_set)

    output_ds = fprop(training_set)

    for fieldname in output_ds.fieldNames():
        print fieldname+"=",output_ds[fieldname]

def test_interface_0():
    learner = ManualNNet(2, 10, 3, .1, 1000)

    model = learner(training_set)

    model2 = learner(training_set)    # trains model a second time

    learner.update(additional_data)   # modifies nnet and model by side-effect


def test_interface2_1():
    learn_algo = ManualNNet(2, 10, 3, .1, 1000)

    prior = learn_algo()

    model1 = learn_algo(training_set1)

    model2 = learn_algo(training_set2)

    model2.update(additional_data)

    n_match = 0
    for o1, o2 in zip(model1.use(test_data), model2.use(test_data)):
        n_match += (o1 == o2) 

    print n_match

test1()
test2()