Mercurial > pylearn
changeset 559:83ebb313b2f1
added a test for the WEIRD_STUFF flag in theano ticket 239
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 01 Dec 2008 16:15:10 -0500 |
parents | b58e71878bb5 |
children | 96221aa02fcb |
files | pylearn/algorithms/rnn.py |
diffstat | 1 files changed, 32 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/algorithms/rnn.py Mon Dec 01 16:14:14 2008 -0500 +++ b/pylearn/algorithms/rnn.py Mon Dec 01 16:15:10 2008 -0500 @@ -160,7 +160,18 @@ obj.u = rng.randn(n_hid, n_out) * 0.01 obj.c = N.zeros(n_out) obj.minimizer.initialize() + def __eq__(self, other): + if not isinstance(other.component, ExampleRNN): + raise NotImplemented + #we compare the member. + if self.n_vis != other.n_vis or slef.n_hid != other.n_hid or self.n_out != other.n_out: + return False + if (N.abs(self.z0-other.z0)<1e-8).all() and (N.abs(self.v-other.v)<1e-8).all() and (N.abs(self.b-other.b)<1e-8).all() and (N.abs(self.w-other.w)<1e-8).all() and (N.abs(self.u-other.u)<1e-8).all() and (N.abs(self.c-other.c)<1e-8).all() and (N.abs(self.z0-other.z0)<1e-8).all(): + return True + return False + def __hash__(self): + raise NotImplemented def test_example_rnn(): minimizer_fn = make_minimizer('sgd', stepsize = 0.001) @@ -191,3 +202,24 @@ else: rnn.minimizer.step_cost(x, y) + + minimizer_fn = make_minimizer('sgd', stepsize = 0.001, WEIRD_STUFF = False) + rnn_module = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn) + + rnn1 = rnn_module.make(mode='FAST_RUN') + + rng1 = N.random.RandomState(7722342) + + niter=15 + for i in xrange(niter): + rnn1.minimizer.step_cost(x, y) + + minimizer_fn = make_minimizer('sgd', stepsize = 0.001, WEIRD_STUFF = True) + + rnn_module = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn) + rnn2 = rnn_module.make(mode='FAST_RUN') + + for i in xrange(niter): + rnn2.minimizer.step_cost(x, y) + + assert rnn1 == rnn2