changeset 32:039c0f249859

added C impl for softmax dx
author bergstrj@iro.umontreal.ca
date Sun, 13 Apr 2008 00:21:11 -0400
parents 2d6c49ec5749
children bb92087cb0f6
files nnet_ops.py
diffstat 1 files changed, 64 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/nnet_ops.py	Fri Apr 11 21:42:07 2008 -0400
+++ b/nnet_ops.py	Sun Apr 13 00:21:11 2008 -0400
@@ -101,11 +101,13 @@
         #TODO: put this into a templated function, in the support code
         #TODO: declare the max of each row as an Op output
 
+        #TODO: set error messages for failures in this code
+
         return """
         npy_intp* Nx = %(x)s->dimensions;
-        assert(%(x)s->dimensions[1] == %(b)s->dimensions[0]);
-        assert(%(sm)s->dimensions[0] == %(x)s->dimensions[0]);
-        assert(%(sm)s->dimensions[1] == %(x)s->dimensions[1]);
+        if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s}
+        if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s}
+        if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s}
 
         for (size_t i = 0; i < Nx[0]; ++i)
         {
@@ -149,7 +151,11 @@
                     sum += sm_ij;
                     sm_i[j * Ssm] = sm_ij;
                 }
-                assert( (0.0 != sum) && (!isinf(sum))); //that was our best... 
+                if ( (0.0 == sum) || (isinf(sum)))
+                { 
+                    //that was our best... 
+                    %(fail)s;
+                }
                 //if we still can't sum it up, we're screwed.
                 //So far, this assertion has never failed...
             }
@@ -161,7 +167,10 @@
                 sm_i[j * Ssm] *= sum_inv;
             }
 
-            assert(y_i < Nx[1]);
+            if (y_i >= Nx[1])
+            {
+                %(fail)s;
+            }
 
             nll_i[0] = - x_i[y_i*Sx] 
                        - b_i[y_i*Sb]
@@ -195,7 +204,55 @@
         self.outputs[0].data = dx
     def grad(self, *args):
         raise NotImplementedError()
+    def c_validate_update(self, (dnll, sm, y_idx), (dx,), sub):
+        """Allocate output storage"""
+        return """
+        if (%(dnll)s->nd != 1) { %(fail)s }
+        if (%(sm)s->nd != 2) { %(fail)s }
+        if (%(y_idx)s->nd != 1) { %(fail)s }
+        if (%(dnll)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
+        if (%(sm)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
+        if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}
 
-#TODO: write a version of CrossentropySoftmax1Hot that accepts a bias for x, if
-# this op needs to be faster.
+        %(dx)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(sm)s), type_num_%(sm)s);
+        if(!%(dx)s){%(fail)s}
+
+        """ % dict(locals(), **sub)
+    def c_validate_cleanup(self, inputs, outputs, sub):
+        """Not sure..."""
+        return ""
+    def c_support_code(self):
+        return """
+        """
+    def c_code(self,  (dnll, sm, y_idx), (dx,), sub):
+        return """
+        npy_intp* shape = %(dx)s->dimensions;
+        if (%(dnll)s->dimensions[0] != %(sm)s->dimensions[0]) {%(fail)s}
+        if (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0]) {%(fail)s}
+        if (%(dnll)s->dimensions[0] != %(dx)s->dimensions[0]) {%(fail)s}
+
+        if (%(sm)s->dimensions[1] != %(dx)s->dimensions[1]) {%(fail)s}
 
+        for (size_t i = 0; i < shape[0]; ++i)
+        {
+            const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0];
+
+            const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
+
+            const double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
+            npy_intp Ssm = %(sm)s->strides[1]/sizeof(double);
+
+            double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i);
+            npy_intp Sdx = %(dx)s->strides[1]/sizeof(double);
+
+            for (size_t j = 0; j < shape[1]; ++j)
+            {
+                dx_i[j * Sdx] = dnll_i * sm_i[j * Ssm];
+            }
+            if (y_i >= shape[1])
+            {
+                %(fail)s;
+            }
+            dx_i[y_i * Sdx] -= dnll_i;
+        }
+        """ % dict(locals(), **sub)