comparison _test_nnet_ops.py @ 419:43d9aa93934e

added other_ops.py to nnet_ops; added basic tests, no docs.
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 14 Jul 2008 16:48:02 -0400
parents 2ee53bae9ee0
children 9cfc2fc0f4d1
comparison
equal deleted inserted replaced
418:2ea14774eb07 419:43d9aa93934e
1 1
2 import unittest 2 import unittest
3 import theano
3 import theano._test_tensor as TT 4 import theano._test_tensor as TT
4 import numpy 5 import numpy
5 6
6 from nnet_ops import * 7 from nnet_ops import *
7 8
33 class Dummy(object): 34 class Dummy(object):
34 def make_node(self, a): 35 def make_node(self, a):
35 return crossentropy_softmax_1hot(a, y_idx)[0:1] 36 return crossentropy_softmax_1hot(a, y_idx)[0:1]
36 TT.verify_grad(self, Dummy(), [numpy.random.rand(3,4)]) 37 TT.verify_grad(self, Dummy(), [numpy.random.rand(3,4)])
37 38
39 class T_prepend(unittest.TestCase):
40 def test0(self):
41 """basic functionality"""
42 x=tensor.matrix('x')
43 y=Prepend_scalar_constant_to_each_row(4.)(x)
44 f=theano.function([x],[y])
45 m=numpy.random.rand(3,5)
46 my = f(m)
47 self.failUnless(my.shape == (3, 6), my.shape)
48 self.failUnless(numpy.all( my[:,0] == 4.0))
49
50
51 class T_prepend(unittest.TestCase):
52 def test0(self):
53 """basic functionality"""
54 x=tensor.matrix('x')
55 y=Prepend_scalar_to_each_row()(5.,x)
56 f=theano.function([x],[y])
57 m=numpy.ones((3,5),dtype="float32")
58 my = f(m)
59 self.failUnless(str(my.dtype) == 'float64')
60 self.failUnless(my.shape == (3, 6))
61 self.failUnless(numpy.all(my[:,0] == 5.0))
62
63 class T_solve(unittest.TestCase):
64 def setUp(self):
65 self.rng = numpy.random.RandomState(666)
66
67 def test0(self):
68 A=self.rng.randn(5,5)
69 b=numpy.array(range(5),dtype=float)
70 x=numpy.linalg.solve(A,b)
71 Ax = numpy.dot(A,x)
72 are = theano.gradient.numeric_grad.abs_rel_err(Ax, b)
73 self.failUnless(numpy.all(are < 1.0e-5), (are, Ax, b))
74 #print A,b
75 #print numpy.dot(A,x)
38 76
39 77
40 if __name__ == '__main__': 78 if __name__ == '__main__':
41 unittest.main() 79 unittest.main()