Mercurial > pylearn
view pylearn/algorithms/sandbox/test_cost.py @ 1506:2f69c9932d9a
Fix test in float32.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Mon, 12 Sep 2011 10:56:38 -0400 |
parents | ba068f7d4d3e |
children |
line wrap: on
line source
import pylearn.algorithms.sandbox.cost as cost import unittest import theano as T import theano.tensor as TT import numpy class T_logfactorial(unittest.TestCase): def test(self): x = TT.as_tensor(range(10)) o = cost.logfactorial(x) f = T.function([],o) self.failUnless(numpy.all(f() - numpy.asarray([0., 0., 1.38629436, 3.29583687, 5.54517744, 8.04718956, 10.75055682, 13.62137104, 16.63553233, 19.7750212])) < 1e-5) def test_float(self): """ Ensure we cannot use floats in logfactorial. """ x = TT.as_tensor([0.5, 2.7]) o = cost.logfactorial(x) f = T.function([], o) try: f() assert False except TypeError, e: if str(e).find("<type 'float'>, must be int or long") >= 0: pass else: raise class T_nlpoisson(unittest.TestCase): def test(self): target = TT.as_tensor([0, 0, 1, 1, 2, 2, 100, 100]) output = TT.as_tensor([0., 1, 1., 0, 1, 0, 5, 1]) o = cost.nlpoisson(target, output) f = T.function([],o) self.failUnless(f() - 33751.7816277 < 1e-5) def test_gradient(self): target = TT.as_tensor([0, 0, 1, 1, 2, 2, 100, 100]) output = TT.as_tensor([0., 1, 1., 0, 1, 0, 5, 1]) loss = cost.nlpoisson(target, output) (goutput) = TT.grad(loss, [output]) # (goutput) = TT.grad(loss, [target]) f = T.function([], goutput) print f() self.failUnless(numpy.all(f() - numpy.asarray([206., 559.96605666, 558.96605666, 205., 557.96605666, 204., 30473.11077513, 459.96605666] < 1e-5))) def test_gradient_fail(self): target = TT.as_tensor([0, 0, 1, 1, 2, 2, 100, 100]) output = TT.as_tensor([0., 1, 1., 0, 1, 0, 5, 1]) loss = cost.nlpoisson(target, output) (goutput) = TT.grad(loss, [target]) f = T.function([], goutput) print f() self.failUnless(numpy.all(f() - numpy.asarray([206., 559.96605666, 558.96605666, 205., 557.96605666, 204., 30473.11077513, 459.96605666] < 1e-5))) if __name__ == '__main__': unittest.main()