changeset 30:bf0145fa73e8

added c implementation for CrossentropySoftmax1Hot
author bergstrj@iro.umontreal.ca
date Fri, 11 Apr 2008 21:41:09 -0400
parents e6c550cb2896
children 2d6c49ec5749
files _nnet_ops.py nnet_ops.py
diffstat 2 files changed, 152 insertions(+), 31 deletions(-) [+]
line wrap: on
line diff
--- a/_nnet_ops.py	Fri Apr 11 11:16:09 2008 -0400
+++ b/_nnet_ops.py	Fri Apr 11 21:41:09 2008 -0400
@@ -17,9 +17,10 @@
         numpy.random.seed(9999)
     def test0(self):
         y_idx = [0,1,3]
-        def output1(a):
-            return crossentropy_softmax_1hot(a, y_idx)[0:1]
-        TT.verify_grad(self, output1, [numpy.random.rand(3,4)])
+        def output1(a,b):
+            return crossentropy_softmax_1hot(a, b, y_idx)[0:1]
+        TT.verify_grad(self, output1, [numpy.random.rand(3,4),
+            numpy.random.rand(4)])
 
 
 
--- a/nnet_ops.py	Fri Apr 11 11:16:09 2008 -0400
+++ b/nnet_ops.py	Fri Apr 11 21:41:09 2008 -0400
@@ -29,53 +29,173 @@
     """
     nin=2
     nout=2
-    def __init__(self, x, y_idx,**kwargs):
+    def __init__(self, x, b, y_idx, **kwargs):
         x = tensor._as_tensor(x)
+        b = tensor._as_tensor(b)
         y_idx = tensor._as_tensor(y_idx)
+        if len(x.broadcastable) != 2 \
+                or x.dtype not in ['float32', 'float64']:
+            raise ValueError('x must be 2-d tensor of floats')
+        if len(b.broadcastable) != 1 \
+                or x.dtype not in ['float32', 'float64']:
+            raise ValueError('x must be 1-d tensor of floats')
+        if len(y_idx.broadcastable) != 1 \
+                or y_idx.dtype not in ['int32', 'int64']:
+            raise ValueError('x must be 1-d tensor of ints')
+
 #       TODO: Is this correct? It used to be y, not y_idx
         nll = tensor.Tensor(x.dtype, y_idx.broadcastable)
 #        nll = Tensor(x.dtype, y.broadcastable)
         sm = tensor.Tensor(x.dtype, x.broadcastable)
-        self.inputs = [x, y_idx]
-        self.outputs = [nll,sm]
+        self.inputs = [x, b, y_idx]
+        self.outputs = [nll, sm]
     def perform(self):
-        x, y_idx = [i.data for i in self.inputs]
+        x, b, y_idx = [i.data for i in self.inputs]
+        if b.shape[0] != x.shape[1]:
+            raise ValueError('b must have same shape as x[0]')
+
         sm = numpy.zeros_like(x) # softmax
         nll = numpy.zeros(x.shape[0]) #nll(y | softmax(x))
         for i in xrange(sm.shape[0]):
-            sm[i] = numpy.exp(x[i] - numpy.max(x[i])) #softmax
+            row = x[i] + b
+            sm[i] = numpy.exp(row - numpy.max(row)) #softmax
             sm[i] *= 1.0 / numpy.sum(sm[i]) #vector scale
             nll[i] = -numpy.log( sm[i, y_idx[i]]) #cross-entropy
         self.outputs[0].data = nll
         self.outputs[1].data = sm
-    def grad(self, (x, y_idx), (g_nll, g_sm)):
+    def grad(self, (x, b, y_idx), (g_nll, g_sm)):
         if g_sm is not None:
             raise NotImplementedError()
-        nll, sm = crossentropy_softmax_1hot(x, y_idx)
-        dx = CrossentropySoftmax1Hot.Dx(g_nll, sm, y_idx).outputs[0]
-        return dx, None
+        nll, sm = crossentropy_softmax_1hot(x, b, y_idx)
+        dx = CrossentropySoftmax1HotDx(g_nll, sm, y_idx).outputs[0]
+        db = tensor.Sum(dx, axis = [0]).outputs[0]
+        return dx, db, None
+
+    def c_validate_update(self, (x, b, y_idx), (nll, sm), sub):
+        """Allocate output storage"""
+        return """
+        if (%(x)s->nd != 2) { %(fail)s }
+        if (%(b)s->nd != 1) { %(fail)s }
+        if (%(y_idx)s->nd != 1) { %(fail)s }
+        if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
+        if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
+        if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}
+
+        %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s);
+        if(!%(nll)s){%(fail)s}
+
+        %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s);
+        if(!%(sm)s){Py_XDECREF(%(nll)s); %(fail)s}
+
+        """ % dict(locals(), **sub)
+    def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub):
+        """Not sure..."""
+        return ""
+    def c_support_code(self):
+        return """
+        """
+    def c_code(self,  (x, b, y_idx), (nll, sm), sub):
+        # this implementation was lifted from
+        # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
+
+        #TODO: put this into a templated function, in the support code
+        #TODO: declare the max of each row as an Op output
+
+        return """
+        npy_intp* Nx = %(x)s->dimensions;
+        assert(%(x)s->dimensions[1] == %(b)s->dimensions[0]);
+        assert(%(sm)s->dimensions[0] == %(x)s->dimensions[0]);
+        assert(%(sm)s->dimensions[1] == %(x)s->dimensions[1]);
+
+        for (size_t i = 0; i < Nx[0]; ++i)
+        {
+            size_t j;
+            double sum = 0.0;
+            bool  discount_max = false;
+
+            const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i);
+            const double* __restrict__ b_i = (double*)(%(b)s->data);
+            const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
+            double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
+            double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i);
+
+            npy_intp Sx = %(x)s->strides[1]/sizeof(double);
+            npy_intp Sb = %(b)s->strides[0]/sizeof(double);
+            npy_intp Ssm = %(sm)s->strides[1]/sizeof(double);
 
-    class Dx (gof.op.Op):
-        nin=3
-        nout=1
-        """Gradient wrt x of the CrossentropySoftmax1Hot Op"""
-        def __init__(self, dy, sm, y_idx,**kwargs):
-            dy = tensor._as_tensor(dy)
-            sm = tensor._as_tensor(sm)
-            y_idx = tensor._as_tensor(y_idx)
-            self.inputs = [dy, sm, y_idx]
-            self.outputs = [tensor.Tensor(sm.dtype, sm.broadcastable)]
-        def perform(self):
-            dy,sm,y_idx = [i.data for i in self.inputs]
-            dx = numpy.zeros_like(sm)
-            for i in xrange(sm.shape[0]):
-                dx[i] = dy[i] * sm[i] #vector scale
-                dx[i, y_idx[i]] -= dy[i] #scalar decrement
-            self.outputs[0].data = dx
-        def grad(self, *args):
-            raise NotImplementedError()
+            size_t row_max_j=0;
+            double row_max = x_i[0] + b_i[0];
+            //try to compute sum and sm the easy way
+            for (j = 0; j < Nx[1]; ++j)
+            {
+                double row_ij = x_i[j * Sx] +  b_i[j * Sb];
+                row_max_j = (row_ij > row_max) ? j : row_max_j;
+                row_max   = (row_ij > row_max) ? row_ij : row_max;
+
+                double sm_ij = exp(row_ij);
+                sum += sm_ij;
+                sm_i[j * Ssm] = sm_ij;
+            }
+            if ((0.0 == sum) || (isinf(sum))) 
+            {
+                //our cheap trick didn't work... try again and do it better.
+                discount_max = true;
+                sum = 0.0; //reset sum and recompute....
+                for (j = 0; j < Nx[1]; ++j)
+                {
+                    double row_ij = x_i[j * Sx] +  b_i[j * Sb];
+
+                    double sm_ij = exp(row_ij - row_max);
+                    sum += sm_ij;
+                    sm_i[j * Ssm] = sm_ij;
+                }
+                assert( (0.0 != sum) && (!isinf(sum))); //that was our best... 
+                //if we still can't sum it up, we're screwed.
+                //So far, this assertion has never failed...
+            }
+
+            //cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
+            double sum_inv = 1.0 / sum;
+            for (j = 0; j < Nx[1]; ++j)
+            {
+                sm_i[j * Ssm] *= sum_inv;
+            }
+
+            assert(y_i < Nx[1]);
+
+            nll_i[0] = - x_i[y_i*Sx] 
+                       - b_i[y_i*Sb]
+                       + (discount_max ? row_max : 0.0)
+                       + log(sum);
+              //mat_at(y,i,0) = -log( mat_at(s,i,t[i]));  //less accurate?
+              //mat_at(y,i,0) =  - mat_at(x,i,t[i]) - mat_at(b,0,t[i]) + (discount_max ? maxi : 0.0) + log(sum);
+        }
+        """ % dict(locals(), **sub)
+
+
+
 crossentropy_softmax_1hot = gof.op.constructor(CrossentropySoftmax1Hot)
 
+class CrossentropySoftmax1HotDx (gof.op.Op):
+    nin=3
+    nout=1
+    """Gradient wrt x of the CrossentropySoftmax1Hot Op"""
+    def __init__(self, dy, sm, y_idx,**kwargs):
+        dy = tensor._as_tensor(dy)
+        sm = tensor._as_tensor(sm)
+        y_idx = tensor._as_tensor(y_idx)
+        self.inputs = [dy, sm, y_idx]
+        self.outputs = [tensor.Tensor(sm.dtype, sm.broadcastable)]
+    def perform(self):
+        dy,sm,y_idx = [i.data for i in self.inputs]
+        dx = numpy.zeros_like(sm)
+        for i in xrange(sm.shape[0]):
+            dx[i] = dy[i] * sm[i] #vector scale
+            dx[i, y_idx[i]] -= dy[i] #scalar decrement
+        self.outputs[0].data = dx
+    def grad(self, *args):
+        raise NotImplementedError()
+
 #TODO: write a version of CrossentropySoftmax1Hot that accepts a bias for x, if
 # this op needs to be faster.