changeset 559:83ebb313b2f1

added a test for the WEIRD_STUFF flag in theano ticket 239
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 01 Dec 2008 16:15:10 -0500
parents b58e71878bb5
children 96221aa02fcb
files pylearn/algorithms/rnn.py
diffstat 1 files changed, 32 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/rnn.py	Mon Dec 01 16:14:14 2008 -0500
+++ b/pylearn/algorithms/rnn.py	Mon Dec 01 16:15:10 2008 -0500
@@ -160,7 +160,18 @@
         obj.u = rng.randn(n_hid, n_out) * 0.01
         obj.c = N.zeros(n_out)
         obj.minimizer.initialize()
+    def __eq__(self, other):
+        if not isinstance(other.component, ExampleRNN):
+            raise NotImplemented
+         #we compare the member.
+        if self.n_vis != other.n_vis or slef.n_hid != other.n_hid or self.n_out != other.n_out:
+            return False
+        if (N.abs(self.z0-other.z0)<1e-8).all() and (N.abs(self.v-other.v)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.u-other.u)<1e-8).all() and (N.abs(self.c-other.c)<1e-8).all() and (N.abs(self.z0-other.z0)<1e-8).all():
+            return True
+        return False
 
+    def __hash__(self):
+        raise NotImplemented
 
 def test_example_rnn():
     minimizer_fn = make_minimizer('sgd', stepsize = 0.001)
@@ -191,3 +202,24 @@
         else:
             rnn.minimizer.step_cost(x, y)
 
+    
+    minimizer_fn = make_minimizer('sgd', stepsize = 0.001, WEIRD_STUFF = False)
+    rnn_module = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn)
+
+    rnn1 = rnn_module.make(mode='FAST_RUN')
+
+    rng1 = N.random.RandomState(7722342)
+
+    niter=15
+    for i in xrange(niter):
+        rnn1.minimizer.step_cost(x, y)
+
+    minimizer_fn = make_minimizer('sgd', stepsize = 0.001, WEIRD_STUFF = True)
+
+    rnn_module = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn)
+    rnn2 = rnn_module.make(mode='FAST_RUN')
+
+    for i in xrange(niter):
+        rnn2.minimizer.step_cost(x, y)
+
+    assert rnn1 == rnn2