Mercurial > pylearn
comparison test_mlp.py @ 186:562f308873f0
added ManualNNet
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 13 May 2008 20:10:03 -0400 |
parents | 25d0a0c713da |
children | ebbb0e749565 |
comparison
equal
deleted
inserted
replaced
185:3d953844abd3 | 186:562f308873f0 |
---|---|
1 | 1 |
2 from mlp import * | 2 from mlp import * |
3 import dataset | 3 import dataset |
4 import nnet_ops | |
4 | 5 |
5 | 6 |
6 from functools import partial | 7 from functools import partial |
7 def separator(debugger, i, node, *ths): | 8 def separator(debugger, i, node, *ths): |
8 print "===================" | 9 print "===================" |
62 output_ds = fprop(training_set) | 63 output_ds = fprop(training_set) |
63 | 64 |
64 for fieldname in output_ds.fieldNames(): | 65 for fieldname in output_ds.fieldNames(): |
65 print fieldname+"=",output_ds[fieldname] | 66 print fieldname+"=",output_ds[fieldname] |
66 | 67 |
67 test0() | 68 def test1(): |
69 nnet = ManualNNet(2, 10,3,.1,1000) | |
70 training_set = dataset.ArrayDataSet(numpy.array([[0, 0, 0], | |
71 [0, 1, 1], | |
72 [1, 0, 1], | |
73 [1, 1, 1]]), | |
74 {'input':slice(2),'target':2}) | |
75 fprop=nnet(training_set) | |
68 | 76 |
77 output_ds = fprop(training_set) | |
78 | |
79 for fieldname in output_ds.fieldNames(): | |
80 print fieldname+"=",output_ds[fieldname] | |
81 | |
82 def test2(): | |
83 training_set = dataset.ArrayDataSet(numpy.array([[0, 0, 0], | |
84 [0, 1, 1], | |
85 [1, 0, 1], | |
86 [1, 1, 1]]), | |
87 {'input':slice(2),'target':2}) | |
88 nin, nhid=2, 10 | |
89 def sigm_layer(input): | |
90 W1 = t.matrix('W1') | |
91 b1 = t.vector('b1') | |
92 return (nnet_ops.sigmoid(b1 + t.dot(input, W1)), | |
93 [W1, b1], | |
94 [(numpy.random.rand(nin, nhid) -0.5) * 0.001, numpy.zeros(nhid)]) | |
95 nnet = ManualNNet(nin, nhid, 3, .1, 1000, hidden_layer=sigm_layer) | |
96 fprop=nnet(training_set) | |
97 | |
98 output_ds = fprop(training_set) | |
99 | |
100 for fieldname in output_ds.fieldNames(): | |
101 print fieldname+"=",output_ds[fieldname] | |
102 test1() | |
103 test2() | |
104 |