Mercurial > pylearn
diff nnet_ops.py @ 185:3d953844abd3
support for more int types in crossentropysoftmax1hot
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 13 May 2008 19:37:29 -0400 |
parents | 9a2aecc57a79 |
children | 44dd9b6448c5 |
line wrap: on
line diff
--- a/nnet_ops.py Tue May 13 18:39:58 2008 -0400 +++ b/nnet_ops.py Tue May 13 19:37:29 2008 -0400 @@ -101,7 +101,7 @@ or x.type.dtype not in ['float32', 'float64']: raise ValueError('b must be 1-d tensor of floats') if y_idx.type.ndim != 1 \ - or y_idx.type.dtype not in ['int32', 'int64']: + or y_idx.type.dtype not in ['int8', 'int16', 'int32', 'int64']: raise ValueError('y_idx must be 1-d tensor of ints') # TODO: Is this correct? It used to be y, not y_idx @@ -109,7 +109,7 @@ y_idx.type.broadcastable).make_result() # nll = Tensor(x.dtype, y.broadcastable) sm = x.type.make_result() - return theano.Apply(self, [x, b, y_idx],[nll, sm]) + return theano.Apply(self, [x, b, y_idx], [nll, sm]) def perform(self, node, input_storage, output_storage): x, b, y_idx = input_storage if b.shape[0] != x.shape[1]: @@ -145,6 +145,7 @@ #TODO: set error messages for failures in this code #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] + y_idx_type = node.inputs[2].type.dtype_specs()[1] return """ npy_intp* Nx = %(x)s->dimensions; @@ -174,9 +175,12 @@ PyErr_SetString(PyExc_TypeError, "b not float64"); %(fail)s; } - if (%(y_idx)s->descr->type_num != PyArray_INT64) + if ((%(y_idx)s->descr->type_num != PyArray_INT64) + && (%(y_idx)s->descr->type_num != PyArray_INT32) + && (%(y_idx)s->descr->type_num != PyArray_INT16) + && (%(y_idx)s->descr->type_num != PyArray_INT8)) { - PyErr_SetString(PyExc_TypeError, "y_idx not int64"); + PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64"); %(fail)s; } if ((%(x)s->dimensions[1] != %(b)s->dimensions[0]) @@ -219,7 +223,7 @@ const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i); const double* __restrict__ b_i = (double*)(%(b)s->data); - const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; + const %(y_idx_type)s y_i = ((%(y_idx_type)s*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i); double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i); @@ -305,15 +309,24 @@ def grad(self, *args): raise NotImplementedError() def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub): + y_idx_type = node.inputs[2].type.dtype_specs()[1] return """ if ((%(dnll)s->descr->type_num != PyArray_DOUBLE) || (%(sm)s->descr->type_num != PyArray_DOUBLE) - || (%(y_idx)s->descr->type_num != PyArray_INT64)) + ) { PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64"); %(fail)s; } + if ((%(y_idx)s->descr->type_num != PyArray_INT64) + && (%(y_idx)s->descr->type_num != PyArray_INT32) + && (%(y_idx)s->descr->type_num != PyArray_INT16) + && (%(y_idx)s->descr->type_num != PyArray_INT8)) + { + PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64"); + %(fail)s; + } if ((%(dnll)s->nd != 1) || (%(sm)s->nd != 2) || (%(y_idx)s->nd != 1)) @@ -343,7 +356,7 @@ { const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0]; - const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; + const %(y_idx_type)s y_i = ((%(y_idx_type)s*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; const double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i); npy_intp Ssm = %(sm)s->strides[1]/sizeof(double);