view test_mlp.py @ 185:3d953844abd3

support for more int types in crossentropysoftmax1hot
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 13 May 2008 19:37:29 -0400
parents 25d0a0c713da
children 562f308873f0
line wrap: on
line source


from mlp import *
import dataset


from functools import partial
def separator(debugger, i, node, *ths):
    print "==================="

def what(debugger, i, node, *ths):
    print "#%i" % i, node

def parents(debugger, i, node, *ths):
    print [input.step for input in node.inputs]

def input_shapes(debugger, i, node, *ths):
    print "input shapes: ",
    for r in node.inputs:
        if hasattr(r.value, 'shape'):
            print r.value.shape,
        else:
            print "no_shape",
    print

def input_types(debugger, i, node, *ths):
    print "input types: ",
    for r in node.inputs:
        print r.type,
    print

def output_shapes(debugger, i, node, *ths):
    print "output shapes:",
    for r in node.outputs:
        if hasattr(r.value, 'shape'):
            print r.value.shape,
        else:
            print "no_shape",
    print

def output_types(debugger, i, node, *ths):
    print "output types:",
    for r in node.outputs:
        print r.type,
    print


def test0():
    linker = 'c|py'
    #linker = partial(theano.gof.DebugLinker, linkers = [theano.gof.OpWiseCLinker],
    #                 debug_pre = [separator, what, parents, input_types, input_shapes],
    #                 debug_post = [output_shapes, output_types],
    #                 compare_fn = lambda x, y: numpy.all(x == y))
    
    nnet = OneHiddenLayerNNetClassifier(10,2,.001,1000, linker = linker)
    training_set = dataset.ArrayDataSet(numpy.array([[0, 0, 0],
                                                     [0, 1, 1],
                                                     [1, 0, 1],
                                                     [1, 1, 1]]),
                                        {'input':slice(2),'target':2})
    fprop=nnet(training_set)

    output_ds = fprop(training_set)

    for fieldname in output_ds.fieldNames():
        print fieldname+"=",output_ds[fieldname]

test0()