Mercurial > pylearn
comparison _test_nnet_ops.py @ 435:eac0a7d44ff0
merge
author | Olivier Breuleux <breuleuo@iro.umontreal.ca> |
---|---|
date | Mon, 04 Aug 2008 16:29:30 -0400 |
parents | 43d9aa93934e |
children | 9cfc2fc0f4d1 |
comparison
equal
deleted
inserted
replaced
434:0f366ecb11ee | 435:eac0a7d44ff0 |
---|---|
1 | 1 |
2 import unittest | 2 import unittest |
3 import theano | |
3 import theano._test_tensor as TT | 4 import theano._test_tensor as TT |
4 import numpy | 5 import numpy |
5 | 6 |
6 from nnet_ops import * | 7 from nnet_ops import * |
7 | 8 |
33 class Dummy(object): | 34 class Dummy(object): |
34 def make_node(self, a): | 35 def make_node(self, a): |
35 return crossentropy_softmax_1hot(a, y_idx)[0:1] | 36 return crossentropy_softmax_1hot(a, y_idx)[0:1] |
36 TT.verify_grad(self, Dummy(), [numpy.random.rand(3,4)]) | 37 TT.verify_grad(self, Dummy(), [numpy.random.rand(3,4)]) |
37 | 38 |
39 class T_prepend(unittest.TestCase): | |
40 def test0(self): | |
41 """basic functionality""" | |
42 x=tensor.matrix('x') | |
43 y=Prepend_scalar_constant_to_each_row(4.)(x) | |
44 f=theano.function([x],[y]) | |
45 m=numpy.random.rand(3,5) | |
46 my = f(m) | |
47 self.failUnless(my.shape == (3, 6), my.shape) | |
48 self.failUnless(numpy.all( my[:,0] == 4.0)) | |
49 | |
50 | |
51 class T_prepend(unittest.TestCase): | |
52 def test0(self): | |
53 """basic functionality""" | |
54 x=tensor.matrix('x') | |
55 y=Prepend_scalar_to_each_row()(5.,x) | |
56 f=theano.function([x],[y]) | |
57 m=numpy.ones((3,5),dtype="float32") | |
58 my = f(m) | |
59 self.failUnless(str(my.dtype) == 'float64') | |
60 self.failUnless(my.shape == (3, 6)) | |
61 self.failUnless(numpy.all(my[:,0] == 5.0)) | |
62 | |
63 class T_solve(unittest.TestCase): | |
64 def setUp(self): | |
65 self.rng = numpy.random.RandomState(666) | |
66 | |
67 def test0(self): | |
68 A=self.rng.randn(5,5) | |
69 b=numpy.array(range(5),dtype=float) | |
70 x=numpy.linalg.solve(A,b) | |
71 Ax = numpy.dot(A,x) | |
72 are = theano.gradient.numeric_grad.abs_rel_err(Ax, b) | |
73 self.failUnless(numpy.all(are < 1.0e-5), (are, Ax, b)) | |
74 #print A,b | |
75 #print numpy.dot(A,x) | |
38 | 76 |
39 | 77 |
40 if __name__ == '__main__': | 78 if __name__ == '__main__': |
41 unittest.main() | 79 unittest.main() |