# HG changeset patch # User Olivier Breuleux # Date 1210708179 14400 # Node ID 1b06bc2c3ca979ca47229ed5ca94a40bbbf71ae3 # Parent 2698c0feeb54e7d77505ffc1b4b2f61338ed7d3f fixed c_code for the ops in nnet_ops.py diff -r 2698c0feeb54 -r 1b06bc2c3ca9 nnet_ops.py --- a/nnet_ops.py Tue May 13 15:35:43 2008 -0400 +++ b/nnet_ops.py Tue May 13 15:49:39 2008 -0400 @@ -20,15 +20,15 @@ def grad(self, (x,), (gz,)): y = scalar_sigmoid(x) return [gz * y * (1.0 - y)] - def c_code(self, (x,), (z,), sub): - if 'float' in self.inputs[0].dtype: + def c_code(self, node, name, (x,), (z,), sub): + if node.inputs[0].type in [scalar.float32, scalar.float64]: return """%(z)s = %(x)s < -30.0 ? 0.0 : %(x)s > 30.0 ? 1.0 : 1.0 /(1.0+exp(-%(x)s));""" % locals() - return NotImplemented#Error('only floatingpoint is implemented') + raise NotImplementedError('only floatingpoint is implemented') scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid') sigmoid = tensor.Elemwise(scalar_sigmoid, name='sigmoid') @@ -44,15 +44,15 @@ return ScalarSoftplus.static_impl(x) def grad(self, (x,), (gz,)): return [gz * scalar_sigmoid(x)] - def c_code(self, (x,), (z,), sub): - if 'float' in self.inputs[0].dtype: + def c_code(self, name, node, (x,), (z,), sub): + if node.inputs[0].type in [scalar.float32, scalar.float64]: return """%(z)s = %(x)s < -30.0 ? 0.0 : %(x)s > 30.0 ? %(x)s : log1p(exp(%(x)s));""" % locals() - return NotImplemented#Error('only floating point x is implemented') + raise NotImplementedError('only floating point x is implemented') scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus') softplus = tensor.Elemwise(scalar_softplus, name='softplus') @@ -135,7 +135,7 @@ return dx, db, None def c_headers(self): return [''] - def c_code(self, (x, b, y_idx), (nll, sm), sub): + def c_code(self, node, name, (x, b, y_idx), (nll, sm), sub): # this implementation was lifted from # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx @@ -302,7 +302,7 @@ output_storage[0][0] = dx def grad(self, *args): raise NotImplementedError() - def c_code(self, (dnll, sm, y_idx), (dx,), sub): + def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub): return """ if ((%(dnll)s->descr->type_num != PyArray_DOUBLE)