changeset 185:3d953844abd3

support for more int types in crossentropysoftmax1hot
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 13 May 2008 19:37:29 -0400
parents 9a2aecc57a79
children 562f308873f0
files nnet_ops.py
diffstat 1 files changed, 20 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/nnet_ops.py	Tue May 13 18:39:58 2008 -0400
+++ b/nnet_ops.py	Tue May 13 19:37:29 2008 -0400
@@ -101,7 +101,7 @@
                 or x.type.dtype not in ['float32', 'float64']:
             raise ValueError('b must be 1-d tensor of floats')
         if y_idx.type.ndim != 1 \
-                or y_idx.type.dtype not in ['int32', 'int64']:
+                or y_idx.type.dtype not in ['int8', 'int16', 'int32', 'int64']:
             raise ValueError('y_idx must be 1-d tensor of ints')
 
 #       TODO: Is this correct? It used to be y, not y_idx
@@ -109,7 +109,7 @@
                 y_idx.type.broadcastable).make_result()
 #        nll = Tensor(x.dtype, y.broadcastable)
         sm = x.type.make_result()
-        return theano.Apply(self, [x, b, y_idx],[nll, sm])
+        return theano.Apply(self, [x, b, y_idx], [nll, sm])
     def perform(self, node, input_storage, output_storage):
         x, b, y_idx = input_storage
         if b.shape[0] != x.shape[1]:
@@ -145,6 +145,7 @@
         #TODO: set error messages for failures in this code
 
         #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
+        y_idx_type = node.inputs[2].type.dtype_specs()[1]
 
         return """
         npy_intp* Nx = %(x)s->dimensions;
@@ -174,9 +175,12 @@
             PyErr_SetString(PyExc_TypeError, "b not float64");
             %(fail)s;
         }
-        if (%(y_idx)s->descr->type_num != PyArray_INT64)
+        if ((%(y_idx)s->descr->type_num != PyArray_INT64)
+            && (%(y_idx)s->descr->type_num != PyArray_INT32)
+            && (%(y_idx)s->descr->type_num != PyArray_INT16)
+            && (%(y_idx)s->descr->type_num != PyArray_INT8))
         {
-            PyErr_SetString(PyExc_TypeError, "y_idx not int64");
+            PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64");
             %(fail)s;
         }
         if ((%(x)s->dimensions[1] != %(b)s->dimensions[0])
@@ -219,7 +223,7 @@
 
             const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i);
             const double* __restrict__ b_i = (double*)(%(b)s->data);
-            const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
+            const %(y_idx_type)s y_i = ((%(y_idx_type)s*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
             double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
             double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i);
 
@@ -305,15 +309,24 @@
     def grad(self, *args):
         raise NotImplementedError()
     def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub):
+        y_idx_type = node.inputs[2].type.dtype_specs()[1]
         return """
 
         if ((%(dnll)s->descr->type_num != PyArray_DOUBLE)
             || (%(sm)s->descr->type_num != PyArray_DOUBLE)
-            || (%(y_idx)s->descr->type_num != PyArray_INT64))
+            )
         {
             PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64");
             %(fail)s;
         }
+        if ((%(y_idx)s->descr->type_num != PyArray_INT64)
+            && (%(y_idx)s->descr->type_num != PyArray_INT32)
+            && (%(y_idx)s->descr->type_num != PyArray_INT16)
+            && (%(y_idx)s->descr->type_num != PyArray_INT8))
+        {
+            PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64");
+            %(fail)s;
+        }
         if ((%(dnll)s->nd != 1)
             || (%(sm)s->nd != 2)
             || (%(y_idx)s->nd != 1))
@@ -343,7 +356,7 @@
         {
             const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0];
 
-            const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
+            const %(y_idx_type)s y_i = ((%(y_idx_type)s*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
 
             const double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
             npy_intp Ssm = %(sm)s->strides[1]/sizeof(double);