changeset 71:5b699b31770a

merge
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 02 May 2008 18:19:35 -0400
parents 76e5c0f37165 (diff) dde1fb1b63ba (current diff)
children 2b6656b2ef52
files
diffstat 2 files changed, 195 insertions(+), 74 deletions(-) [+]
line wrap: on
line diff
--- a/_nnet_ops.py	Fri May 02 11:24:17 2008 -0400
+++ b/_nnet_ops.py	Fri May 02 18:19:35 2008 -0400
@@ -11,6 +11,11 @@
     def test_elemwise(self):
         TT.verify_grad(self, Sigmoid, [numpy.random.rand(3,4)])
 
+class T_softplus(unittest.TestCase):
+    def setUp(self):
+        numpy.random.seed(9999)
+    def test_elemwise(self):
+        TT.verify_grad(self, Softplus, [numpy.random.rand(3,4)])
 
 class T_CrossentropySoftmax1Hot(unittest.TestCase):
     def setUp(self):
@@ -18,10 +23,16 @@
     def test0(self):
         y_idx = [0,1,3]
         def output1(a,b):
-            return crossentropy_softmax_1hot(a, b, y_idx)[0:1]
+            return crossentropy_softmax_1hot_with_bias(a, b, y_idx)[0:1]
         TT.verify_grad(self, output1, [numpy.random.rand(3,4),
             numpy.random.rand(4)])
 
+    def test1(self):
+        y_idx = [0,1,3]
+        def output1(a):
+            return crossentropy_softmax_1hot(a, y_idx)[0:1]
+        TT.verify_grad(self, output1, [numpy.random.rand(3,4)])
+
 
 
 if __name__ == '__main__':
--- a/nnet_ops.py	Fri May 02 11:24:17 2008 -0400
+++ b/nnet_ops.py	Fri May 02 18:19:35 2008 -0400
@@ -2,32 +2,93 @@
 from theano import tensor, gof, scalar
 import numpy
 
-class ScalarSigmoid(scalar.UnaryScalarOp):
+############
+#
+# SCALAR OPS
+#
+
+class ScalarSigmoid(scalar.FloatUnaryScalarOp):
+    @staticmethod
+    def st_impl(x):
+        if x < -30.0:
+            return 0.0
+        if x > 30.0:
+            return 1.0 
+        return 1.0 / (1.0 + numpy.exp(-x))
     def impl(self, x):
-        return 1.0 / (1 + numpy.exp(-x))
+        return ScalarSigmoid.st_impl(x)
     def grad(self, (x,), (gz,)):
-        return gz * scalar_sigmoid(x) * (1.0 - scalar_sigmoid(x)),
-    def c_foreach(self, (x,), (z,)): 
-        return "%(z)s = 1.0 / (1 + exp(-%(x)s));" % locals()
+        y = scalar_sigmoid(x)
+        return [gz * y * (1.0 - y)]
+    def c_foreach(self, (x,), (z,), sub):
+        if 'float' in self.inputs[0].dtype:
+            return """%(z)s =
+                %(x)s < -30.0 
+                ? 0.0 
+                : %(x)s > 30.0 
+                   ? 1.0
+                   : 1.0 /(1.0+exp(-%(x)s));""" % locals()
+        raise NotImplementedError('only floatingpoint is implemented')
 scalar_sigmoid = gof.op.constructor(ScalarSigmoid)
-Sigmoid, sigmoid, SigmoidInplace, sigmoid_inplace \
-        = theano.tensor.broadcast(ScalarSigmoid, 'Sigmoid')
+Sigmoid, sigmoid, SigmoidInplace, sigmoid_inplace =\
+        tensor.broadcast(ScalarSigmoid, 'Sigmoid')
 
+class ScalarSoftplus(scalar.FloatUnaryScalarOp):
+    @staticmethod
+    def static_impl(x):
+        if x < -30.0:
+            return 0.0
+        if x > 30.0:
+            return x
+        return numpy.log1p(numpy.exp(x))
+    def impl(self, x):
+        return ScalarSoftplus.static_impl(x)
+    def grad(self, (x,), (gz,)):
+        return [gz * scalar_sigmoid(x)]
+    def c_foreach(self, (x,), (z,), sub):
+        if 'float' in self.inputs[0].dtype:
+            return """%(z)s =
+                %(x)s < -30.0 
+                ? 0.0 
+                : %(x)s > 30.0 
+                   ? %(x)s
+                   : log1p(exp(%(x)s));""" % locals()
+        raise NotImplementedError('only floating point x is implemented')
+scalar_softplus = gof.op.constructor(ScalarSoftplus)
+Softplus, softplus, SoftplusInplace, softplus_inplace =\
+        tensor.broadcast(ScalarSoftplus, 'Softplus')
 
 
-class CrossentropySoftmax1Hot(gof.op.Op):
-    """A special compound Op for the output of neural-net classifiers.
+############
+#
+# TENSOR OPS
+#
+
+
+class CrossentropySoftmax1HotWithBias(gof.op.Op):
+    """A special compound L{Op} for the output of neural-net classifiers.
+
+    @type x: is a matrix of floats (32 or 64)
+    @type b: is a [row] vector of floats (32 or 64), length is number of cols in x
+    @type y_idx: a [column] vector of int (32 or 64), length is number of rows in x
 
-    This Op has two outputs:
-    - KL(softmax(x), y)
-    - softmax(x)
+    @precondition: every entry in y_idx is a valid (non-negative) column index into x
+
+    This L{Op} has two outputs:
+     - KL(softmax(x+b), y)
+     - softmax(x+b)
 
-    x[i] is assumed to be a dense vector
+    
     softmax(x[i]) is the i'th distribution over len(x[i]) options
-    y[i] is an integer index, encoding a 1-hot distribution
+
+    y_idx[i] is an integer index, encoding a 1-hot distribution. 
+    
+    In practice, when we're trying to do classification, we have one row in x
+    and y_idx per example, and y[i] is the index of the (correct) class of the
+    i'th example.
 
     """
-    nin=2
+    nin=3
     nout=2
     def __init__(self, x, b, y_idx, **kwargs):
         x = tensor._as_tensor(x)
@@ -52,7 +113,9 @@
     def perform(self):
         x, b, y_idx = [i.data for i in self.inputs]
         if b.shape[0] != x.shape[1]:
-            raise ValueError('b must have same shape as x[0]')
+            raise ValueError('b must have same number of columns as x')
+        if y_idx.shape[0] != x.shape[0]:
+            raise ValueError('y_idx must have same number of rows as x')
 
         sm = numpy.zeros_like(x) # softmax
         nll = numpy.zeros(x.shape[0]) #nll(y | softmax(x))
@@ -66,17 +129,12 @@
     def grad(self, (x, b, y_idx), (g_nll, g_sm)):
         if g_sm is not None:
             raise NotImplementedError()
-        nll, sm = crossentropy_softmax_1hot(x, b, y_idx)
-        dx = CrossentropySoftmax1HotDx(g_nll, sm, y_idx).outputs[0]
+        nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx)
+        dx = CrossentropySoftmax1HotWithBiasDx(g_nll, sm, y_idx).outputs[0]
         db = tensor.Sum(dx, axis = [0]).outputs[0]
         return dx, db, None
 
-    def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub):
-        """Not sure..."""
-        return ""
-    def c_support_code(self):
-        return """
-        """
+    def c_headers(self): return ['<iostream>']
     def c_code(self,  (x, b, y_idx), (nll, sm), sub):
         # this implementation was lifted from
         # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
@@ -89,25 +147,67 @@
         return """
         npy_intp* Nx = %(x)s->dimensions;
 
-        if (%(x)s->nd != 2) { %(fail)s }
-        if (%(b)s->nd != 1) { %(fail)s }
-        if (%(y_idx)s->nd != 1) { %(fail)s }
-        if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
-        if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
-        if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}
-
-        %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s);
-        if(!%(nll)s){%(fail)s}
+        if (%(x)s->nd != 2)
+        {
+            PyErr_SetString(PyExc_ValueError, "a not 2d tensor");
+            %(fail)s;
+        }
+        if (%(b)s->nd != 1)
+        {
+            PyErr_SetString(PyExc_ValueError, "b not 1d tensor");
+            %(fail)s;
+        }
+        if (%(y_idx)s->nd != 1)
+        {
+            PyErr_SetString(PyExc_ValueError, "y_idx not 1d tensor");
+            %(fail)s;
+        }
+        if (%(x)s->descr->type_num != PyArray_DOUBLE)
+        {
+            PyErr_SetString(PyExc_TypeError, "a not float64");
+            %(fail)s;
+        }
+        if (%(b)s->descr->type_num != PyArray_DOUBLE)
+        {
+            PyErr_SetString(PyExc_TypeError, "b not float64");
+            %(fail)s;
+        }
+        if (%(y_idx)s->descr->type_num != PyArray_INT64)
+        {
+            PyErr_SetString(PyExc_TypeError, "y_idx not int64");
+            %(fail)s;
+        }
+        if ((%(x)s->dimensions[1] != %(b)s->dimensions[0])
+         || (%(x)s->dimensions[0] != %(y_idx)s->dimensions[0]))
+        {
+            PyErr_SetString(PyExc_ValueError, "dimension mismatch in arguments");
+            %(fail)s;
+        }
 
-        %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s);
-        if(!%(sm)s) {
-            // The normal cleanup code will take care of %(nll)s
-            // Py_XDECREF(%(nll)s); %(nll)s=NULL;
-            %(fail)s
+        if ((NULL == %(nll)s) //initial condition
+            || (%(nll)s->dimensions[0] != %(y_idx)s->dimensions[0]))
+        {
+            if (NULL != %(nll)s) Py_XDECREF(%(nll)s);
+            %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s);
+            if(!%(nll)s)
+            {
+                PyErr_SetString(PyExc_MemoryError, "failed to alloc nll output");
+                %(fail)s;
+            }
         }
-        if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s}
-        if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s}
-        if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s}
+        if ((NULL == %(sm)s)
+            || (%(sm)s->dimensions[0] != %(x)s->dimensions[0])
+            || (%(sm)s->dimensions[1] != %(x)s->dimensions[1]))
+        {
+            if (NULL != %(sm)s) Py_XDECREF(%(sm)s);
+            %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s);
+            if(!%(sm)s) {
+                // The normal cleanup code will take care of %(nll)s
+                // Py_XDECREF(%(nll)s); %(nll)s=NULL;
+                PyErr_SetString(PyExc_MemoryError, "failed to alloc sm output");
+                %(fail)s
+            }
+        }
 
         for (size_t i = 0; i < Nx[0]; ++i)
         {
@@ -181,11 +281,10 @@
         }
         """ % dict(locals(), **sub)
 
-
+crossentropy_softmax_1hot_with_bias = \
+        gof.op.constructor(CrossentropySoftmax1HotWithBias)
 
-crossentropy_softmax_1hot = gof.op.constructor(CrossentropySoftmax1Hot)
-
-class CrossentropySoftmax1HotDx (gof.op.Op):
+class CrossentropySoftmax1HotWithBiasDx (gof.op.Op):
     nin=3
     nout=1
     """Gradient wrt x of the CrossentropySoftmax1Hot Op"""
@@ -204,36 +303,42 @@
         self.outputs[0].data = dx
     def grad(self, *args):
         raise NotImplementedError()
-    def c_validate_update(self, (dnll, sm, y_idx), (dx,), sub):
-        """Allocate output storage"""
-        return """
-        if (%(dnll)s->nd != 1) { %(fail)s }
-        if (%(sm)s->nd != 2) { %(fail)s }
-        if (%(y_idx)s->nd != 1) { %(fail)s }
-        if (%(dnll)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
-        if (%(sm)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
-        if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}
-
-        %(dx)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(sm)s), type_num_%(sm)s);
-        if(!%(dx)s){%(fail)s}
-
-        """ % dict(locals(), **sub)
-    def c_validate_cleanup(self, inputs, outputs, sub):
-        """Not sure..."""
-        return ""
-    def c_support_code(self):
-        return """
-        """
     def c_code(self,  (dnll, sm, y_idx), (dx,), sub):
         return """
-        npy_intp* shape = %(dx)s->dimensions;
-        if (%(dnll)s->dimensions[0] != %(sm)s->dimensions[0]) {%(fail)s}
-        if (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0]) {%(fail)s}
-        if (%(dnll)s->dimensions[0] != %(dx)s->dimensions[0]) {%(fail)s}
 
-        if (%(sm)s->dimensions[1] != %(dx)s->dimensions[1]) {%(fail)s}
+        if ((%(dnll)s->descr->type_num != PyArray_DOUBLE)
+            || (%(sm)s->descr->type_num != PyArray_DOUBLE)
+            || (%(y_idx)s->descr->type_num != PyArray_INT64))
+        {
+            PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64");
+            %(fail)s;
+        }
+        if ((%(dnll)s->nd != 1)
+            || (%(sm)s->nd != 2)
+            || (%(y_idx)s->nd != 1))
+        {
+            PyErr_SetString(PyExc_ValueError, "rank error");
+            %(fail)s;
+        }
+        if ((%(dnll)s->dimensions[0] != %(sm)s->dimensions[0])
+            || (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0]))
+        {
+            PyErr_SetString(PyExc_ValueError, "dimension mismatch");
+            %(fail)s;
+        }
+        if ((NULL == %(dx)s)
+            || (%(dx)s->dimensions[0] != %(sm)s->dimensions[0])
+            || (%(dx)s->dimensions[1] != %(sm)s->dimensions[1]))
+        {
+            if (NULL != %(dx)s) Py_XDECREF(%(dx)s);
+            %(dx)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(sm)s), type_num_%(sm)s);
+            if(!%(dx)s) {
+                PyErr_SetString(PyExc_MemoryError, "failed to alloc dx output");
+                %(fail)s
+            }
+        }
 
-        for (size_t i = 0; i < shape[0]; ++i)
+        for (size_t i = 0; i < %(dx)s->dimensions[0]; ++i)
         {
             const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0];
 
@@ -245,14 +350,19 @@
             double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i);
             npy_intp Sdx = %(dx)s->strides[1]/sizeof(double);
 
-            for (size_t j = 0; j < shape[1]; ++j)
+            for (size_t j = 0; j < %(dx)s->dimensions[1]; ++j)
             {
                 dx_i[j * Sdx] = dnll_i * sm_i[j * Ssm];
             }
-            if (y_i >= shape[1])
+            if (y_i >= %(dx)s->dimensions[1])
             {
                 %(fail)s;
             }
             dx_i[y_i * Sdx] -= dnll_i;
         }
         """ % dict(locals(), **sub)
+
+def crossentropy_softmax_1hot(x, y_idx, **kwargs):
+    b = tensor.zeros_like(x[0,:])
+    return crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs)
+