Mercurial > pylearn
annotate pylearn/algorithms/rnn.py @ 784:ba65e95d1221
removed manual call to Member and Variable as this is deprecated in theano.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 29 Jun 2009 09:49:28 -0400 |
parents | a4f65f1d2b18 |
children |
rev | line source |
---|---|
587
a4f65f1d2b18
made the file executable.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
581
diff
changeset
|
1 #!/usr/bin/env python |
550 | 2 import numpy as N |
784
ba65e95d1221
removed manual call to Member and Variable as this is deprecated in theano.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
587
diff
changeset
|
3 from theano import Op, Apply, tensor as T, Module, Method, Mode, compile |
550 | 4 from theano.gof import OpSub, TopoOptimizer |
5 | |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
6 from minimizer import make_minimizer # minimizer |
550 | 7 from theano.printing import Print |
8 import sgd #until Olivier's module-import thing works better | |
9 | |
10 #################### | |
11 # Library-type stuff | |
12 #################### | |
13 | |
14 class TanhRnn(Op): | |
15 """ | |
16 This class implements the recurrent part of a recurrent neural network. | |
17 | |
18 There is not a neat way to include this in a more fine-grained way in Theano at the moment, | |
19 so to get something working, I'm implementing a relatively complicated Op that could be | |
20 broken down later into constituents. | |
21 | |
22 Anyway, this Op implements recursive computation of the form: | |
23 | |
24 .. latex-eqn: | |
25 z_t &= \tanh( z_{t-1} A + x_{t-1}) | |
26 | |
27 For z0 a vector, and x a TxM matrix, it returns a matrix z of shape (T+1, M), | |
28 in which z[0] = z0. | |
29 | |
30 """ | |
31 | |
32 def make_node(self, x, z0, A): | |
33 """ | |
34 :type x: matrix (each row is an x_t) (shape: (T, M)) | |
35 :type z0: vector (the first row of output) (shape: M) | |
36 :type A: matrix (M by M) | |
37 | |
38 """ | |
39 x = T.as_tensor(x) | |
40 z0 = T.as_tensor(z0) | |
41 A = T.as_tensor(A) | |
42 z = x.type() #make a new symbolic result with the same type as x | |
43 return Apply(self, [x, z0, A], [z]) | |
44 | |
45 def perform(self, node, (x,z0,A), out): | |
46 T,M = x.shape | |
47 z = N.zeros((T+1, M)) | |
48 z[0] = z0 | |
49 for i in xrange(T): | |
50 z[i+1] = N.tanh(N.dot(z[i], A) + x[i]) | |
51 out[0][0] = z | |
52 | |
53 def grad(self, (x, z0, A), (gz,)): | |
54 z = tanh_rnn(x, z0, A) | |
55 gz_incl_rnn, gx = tanh_rnn_grad(A, z, gz) | |
56 return [gx, gz_incl_rnn[0], (T.dot(z[:-1].T, gx))] | |
57 tanh_rnn = TanhRnn() | |
58 | |
59 class TanhRnnGrad(Op): | |
60 """Gradient calculation for TanhRnn""" | |
61 | |
62 def __init__(self, inplace): | |
63 self.inplace = inplace | |
64 | |
65 if self.inplace: | |
66 self.destroy_map = {0: [2]} | |
67 | |
68 def __eq__(self, other): | |
69 return (type(self) == type(other)) and (self.inplace == other.inplace) | |
70 | |
71 def __hash__(self, other): | |
72 return hash(type(self)) ^ hash(self.inplace) | |
73 | |
74 def make_node(self, A, z, gz): | |
75 return Apply(self, [A,z,gz], (z.type(), gz.type())) | |
76 | |
77 def perform(self, node, (A, z, gz), out): | |
78 Tp1,M = z.shape | |
79 T = Tp1 - 1 | |
80 gx = N.zeros((T, M)) | |
81 | |
82 if not self.inplace: | |
83 gz = gz.copy() | |
84 | |
85 for i in xrange(T-1, -1, -1): | |
86 #back through the tanh | |
87 gx[i] = gz[i+1] * (1.0 - z[i+1] * z[i+1]) | |
88 gz[i] += N.dot(A, gx[i]) | |
89 | |
90 out[0][0] = gz | |
91 out[1][0] = gx | |
92 | |
93 def __str__(self): | |
94 if self.inplace: | |
95 return 'Inplace' + super(TanhRnnGrad, self).__str__() | |
96 else: | |
97 return super(TanhRnnGrad, self).__str__() | |
98 | |
99 tanh_rnn_grad = TanhRnnGrad(inplace=False) | |
100 tanh_rnn_grad_inplace = TanhRnnGrad(inplace=True) | |
101 | |
102 compile.optdb.register('inplace_rnngrad', TopoOptimizer(OpSub(tanh_rnn_grad, tanh_rnn_grad_inplace)), 60, 'fast_run', 'inplace') | |
103 | |
104 | |
105 ####################### | |
106 # Experiment-type stuff | |
107 ####################### | |
108 | |
109 | |
110 | |
111 class ExampleRNN(Module): | |
112 | |
113 def __init__(self, n_vis, n_hid, n_out, minimizer): | |
114 super(ExampleRNN, self).__init__() | |
115 | |
116 def affine(weight, bias): | |
117 return (lambda a : T.dot(a, weight) + bias) | |
118 | |
119 self.n_vis = n_vis | |
120 self.n_hid = n_hid | |
121 self.n_out = n_out | |
122 | |
123 #affine transformatoin x -> latent space | |
784
ba65e95d1221
removed manual call to Member and Variable as this is deprecated in theano.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
587
diff
changeset
|
124 self.v, self.b = T.dmatrix(), T.dvector() |
550 | 125 input_transform = affine(self.v, self.b) |
126 | |
127 #recurrent weight matrix in latent space | |
784
ba65e95d1221
removed manual call to Member and Variable as this is deprecated in theano.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
587
diff
changeset
|
128 self.z0 = T.dvector() |
ba65e95d1221
removed manual call to Member and Variable as this is deprecated in theano.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
587
diff
changeset
|
129 self.w = T.dmatrix() |
550 | 130 |
131 #affine transformation latent -> output space | |
784
ba65e95d1221
removed manual call to Member and Variable as this is deprecated in theano.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
587
diff
changeset
|
132 self.u, self.c = T.dmatrix(), T.dvector() |
550 | 133 output_transform = affine(self.u, self.c) |
134 | |
135 self.params = [self.v, self.b, self.w, self.u, self.c] | |
136 | |
137 #input and target | |
138 x, y = T.dmatrix(), T.dmatrix() | |
139 | |
140 z = tanh_rnn(input_transform(x), self.z0, self.w) | |
141 yhat = output_transform(z[1:]) | |
142 self.cost = T.sum((y - yhat)**2) | |
143 | |
144 self.blah = Method([x,y], self.cost) | |
145 | |
146 # using the make_minimizer protocol | |
147 self.minimizer = minimizer([x, y], self.cost, self.params) | |
148 | |
149 def _instance_initialize(self, obj): | |
150 n_vis = self.n_vis | |
151 n_hid = self.n_hid | |
152 n_out = self.n_out | |
153 | |
154 rng = N.random.RandomState(2342) | |
155 | |
156 obj.z0 = N.zeros(n_hid) | |
157 obj.v = rng.randn(n_vis, n_hid) * 0.01 | |
158 obj.b = N.zeros(n_hid) | |
159 obj.w = rng.randn(n_hid, n_hid) * 0.01 | |
160 obj.u = rng.randn(n_hid, n_out) * 0.01 | |
161 obj.c = N.zeros(n_out) | |
162 obj.minimizer.initialize() | |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
163 def _instance__eq__(self, other): |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
164 if not isinstance(other.component, ExampleRNN): |
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
165 raise NotImplemented |
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
166 #we compare the member. |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
167 # if self.n_vis != other.n_vis or slef.n_hid != other.n_hid or self.n_out != other.n_out: |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
168 # return False |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
169 if (N.abs(self.z0-other.z0)<1e-8).all() and (N.abs(self.v-other.v)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.u-other.u)<1e-8).all() and (N.abs(self.c-other.c)<1e-8).all() and (N.abs(self.z0-other.z0)<1e-8).all(): |
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
170 return True |
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
171 return False |
550 | 172 |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
173 def _instance__hash__(self): |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
174 raise NotImplemented |
550 | 175 |
176 def test_example_rnn(): | |
177 minimizer_fn = make_minimizer('sgd', stepsize = 0.001) | |
178 | |
179 n_vis = 5 | |
180 n_out = 3 | |
181 n_hid = 4 | |
182 rnn_module = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn) | |
183 | |
184 rnn = rnn_module.make(mode='FAST_RUN') | |
185 | |
186 rng = N.random.RandomState(7722342) | |
187 x = rng.randn(10,n_vis) | |
188 y = rng.randn(10,n_out) | |
189 | |
190 #set y to be like x with a lag of LAG | |
191 LAG = 4 | |
192 y[LAG:] = x[:-LAG, 0:n_out] | |
193 | |
194 if 1: | |
195 for i, node in enumerate(rnn.minimizer.step_cost.maker.env.toposort()): | |
196 print i, node | |
197 | |
198 niter=1500 | |
199 for i in xrange(niter): | |
200 if i % 100 == 0: | |
201 print i, rnn.minimizer.step_cost(x, y), rnn.minimizer.stepsize | |
202 else: | |
203 rnn.minimizer.step_cost(x, y) | |
204 | |
560
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
205 def test_WEIRD_STUFF(): |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
206 n_vis = 5 |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
207 n_out = 3 |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
208 n_hid = 4 |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
209 rng = N.random.RandomState(7722342) |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
210 x = rng.randn(10,n_vis) |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
211 y = rng.randn(10,n_out) |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
212 |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
213 #set y to be like x with a lag of LAG |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
214 LAG = 4 |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
215 y[LAG:] = x[:-LAG, 0:n_out] |
96221aa02fcb
put the new test in a different test fct.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
559
diff
changeset
|
216 |
581
01e04bf878e2
removed some code that is not needed anymore as the bug is fixed. I will add a test in module later.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
580
diff
changeset
|
217 minimizer_fn1 = make_minimizer('sgd', stepsize = 0.001) |
01e04bf878e2
removed some code that is not needed anymore as the bug is fixed. I will add a test in module later.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
580
diff
changeset
|
218 minimizer_fn2 = make_minimizer('sgd', stepsize = 0.001) |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
219 rnn_module1 = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn1) |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
220 rnn_module2 = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn2) |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
221 rnn1 = rnn_module2.make(mode='FAST_RUN') |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
222 rnn2 = rnn_module1.make(mode='FAST_COMPILE') |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
223 if 0: |
581
01e04bf878e2
removed some code that is not needed anymore as the bug is fixed. I will add a test in module later.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
580
diff
changeset
|
224 topo1=rnn1.minimizer.step_cost.maker.env.toposort() |
01e04bf878e2
removed some code that is not needed anymore as the bug is fixed. I will add a test in module later.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
580
diff
changeset
|
225 topo2=rnn2.minimizer.step_cost.maker.env.toposort() |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
226 for i in range(len(topo1)): |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
227 print '1',i, topo1[i] |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
228 print '2',i, topo2[i] |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
229 |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
230 |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
231 |
581
01e04bf878e2
removed some code that is not needed anymore as the bug is fixed. I will add a test in module later.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
580
diff
changeset
|
232 niter=50 |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
233 for i in xrange(niter): |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
234 rnn1.minimizer.step(x, y) |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
235 rnn2.minimizer.step(x, y) |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
236 |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
237 # assert rnn1.n_vis != rnn2.n_vis or slef.n_hid != rnn2.n_hid or rnn1.n_out != rnn2.n_out |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
238 assert (N.abs(rnn1.z0-rnn2.z0)<1e-8).all() |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
239 assert (N.abs(rnn1.v-rnn2.v)<1e-8).all() and (N.abs(rnn1.b-rnn2.b)<1e-8).all() and (N.abs(rnn1.w-rnn2.w)<1e-8).all() and (N.abs(rnn1.u-rnn2.u)<1e-8).all() and (N.abs(rnn1.c-rnn2.c)<1e-8).all() |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
240 |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
241 # assert b |
559
83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
550
diff
changeset
|
242 |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
243 if __name__ == '__main__': |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
244 # from theano.tests import main |
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
245 # main(__file__) |
580
1972bc9bea6d
added one test.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
574
diff
changeset
|
246 test_example_rnn() |
574
220044be9fd8
added test for a bug that James reported to me.
Frederic Bastien <bastienf@iro.umontreal.ca>
parents:
560
diff
changeset
|
247 test_WEIRD_STUFF() |