comparison xlogx.py @ 454:6e7509acb1c0

Merged
author delallea@valhalla.apstat.com
date Thu, 02 Oct 2008 13:41:43 -0400
parents 117e5b09cf31
children
comparison
equal deleted inserted replaced
453:ce6b4fd3ab29 454:6e7509acb1c0
1
2 import theano
3 from theano import tensor, scalar
4 import numpy
5
6 class XlogX(scalar.UnaryScalarOp):
7 """
8 Compute X * log(X), with special case 0 log(0) = 0.
9 """
10 @staticmethod
11 def st_impl(x):
12 if x == 0.0:
13 return 0.0
14 return x * numpy.log(x)
15 def impl(self, x):
16 return XlogX.st_impl(x)
17 def grad(self, (x,), (gz,)):
18 return [gz * (1 + scalar.log(x))]
19 def c_code(self, node, name, (x,), (z,), sub):
20 if node.inputs[0].type in [scalar.float32, scalar.float64]:
21 return """%(z)s =
22 %(x)s == 0.0
23 ? 0.0
24 : %(x)s * log(%(x)s);""" % locals()
25 raise NotImplementedError('only floatingpoint is implemented')
26 scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx')
27 xlogx = tensor.Elemwise(scalar_xlogx, name='xlogx')
28