comparison nnet_ops.py @ 185:3d953844abd3

support for more int types in crossentropysoftmax1hot
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 13 May 2008 19:37:29 -0400
parents 9a2aecc57a79
children 44dd9b6448c5
comparison
equal deleted inserted replaced
184:9a2aecc57a79 185:3d953844abd3
99 raise ValueError('x must be 2-d tensor of floats') 99 raise ValueError('x must be 2-d tensor of floats')
100 if b.type.ndim != 1 \ 100 if b.type.ndim != 1 \
101 or x.type.dtype not in ['float32', 'float64']: 101 or x.type.dtype not in ['float32', 'float64']:
102 raise ValueError('b must be 1-d tensor of floats') 102 raise ValueError('b must be 1-d tensor of floats')
103 if y_idx.type.ndim != 1 \ 103 if y_idx.type.ndim != 1 \
104 or y_idx.type.dtype not in ['int32', 'int64']: 104 or y_idx.type.dtype not in ['int8', 'int16', 'int32', 'int64']:
105 raise ValueError('y_idx must be 1-d tensor of ints') 105 raise ValueError('y_idx must be 1-d tensor of ints')
106 106
107 # TODO: Is this correct? It used to be y, not y_idx 107 # TODO: Is this correct? It used to be y, not y_idx
108 nll = tensor.Tensor(x.type.dtype, 108 nll = tensor.Tensor(x.type.dtype,
109 y_idx.type.broadcastable).make_result() 109 y_idx.type.broadcastable).make_result()
110 # nll = Tensor(x.dtype, y.broadcastable) 110 # nll = Tensor(x.dtype, y.broadcastable)
111 sm = x.type.make_result() 111 sm = x.type.make_result()
112 return theano.Apply(self, [x, b, y_idx],[nll, sm]) 112 return theano.Apply(self, [x, b, y_idx], [nll, sm])
113 def perform(self, node, input_storage, output_storage): 113 def perform(self, node, input_storage, output_storage):
114 x, b, y_idx = input_storage 114 x, b, y_idx = input_storage
115 if b.shape[0] != x.shape[1]: 115 if b.shape[0] != x.shape[1]:
116 raise ValueError('b must have same number of columns as x') 116 raise ValueError('b must have same number of columns as x')
117 if y_idx.shape[0] != x.shape[0]: 117 if y_idx.shape[0] != x.shape[0]:
143 #TODO: declare the max of each row as an Op output 143 #TODO: declare the max of each row as an Op output
144 144
145 #TODO: set error messages for failures in this code 145 #TODO: set error messages for failures in this code
146 146
147 #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] 147 #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1]
148 y_idx_type = node.inputs[2].type.dtype_specs()[1]
148 149
149 return """ 150 return """
150 npy_intp* Nx = %(x)s->dimensions; 151 npy_intp* Nx = %(x)s->dimensions;
151 152
152 if (%(x)s->nd != 2) 153 if (%(x)s->nd != 2)
172 if (%(b)s->descr->type_num != PyArray_DOUBLE) 173 if (%(b)s->descr->type_num != PyArray_DOUBLE)
173 { 174 {
174 PyErr_SetString(PyExc_TypeError, "b not float64"); 175 PyErr_SetString(PyExc_TypeError, "b not float64");
175 %(fail)s; 176 %(fail)s;
176 } 177 }
177 if (%(y_idx)s->descr->type_num != PyArray_INT64) 178 if ((%(y_idx)s->descr->type_num != PyArray_INT64)
178 { 179 && (%(y_idx)s->descr->type_num != PyArray_INT32)
179 PyErr_SetString(PyExc_TypeError, "y_idx not int64"); 180 && (%(y_idx)s->descr->type_num != PyArray_INT16)
181 && (%(y_idx)s->descr->type_num != PyArray_INT8))
182 {
183 PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64");
180 %(fail)s; 184 %(fail)s;
181 } 185 }
182 if ((%(x)s->dimensions[1] != %(b)s->dimensions[0]) 186 if ((%(x)s->dimensions[1] != %(b)s->dimensions[0])
183 || (%(x)s->dimensions[0] != %(y_idx)s->dimensions[0])) 187 || (%(x)s->dimensions[0] != %(y_idx)s->dimensions[0]))
184 { 188 {
217 double sum = 0.0; 221 double sum = 0.0;
218 bool discount_max = false; 222 bool discount_max = false;
219 223
220 const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i); 224 const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i);
221 const double* __restrict__ b_i = (double*)(%(b)s->data); 225 const double* __restrict__ b_i = (double*)(%(b)s->data);
222 const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; 226 const %(y_idx_type)s y_i = ((%(y_idx_type)s*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
223 double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i); 227 double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
224 double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i); 228 double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i);
225 229
226 npy_intp Sx = %(x)s->strides[1]/sizeof(double); 230 npy_intp Sx = %(x)s->strides[1]/sizeof(double);
227 npy_intp Sb = %(b)s->strides[0]/sizeof(double); 231 npy_intp Sb = %(b)s->strides[0]/sizeof(double);
303 dx[i, y_idx[i]] -= dy[i] #scalar decrement 307 dx[i, y_idx[i]] -= dy[i] #scalar decrement
304 output_storage[0][0] = dx 308 output_storage[0][0] = dx
305 def grad(self, *args): 309 def grad(self, *args):
306 raise NotImplementedError() 310 raise NotImplementedError()
307 def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub): 311 def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub):
312 y_idx_type = node.inputs[2].type.dtype_specs()[1]
308 return """ 313 return """
309 314
310 if ((%(dnll)s->descr->type_num != PyArray_DOUBLE) 315 if ((%(dnll)s->descr->type_num != PyArray_DOUBLE)
311 || (%(sm)s->descr->type_num != PyArray_DOUBLE) 316 || (%(sm)s->descr->type_num != PyArray_DOUBLE)
312 || (%(y_idx)s->descr->type_num != PyArray_INT64)) 317 )
313 { 318 {
314 PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64"); 319 PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64");
320 %(fail)s;
321 }
322 if ((%(y_idx)s->descr->type_num != PyArray_INT64)
323 && (%(y_idx)s->descr->type_num != PyArray_INT32)
324 && (%(y_idx)s->descr->type_num != PyArray_INT16)
325 && (%(y_idx)s->descr->type_num != PyArray_INT8))
326 {
327 PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64");
315 %(fail)s; 328 %(fail)s;
316 } 329 }
317 if ((%(dnll)s->nd != 1) 330 if ((%(dnll)s->nd != 1)
318 || (%(sm)s->nd != 2) 331 || (%(sm)s->nd != 2)
319 || (%(y_idx)s->nd != 1)) 332 || (%(y_idx)s->nd != 1))
341 354
342 for (size_t i = 0; i < %(dx)s->dimensions[0]; ++i) 355 for (size_t i = 0; i < %(dx)s->dimensions[0]; ++i)
343 { 356 {
344 const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0]; 357 const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0];
345 358
346 const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; 359 const %(y_idx_type)s y_i = ((%(y_idx_type)s*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
347 360
348 const double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i); 361 const double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
349 npy_intp Ssm = %(sm)s->strides[1]/sizeof(double); 362 npy_intp Ssm = %(sm)s->strides[1]/sizeof(double);
350 363
351 double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i); 364 double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i);