changeset 450:117e5b09cf31

Added an XlogX op.
author Joseph Turian <turian@gmail.com>
date Thu, 04 Sep 2008 14:46:17 -0400
parents 2bb67e978c28
children d99fefbc9324
files _test_xlogx.py xlogx.py
diffstat 2 files changed, 55 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/_test_xlogx.py	Thu Sep 04 14:46:17 2008 -0400
@@ -0,0 +1,27 @@
+from xlogx import xlogx
+
+import unittest
+from theano import compile
+from theano import gradient
+
+from theano.tensor import as_tensor
+import theano._test_tensor as TT
+
+import random
+import numpy.random
+
+class T_XlogX(unittest.TestCase):
+    def test0(self):
+        x = as_tensor([1, 0])
+        y = xlogx(x)
+        y = compile.eval_outputs([y])
+        self.failUnless(numpy.all(y == numpy.asarray([0, 0.])))
+    def test1(self):
+        class Dummy(object):
+            def make_node(self, a):
+                return [xlogx(a)[:,2]]
+        TT.verify_grad(self, Dummy(), [numpy.random.rand(3,4)])
+
+
+if __name__ == '__main__':
+    unittest.main()
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/xlogx.py	Thu Sep 04 14:46:17 2008 -0400
@@ -0,0 +1,28 @@
+
+import theano
+from theano import tensor, scalar
+import numpy
+
+class XlogX(scalar.UnaryScalarOp):
+    """
+    Compute X * log(X), with special case 0 log(0) = 0.
+    """
+    @staticmethod
+    def st_impl(x):
+        if x == 0.0:
+            return 0.0
+        return x * numpy.log(x)
+    def impl(self, x):
+        return XlogX.st_impl(x)
+    def grad(self, (x,), (gz,)):
+        return [gz * (1 + scalar.log(x))]
+    def c_code(self, node, name, (x,), (z,), sub):
+        if node.inputs[0].type in [scalar.float32, scalar.float64]:
+            return """%(z)s =
+                %(x)s == 0.0
+                ? 0.0
+                : %(x)s * log(%(x)s);""" % locals()
+        raise NotImplementedError('only floatingpoint is implemented')
+scalar_xlogx  = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
+xlogx = tensor.Elemwise(scalar_xlogx, name='xlogx')
+