Mercurial > pylearn
view pylearn/algorithms/tests/test_exponential_mean.py @ 1508:b28e8730c948
fix test.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Mon, 12 Sep 2011 11:45:56 -0400 |
parents | be420f1836bb |
children |
line wrap: on
line source
import theano, numpy from theano.compile.debugmode import DebugMode from pylearn.algorithms import exponential_mean def test_mean(): x = theano.tensor.dvector() rows_to_test = 10 D = exponential_mean.exp_mean(x, (4,), rows_to_test) D.f = theano.Method([x], D.curval, D.updates()) d = D.make() rng = numpy.random.RandomState(3284) xval = rng.randn(rows_to_test+3,4) for i, xrow in enumerate(xval): dmean = d.f(xrow) nmean = numpy.mean(xval[:i+1], axis=0) if i < rows_to_test: assert numpy.allclose(dmean, nmean) else: assert not numpy.allclose(dmean, nmean) assert i > rows_to_test def test_var(): x = theano.tensor.dvector() rows_to_test = 10 D = exponential_mean.exp_var(x, (4,), rows_to_test) D.f = theano.Method([x], D.curval, D.updates()) d = D.make() rng = numpy.random.RandomState(3284) xval = rng.randn(rows_to_test+3,4) for i, xrow in enumerate(xval): dmean = d.f(xrow) nmean = numpy.var(xval[:i+1], axis=0) if i < rows_to_test: assert numpy.allclose(dmean, nmean) else: assert not numpy.allclose(dmean, nmean) assert i > rows_to_test def test_dynamic_normalizer(): mode = theano.compile.mode.get_default_mode() if isinstance(mode,DebugMode): mode = 'FAST_RUN' x = theano.tensor.dvector() rows_to_test = 100 cols=2 D = exponential_mean.DynamicNormalizer(x, (cols,), rows_to_test) M = theano.Module() M.dn = D M.dn_mean = exponential_mean.exp_mean(D.output, (cols,), 50) M.dn_var = exponential_mean.exp_var(D.output, (cols,), 50) M.x_mean = exponential_mean.exp_mean(x, (cols,), 10) updates = D.updates() #print len(updates) updates.update(M.dn_mean.updates()) #print len(updates) updates.update(M.dn_var.updates()) #print len(updates) updates.update(M.x_mean.updates()) #print len(updates) M.f = theano.Method([x], [D.output, M.dn_mean.curval, M.dn_var.curval, M.x_mean.curval] , updates) m = M.make(mode=mode) m.dn.initialize() m.dn_mean.initialize() m.dn_var.initialize() m.x_mean.initialize() rng = numpy.random.RandomState(3284) xval = rng.rand(rows_to_test+100,cols) for i, xrow in enumerate(xval): n_x = m.f(xrow) #print n_x assert numpy.all(numpy.abs(n_x[1]) < 0.15) # the means should be close to 0 assert numpy.all(numpy.abs(n_x[2]-1) < 0.07) # the variance should be close to 1.0 assert i > rows_to_test