Mercurial > pylearn
diff xlogx.py @ 454:6e7509acb1c0
Merged
author | delallea@valhalla.apstat.com |
---|---|
date | Thu, 02 Oct 2008 13:41:43 -0400 |
parents | 117e5b09cf31 |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/xlogx.py Thu Oct 02 13:41:43 2008 -0400 @@ -0,0 +1,28 @@ + +import theano +from theano import tensor, scalar +import numpy + +class XlogX(scalar.UnaryScalarOp): + """ + Compute X * log(X), with special case 0 log(0) = 0. + """ + @staticmethod + def st_impl(x): + if x == 0.0: + return 0.0 + return x * numpy.log(x) + def impl(self, x): + return XlogX.st_impl(x) + def grad(self, (x,), (gz,)): + return [gz * (1 + scalar.log(x))] + def c_code(self, node, name, (x,), (z,), sub): + if node.inputs[0].type in [scalar.float32, scalar.float64]: + return """%(z)s = + %(x)s == 0.0 + ? 0.0 + : %(x)s * log(%(x)s);""" % locals() + raise NotImplementedError('only floatingpoint is implemented') +scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx') +xlogx = tensor.Elemwise(scalar_xlogx, name='xlogx') +