# HG changeset patch # User Joseph Turian # Date 1236744562 14400 # Node ID 576eb7f77c35f249278b950fb5a8bb7884aa39dc # Parent b282a5c2f76b6d4f15796203aa8ae9d9b9bdcdd6 Trying to fix gradient in logfactorial diff -r b282a5c2f76b -r 576eb7f77c35 pylearn/algorithms/sandbox/cost.py --- a/pylearn/algorithms/sandbox/cost.py Tue Mar 10 20:58:16 2009 -0400 +++ b/pylearn/algorithms/sandbox/cost.py Wed Mar 11 00:09:22 2009 -0400 @@ -9,6 +9,18 @@ from theano import tensor, scalar import numpy +class UndefinedGradient(Exception): + """ + Raised by UndefinedGradientOp to indicate that the gradient is undefined mathematically. + """ + pass +from theano import gof +class UndefinedGradientOp(gof.Op): + def perform(self, x=None): + if x is not None: raise UndefinedGradient(x) + else: raise UndefinedGradient(x) +undefined_gradient = UndefinedGradientOp() + class LogFactorial(scalar.UnaryScalarOp): """ Compute log x!. @@ -28,6 +40,8 @@ return v def impl(self, x): return LogFactorial.st_impl(x) + def grad(self, (x,), (gz,)): + undefined_gradient(self) # def grad(self, (x,), (gz,)): # raise NotImplementedError('gradient not defined over discrete values') # return None diff -r b282a5c2f76b -r 576eb7f77c35 pylearn/algorithms/sandbox/test_cost.py --- a/pylearn/algorithms/sandbox/test_cost.py Tue Mar 10 20:58:16 2009 -0400 +++ b/pylearn/algorithms/sandbox/test_cost.py Wed Mar 11 00:09:22 2009 -0400 @@ -38,7 +38,16 @@ # (goutput) = TT.grad(loss, [target]) f = T.function([], goutput) print f() - self.failUnless(f() - 33751.7816277 < 1e-5) + self.failUnless(numpy.all(f() - numpy.asarray([206., 559.96605666, 558.96605666, 205., 557.96605666, 204., 30473.11077513, 459.96605666] < 1e-5))) + + def test_gradient_fail(self): + target = TT.as_tensor([0, 0, 1, 1, 2, 2, 100, 100]) + output = TT.as_tensor([0., 1, 1., 0, 1, 0, 5, 1]) + loss = cost.nlpoisson(target, output) + (goutput) = TT.grad(loss, [target]) + f = T.function([], goutput) + print f() + self.failUnless(numpy.all(f() - numpy.asarray([206., 559.96605666, 558.96605666, 205., 557.96605666, 204., 30473.11077513, 459.96605666] < 1e-5))) if __name__ == '__main__': unittest.main()