Mercurial > pylearn
changeset 664:576eb7f77c35
Trying to fix gradient in logfactorial
author | Joseph Turian <turian@gmail.com> |
---|---|
date | Wed, 11 Mar 2009 00:09:22 -0400 |
parents | b282a5c2f76b |
children | 070a7d68d3a1 |
files | pylearn/algorithms/sandbox/cost.py pylearn/algorithms/sandbox/test_cost.py |
diffstat | 2 files changed, 24 insertions(+), 1 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/algorithms/sandbox/cost.py Tue Mar 10 20:58:16 2009 -0400 +++ b/pylearn/algorithms/sandbox/cost.py Wed Mar 11 00:09:22 2009 -0400 @@ -9,6 +9,18 @@ from theano import tensor, scalar import numpy +class UndefinedGradient(Exception): + """ + Raised by UndefinedGradientOp to indicate that the gradient is undefined mathematically. + """ + pass +from theano import gof +class UndefinedGradientOp(gof.Op): + def perform(self, x=None): + if x is not None: raise UndefinedGradient(x) + else: raise UndefinedGradient(x) +undefined_gradient = UndefinedGradientOp() + class LogFactorial(scalar.UnaryScalarOp): """ Compute log x!. @@ -28,6 +40,8 @@ return v def impl(self, x): return LogFactorial.st_impl(x) + def grad(self, (x,), (gz,)): + undefined_gradient(self) # def grad(self, (x,), (gz,)): # raise NotImplementedError('gradient not defined over discrete values') # return None
--- a/pylearn/algorithms/sandbox/test_cost.py Tue Mar 10 20:58:16 2009 -0400 +++ b/pylearn/algorithms/sandbox/test_cost.py Wed Mar 11 00:09:22 2009 -0400 @@ -38,7 +38,16 @@ # (goutput) = TT.grad(loss, [target]) f = T.function([], goutput) print f() - self.failUnless(f() - 33751.7816277 < 1e-5) + self.failUnless(numpy.all(f() - numpy.asarray([206., 559.96605666, 558.96605666, 205., 557.96605666, 204., 30473.11077513, 459.96605666] < 1e-5))) + + def test_gradient_fail(self): + target = TT.as_tensor([0, 0, 1, 1, 2, 2, 100, 100]) + output = TT.as_tensor([0., 1, 1., 0, 1, 0, 5, 1]) + loss = cost.nlpoisson(target, output) + (goutput) = TT.grad(loss, [target]) + f = T.function([], goutput) + print f() + self.failUnless(numpy.all(f() - numpy.asarray([206., 559.96605666, 558.96605666, 205., 557.96605666, 204., 30473.11077513, 459.96605666] < 1e-5))) if __name__ == '__main__': unittest.main()