Mercurial > pylearn
diff nnet_ops.py @ 181:1b06bc2c3ca9
fixed c_code for the ops in nnet_ops.py
author | Olivier Breuleux <breuleuo@iro.umontreal.ca> |
---|---|
date | Tue, 13 May 2008 15:49:39 -0400 |
parents | 2ca8dccba270 |
children | 9a2aecc57a79 |
line wrap: on
line diff
--- a/nnet_ops.py Tue May 13 15:35:43 2008 -0400 +++ b/nnet_ops.py Tue May 13 15:49:39 2008 -0400 @@ -20,15 +20,15 @@ def grad(self, (x,), (gz,)): y = scalar_sigmoid(x) return [gz * y * (1.0 - y)] - def c_code(self, (x,), (z,), sub): - if 'float' in self.inputs[0].dtype: + def c_code(self, node, name, (x,), (z,), sub): + if node.inputs[0].type in [scalar.float32, scalar.float64]: return """%(z)s = %(x)s < -30.0 ? 0.0 : %(x)s > 30.0 ? 1.0 : 1.0 /(1.0+exp(-%(x)s));""" % locals() - return NotImplemented#Error('only floatingpoint is implemented') + raise NotImplementedError('only floatingpoint is implemented') scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid') sigmoid = tensor.Elemwise(scalar_sigmoid, name='sigmoid') @@ -44,15 +44,15 @@ return ScalarSoftplus.static_impl(x) def grad(self, (x,), (gz,)): return [gz * scalar_sigmoid(x)] - def c_code(self, (x,), (z,), sub): - if 'float' in self.inputs[0].dtype: + def c_code(self, name, node, (x,), (z,), sub): + if node.inputs[0].type in [scalar.float32, scalar.float64]: return """%(z)s = %(x)s < -30.0 ? 0.0 : %(x)s > 30.0 ? %(x)s : log1p(exp(%(x)s));""" % locals() - return NotImplemented#Error('only floating point x is implemented') + raise NotImplementedError('only floating point x is implemented') scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus') softplus = tensor.Elemwise(scalar_softplus, name='softplus') @@ -135,7 +135,7 @@ return dx, db, None def c_headers(self): return ['<iostream>'] - def c_code(self, (x, b, y_idx), (nll, sm), sub): + def c_code(self, node, name, (x, b, y_idx), (nll, sm), sub): # this implementation was lifted from # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx @@ -302,7 +302,7 @@ output_storage[0][0] = dx def grad(self, *args): raise NotImplementedError() - def c_code(self, (dnll, sm, y_idx), (dx,), sub): + def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub): return """ if ((%(dnll)s->descr->type_num != PyArray_DOUBLE)