changeset 181:1b06bc2c3ca9

fixed c_code for the ops in nnet_ops.py
author Olivier Breuleux <breuleuo@iro.umontreal.ca>
date Tue, 13 May 2008 15:49:39 -0400
parents 2698c0feeb54
children 4afb41e61fcf
files nnet_ops.py
diffstat 1 files changed, 8 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/nnet_ops.py	Tue May 13 15:35:43 2008 -0400
+++ b/nnet_ops.py	Tue May 13 15:49:39 2008 -0400
@@ -20,15 +20,15 @@
     def grad(self, (x,), (gz,)):
         y = scalar_sigmoid(x)
         return [gz * y * (1.0 - y)]
-    def c_code(self, (x,), (z,), sub):
-        if 'float' in self.inputs[0].dtype:
+    def c_code(self, node, name, (x,), (z,), sub):
+        if node.inputs[0].type in [scalar.float32, scalar.float64]:
             return """%(z)s =
                 %(x)s < -30.0 
                 ? 0.0 
                 : %(x)s > 30.0 
                    ? 1.0
                    : 1.0 /(1.0+exp(-%(x)s));""" % locals()
-        return NotImplemented#Error('only floatingpoint is implemented')
+        raise NotImplementedError('only floatingpoint is implemented')
 scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid')
 sigmoid = tensor.Elemwise(scalar_sigmoid, name='sigmoid')
 
@@ -44,15 +44,15 @@
         return ScalarSoftplus.static_impl(x)
     def grad(self, (x,), (gz,)):
         return [gz * scalar_sigmoid(x)]
-    def c_code(self, (x,), (z,), sub):
-        if 'float' in self.inputs[0].dtype:
+    def c_code(self, name, node, (x,), (z,), sub):
+        if node.inputs[0].type in [scalar.float32, scalar.float64]:
             return """%(z)s =
                 %(x)s < -30.0 
                 ? 0.0 
                 : %(x)s > 30.0 
                    ? %(x)s
                    : log1p(exp(%(x)s));""" % locals()
-        return NotImplemented#Error('only floating point x is implemented')
+        raise NotImplementedError('only floating point x is implemented')
 scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus')
 softplus = tensor.Elemwise(scalar_softplus, name='softplus')
 
@@ -135,7 +135,7 @@
         return dx, db, None
 
     def c_headers(self): return ['<iostream>']
-    def c_code(self,  (x, b, y_idx), (nll, sm), sub):
+    def c_code(self, node, name, (x, b, y_idx), (nll, sm), sub):
         # this implementation was lifted from
         # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
 
@@ -302,7 +302,7 @@
         output_storage[0][0] = dx
     def grad(self, *args):
         raise NotImplementedError()
-    def c_code(self,  (dnll, sm, y_idx), (dx,), sub):
+    def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub):
         return """
 
         if ((%(dnll)s->descr->type_num != PyArray_DOUBLE)