# HG changeset patch # User Frederic Bastien # Date 1228166110 18000 # Node ID 83ebb313b2f1b6d20c3203571e05fa01fc05ea38 # Parent b58e71878bb5cac0f18e95a50a05619ae03a73be added a test for the WEIRD_STUFF flag in theano ticket 239 diff -r b58e71878bb5 -r 83ebb313b2f1 pylearn/algorithms/rnn.py --- a/pylearn/algorithms/rnn.py Mon Dec 01 16:14:14 2008 -0500 +++ b/pylearn/algorithms/rnn.py Mon Dec 01 16:15:10 2008 -0500 @@ -160,7 +160,18 @@ obj.u = rng.randn(n_hid, n_out) * 0.01 obj.c = N.zeros(n_out) obj.minimizer.initialize() + def __eq__(self, other): + if not isinstance(other.component, ExampleRNN): + raise NotImplemented + #we compare the member. + if self.n_vis != other.n_vis or slef.n_hid != other.n_hid or self.n_out != other.n_out: + return False + if (N.abs(self.z0-other.z0)<1e-8).all() and (N.abs(self.v-other.v)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.u-other.u)<1e-8).all() and (N.abs(self.c-other.c)<1e-8).all() and (N.abs(self.z0-other.z0)<1e-8).all(): + return True + return False + def __hash__(self): + raise NotImplemented def test_example_rnn(): minimizer_fn = make_minimizer('sgd', stepsize = 0.001) @@ -191,3 +202,24 @@ else: rnn.minimizer.step_cost(x, y) + + minimizer_fn = make_minimizer('sgd', stepsize = 0.001, WEIRD_STUFF = False) + rnn_module = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn) + + rnn1 = rnn_module.make(mode='FAST_RUN') + + rng1 = N.random.RandomState(7722342) + + niter=15 + for i in xrange(niter): + rnn1.minimizer.step_cost(x, y) + + minimizer_fn = make_minimizer('sgd', stepsize = 0.001, WEIRD_STUFF = True) + + rnn_module = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn) + rnn2 = rnn_module.make(mode='FAST_RUN') + + for i in xrange(niter): + rnn2.minimizer.step_cost(x, y) + + assert rnn1 == rnn2