Mercurial > pylearn
comparison xlogx.py @ 454:6e7509acb1c0
Merged
author | delallea@valhalla.apstat.com |
---|---|
date | Thu, 02 Oct 2008 13:41:43 -0400 |
parents | 117e5b09cf31 |
children |
comparison
equal
deleted
inserted
replaced
453:ce6b4fd3ab29 | 454:6e7509acb1c0 |
---|---|
1 | |
2 import theano | |
3 from theano import tensor, scalar | |
4 import numpy | |
5 | |
6 class XlogX(scalar.UnaryScalarOp): | |
7 """ | |
8 Compute X * log(X), with special case 0 log(0) = 0. | |
9 """ | |
10 @staticmethod | |
11 def st_impl(x): | |
12 if x == 0.0: | |
13 return 0.0 | |
14 return x * numpy.log(x) | |
15 def impl(self, x): | |
16 return XlogX.st_impl(x) | |
17 def grad(self, (x,), (gz,)): | |
18 return [gz * (1 + scalar.log(x))] | |
19 def c_code(self, node, name, (x,), (z,), sub): | |
20 if node.inputs[0].type in [scalar.float32, scalar.float64]: | |
21 return """%(z)s = | |
22 %(x)s == 0.0 | |
23 ? 0.0 | |
24 : %(x)s * log(%(x)s);""" % locals() | |
25 raise NotImplementedError('only floatingpoint is implemented') | |
26 scalar_xlogx = XlogX(scalar.upgrade_to_float, name='scalar_xlogx') | |
27 xlogx = tensor.Elemwise(scalar_xlogx, name='xlogx') | |
28 |