diff xlogx.py @ 454:6e7509acb1c0

Merged
author delallea@valhalla.apstat.com
date Thu, 02 Oct 2008 13:41:43 -0400
parents 117e5b09cf31
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/xlogx.py	Thu Oct 02 13:41:43 2008 -0400
@@ -0,0 +1,28 @@
+
+import theano
+from theano import tensor, scalar
+import numpy
+
+class XlogX(scalar.UnaryScalarOp):
+    """
+    Compute X * log(X), with special case 0 log(0) = 0.
+    """
+    @staticmethod
+    def st_impl(x):
+        if x == 0.0:
+            return 0.0
+        return x * numpy.log(x)
+    def impl(self, x):
+        return XlogX.st_impl(x)
+    def grad(self, (x,), (gz,)):
+        return [gz * (1 + scalar.log(x))]
+    def c_code(self, node, name, (x,), (z,), sub):
+        if node.inputs[0].type in [scalar.float32, scalar.float64]:
+            return """%(z)s =
+                %(x)s == 0.0
+                ? 0.0
+                : %(x)s * log(%(x)s);""" % locals()
+        raise NotImplementedError('only floatingpoint is implemented')
+scalar_xlogx  = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
+xlogx = tensor.Elemwise(scalar_xlogx, name='xlogx')
+