diff nnet_ops.py @ 34:1b152f46ad0c

consolidated code
author bergstrj@iro.umontreal.ca
date Thu, 17 Apr 2008 12:49:48 -0400
parents 039c0f249859
children 810a8e3c85e1
line wrap: on
line diff
--- a/nnet_ops.py	Thu Apr 17 12:49:33 2008 -0400
+++ b/nnet_ops.py	Thu Apr 17 12:49:48 2008 -0400
@@ -71,23 +71,6 @@
         db = tensor.Sum(dx, axis = [0]).outputs[0]
         return dx, db, None
 
-    def c_validate_update(self, (x, b, y_idx), (nll, sm), sub):
-        """Allocate output storage"""
-        return """
-        if (%(x)s->nd != 2) { %(fail)s }
-        if (%(b)s->nd != 1) { %(fail)s }
-        if (%(y_idx)s->nd != 1) { %(fail)s }
-        if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
-        if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
-        if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}
-
-        %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s);
-        if(!%(nll)s){%(fail)s}
-
-        %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s);
-        if(!%(sm)s){Py_XDECREF(%(nll)s); %(fail)s}
-
-        """ % dict(locals(), **sub)
     def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub):
         """Not sure..."""
         return ""
@@ -105,6 +88,23 @@
 
         return """
         npy_intp* Nx = %(x)s->dimensions;
+
+        if (%(x)s->nd != 2) { %(fail)s }
+        if (%(b)s->nd != 1) { %(fail)s }
+        if (%(y_idx)s->nd != 1) { %(fail)s }
+        if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
+        if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
+        if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}
+
+        %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s);
+        if(!%(nll)s){%(fail)s}
+
+        %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s);
+        if(!%(sm)s) {
+            // The normal cleanup code will take care of %(nll)s
+            // Py_XDECREF(%(nll)s); %(nll)s=NULL;
+            %(fail)s
+        }
         if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s}
         if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s}
         if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s}