# HG changeset patch # User James Bergstra # Date 1216068482 14400 # Node ID 43d9aa93934ed15d00690e2cc0331e244df25774 # Parent 2ea14774eb075f7f7ab0bf1f42d4f8f5ea69929d added other_ops.py to nnet_ops; added basic tests, no docs. diff -r 2ea14774eb07 -r 43d9aa93934e _test_nnet_ops.py --- a/_test_nnet_ops.py Mon Jul 14 13:48:41 2008 -0400 +++ b/_test_nnet_ops.py Mon Jul 14 16:48:02 2008 -0400 @@ -1,5 +1,6 @@ import unittest +import theano import theano._test_tensor as TT import numpy @@ -35,6 +36,43 @@ return crossentropy_softmax_1hot(a, y_idx)[0:1] TT.verify_grad(self, Dummy(), [numpy.random.rand(3,4)]) +class T_prepend(unittest.TestCase): + def test0(self): + """basic functionality""" + x=tensor.matrix('x') + y=Prepend_scalar_constant_to_each_row(4.)(x) + f=theano.function([x],[y]) + m=numpy.random.rand(3,5) + my = f(m) + self.failUnless(my.shape == (3, 6), my.shape) + self.failUnless(numpy.all( my[:,0] == 4.0)) + + +class T_prepend(unittest.TestCase): + def test0(self): + """basic functionality""" + x=tensor.matrix('x') + y=Prepend_scalar_to_each_row()(5.,x) + f=theano.function([x],[y]) + m=numpy.ones((3,5),dtype="float32") + my = f(m) + self.failUnless(str(my.dtype) == 'float64') + self.failUnless(my.shape == (3, 6)) + self.failUnless(numpy.all(my[:,0] == 5.0)) + +class T_solve(unittest.TestCase): + def setUp(self): + self.rng = numpy.random.RandomState(666) + + def test0(self): + A=self.rng.randn(5,5) + b=numpy.array(range(5),dtype=float) + x=numpy.linalg.solve(A,b) + Ax = numpy.dot(A,x) + are = theano.gradient.numeric_grad.abs_rel_err(Ax, b) + self.failUnless(numpy.all(are < 1.0e-5), (are, Ax, b)) + #print A,b + #print numpy.dot(A,x) if __name__ == '__main__': diff -r 2ea14774eb07 -r 43d9aa93934e linear_regression.py --- a/linear_regression.py Mon Jul 14 13:48:41 2008 -0400 +++ b/linear_regression.py Mon Jul 14 16:48:02 2008 -0400 @@ -6,7 +6,7 @@ from pylearn.learner import OfflineLearningAlgorithm from theano import tensor as T -from theano.sandbox.others_ops import prepend_1_to_each_row +from nnet_ops import prepend_1_to_each_row from theano.scalar import as_scalar from common.autoname import AutoName import theano diff -r 2ea14774eb07 -r 43d9aa93934e nnet_ops.py --- a/nnet_ops.py Mon Jul 14 13:48:41 2008 -0400 +++ b/nnet_ops.py Mon Jul 14 16:48:02 2008 -0400 @@ -1,3 +1,6 @@ +## This file contain ops that are not currently integrated in the core of threano. +## Not all of those ops have been thoroughly tested. + import theano from theano import tensor, scalar import numpy @@ -387,3 +390,104 @@ @todo: Rewrite as a scalar, and then broadcast to tensor. """ return -(target * tensor.log(output) + (1 - target) * tensor.log(1 - output)) + + + +class Prepend_scalar_constant_to_each_row(theano.Op): + def __init__(self, val = 0): + if isinstance(val, float): + val = scalar.constant(val) + self.val = val + + def make_node(self, mat): + #check type of input + if not isinstance(mat,theano.Result) or not mat.type==tensor.matrix().type: + raise TypeError("Expected a matrix as input") + x = tensor.as_tensor(mat) + y = tensor.as_tensor(self.val) + if x.type.dtype != y.type.dtype: + TypeError("the value to prepend don't have the same type as the matrix") + + node = theano.Apply(op=self, inputs=[mat], outputs=[tensor.matrix()]) + return node + + def perform(self, node, (mat, ), (output, )): + new_shape=(mat.shape[0],mat.shape[1]+1) + if output[0] == None: + output[0]=numpy.empty(new_shape,dtype=mat.dtype) + out=output[0] + else: + if output[0].shape!=new_shape: + try: + output[0].resize(new_shape) + except: + output[0]=numpy.empty(new_shape, dtype=mat.dtype) + out=output[0] + + out[:,0].fill(self.val.data) + out[:,1:]=mat + + def grad(self, (mat,), (goutput,)): + return goutput[:,1:] + +class Prepend_scalar_to_each_row(theano.Op): + def make_node(self, val, mat): + #check type of input + if isinstance(val, float): + val = scalar.constant(val) + if not isinstance(mat,theano.Result) or not mat.type==tensor.matrix().type: + raise TypeError("Expected a matrix as input") + x = tensor.as_tensor(mat) + y = tensor.as_tensor(val) + if x.type.dtype != y.type.dtype: + TypeError("the value to prepend don't have the same type as the matrix") + + node = theano.Apply(op=self, inputs=[val,mat], outputs=[tensor.matrix()]) + return node + + def perform(self, node, (val,mat), (output, )): + new_shape=(mat.shape[0],mat.shape[1]+1) + if output[0] == None: + output[0]=numpy.empty(new_shape,dtype=mat.dtype) + out=output[0] + else: + if output[0].shape!=new_shape: + try: + output[0].resize(new_shape) + except: + output[0]=numpy.empty(new_shape, dtype=mat.dtype) + out=output[0] + out[:,0].fill(val) + out[:,1:]=mat + + def grad(self, (val, mat), (goutput,)): + return goutput[:,0], goutput[:,1:] + +prepend_scalar_to_each_row = Prepend_scalar_to_each_row() +prepend_0_to_each_row = Prepend_scalar_constant_to_each_row(0.) +prepend_1_to_each_row = Prepend_scalar_constant_to_each_row(1.) + +class solve(theano.Op): + """ + Find the solution to the linear equation Ax=b, + where A is a 2d matrix and b is a 1d or 2d matrix. + It use numpy.solve to find the solution. + """ + + def make_node(self, A, b): + if not isinstance(A, theano.Result) or not A.type==tensor.matrix().type: + raise TypeError("We expected that A had a matrix type") + if not isinstance(B, theano.Result) or not B.type==tensor.matrix().type: + raise TypeError("We expected that B had a matrix type") + + node = theano.Apply(op=self, inputs=[A, B], outputs=[tensor.matrix()]) + return node + + def perform(self, node, (A, B), (output, )): + ret=numpy.solve(A,B) + output[0]=ret + + def grad(self, (theta, A, B), (gtheta,)): + raise NotImplementedError() + +