Mercurial > pylearn
comparison nnet_ops.py @ 185:3d953844abd3
support for more int types in crossentropysoftmax1hot
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 13 May 2008 19:37:29 -0400 |
parents | 9a2aecc57a79 |
children | 44dd9b6448c5 |
comparison
equal
deleted
inserted
replaced
184:9a2aecc57a79 | 185:3d953844abd3 |
---|---|
99 raise ValueError('x must be 2-d tensor of floats') | 99 raise ValueError('x must be 2-d tensor of floats') |
100 if b.type.ndim != 1 \ | 100 if b.type.ndim != 1 \ |
101 or x.type.dtype not in ['float32', 'float64']: | 101 or x.type.dtype not in ['float32', 'float64']: |
102 raise ValueError('b must be 1-d tensor of floats') | 102 raise ValueError('b must be 1-d tensor of floats') |
103 if y_idx.type.ndim != 1 \ | 103 if y_idx.type.ndim != 1 \ |
104 or y_idx.type.dtype not in ['int32', 'int64']: | 104 or y_idx.type.dtype not in ['int8', 'int16', 'int32', 'int64']: |
105 raise ValueError('y_idx must be 1-d tensor of ints') | 105 raise ValueError('y_idx must be 1-d tensor of ints') |
106 | 106 |
107 # TODO: Is this correct? It used to be y, not y_idx | 107 # TODO: Is this correct? It used to be y, not y_idx |
108 nll = tensor.Tensor(x.type.dtype, | 108 nll = tensor.Tensor(x.type.dtype, |
109 y_idx.type.broadcastable).make_result() | 109 y_idx.type.broadcastable).make_result() |
110 # nll = Tensor(x.dtype, y.broadcastable) | 110 # nll = Tensor(x.dtype, y.broadcastable) |
111 sm = x.type.make_result() | 111 sm = x.type.make_result() |
112 return theano.Apply(self, [x, b, y_idx],[nll, sm]) | 112 return theano.Apply(self, [x, b, y_idx], [nll, sm]) |
113 def perform(self, node, input_storage, output_storage): | 113 def perform(self, node, input_storage, output_storage): |
114 x, b, y_idx = input_storage | 114 x, b, y_idx = input_storage |
115 if b.shape[0] != x.shape[1]: | 115 if b.shape[0] != x.shape[1]: |
116 raise ValueError('b must have same number of columns as x') | 116 raise ValueError('b must have same number of columns as x') |
117 if y_idx.shape[0] != x.shape[0]: | 117 if y_idx.shape[0] != x.shape[0]: |
143 #TODO: declare the max of each row as an Op output | 143 #TODO: declare the max of each row as an Op output |
144 | 144 |
145 #TODO: set error messages for failures in this code | 145 #TODO: set error messages for failures in this code |
146 | 146 |
147 #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] | 147 #TODO: use this to accept float32 and int32: node.inputs[0].type.dtype_specs()[1] |
148 y_idx_type = node.inputs[2].type.dtype_specs()[1] | |
148 | 149 |
149 return """ | 150 return """ |
150 npy_intp* Nx = %(x)s->dimensions; | 151 npy_intp* Nx = %(x)s->dimensions; |
151 | 152 |
152 if (%(x)s->nd != 2) | 153 if (%(x)s->nd != 2) |
172 if (%(b)s->descr->type_num != PyArray_DOUBLE) | 173 if (%(b)s->descr->type_num != PyArray_DOUBLE) |
173 { | 174 { |
174 PyErr_SetString(PyExc_TypeError, "b not float64"); | 175 PyErr_SetString(PyExc_TypeError, "b not float64"); |
175 %(fail)s; | 176 %(fail)s; |
176 } | 177 } |
177 if (%(y_idx)s->descr->type_num != PyArray_INT64) | 178 if ((%(y_idx)s->descr->type_num != PyArray_INT64) |
178 { | 179 && (%(y_idx)s->descr->type_num != PyArray_INT32) |
179 PyErr_SetString(PyExc_TypeError, "y_idx not int64"); | 180 && (%(y_idx)s->descr->type_num != PyArray_INT16) |
181 && (%(y_idx)s->descr->type_num != PyArray_INT8)) | |
182 { | |
183 PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64"); | |
180 %(fail)s; | 184 %(fail)s; |
181 } | 185 } |
182 if ((%(x)s->dimensions[1] != %(b)s->dimensions[0]) | 186 if ((%(x)s->dimensions[1] != %(b)s->dimensions[0]) |
183 || (%(x)s->dimensions[0] != %(y_idx)s->dimensions[0])) | 187 || (%(x)s->dimensions[0] != %(y_idx)s->dimensions[0])) |
184 { | 188 { |
217 double sum = 0.0; | 221 double sum = 0.0; |
218 bool discount_max = false; | 222 bool discount_max = false; |
219 | 223 |
220 const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i); | 224 const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i); |
221 const double* __restrict__ b_i = (double*)(%(b)s->data); | 225 const double* __restrict__ b_i = (double*)(%(b)s->data); |
222 const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; | 226 const %(y_idx_type)s y_i = ((%(y_idx_type)s*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; |
223 double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i); | 227 double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i); |
224 double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i); | 228 double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i); |
225 | 229 |
226 npy_intp Sx = %(x)s->strides[1]/sizeof(double); | 230 npy_intp Sx = %(x)s->strides[1]/sizeof(double); |
227 npy_intp Sb = %(b)s->strides[0]/sizeof(double); | 231 npy_intp Sb = %(b)s->strides[0]/sizeof(double); |
303 dx[i, y_idx[i]] -= dy[i] #scalar decrement | 307 dx[i, y_idx[i]] -= dy[i] #scalar decrement |
304 output_storage[0][0] = dx | 308 output_storage[0][0] = dx |
305 def grad(self, *args): | 309 def grad(self, *args): |
306 raise NotImplementedError() | 310 raise NotImplementedError() |
307 def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub): | 311 def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub): |
312 y_idx_type = node.inputs[2].type.dtype_specs()[1] | |
308 return """ | 313 return """ |
309 | 314 |
310 if ((%(dnll)s->descr->type_num != PyArray_DOUBLE) | 315 if ((%(dnll)s->descr->type_num != PyArray_DOUBLE) |
311 || (%(sm)s->descr->type_num != PyArray_DOUBLE) | 316 || (%(sm)s->descr->type_num != PyArray_DOUBLE) |
312 || (%(y_idx)s->descr->type_num != PyArray_INT64)) | 317 ) |
313 { | 318 { |
314 PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64"); | 319 PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64"); |
320 %(fail)s; | |
321 } | |
322 if ((%(y_idx)s->descr->type_num != PyArray_INT64) | |
323 && (%(y_idx)s->descr->type_num != PyArray_INT32) | |
324 && (%(y_idx)s->descr->type_num != PyArray_INT16) | |
325 && (%(y_idx)s->descr->type_num != PyArray_INT8)) | |
326 { | |
327 PyErr_SetString(PyExc_TypeError, "y_idx not int8, int16, int32, or int64"); | |
315 %(fail)s; | 328 %(fail)s; |
316 } | 329 } |
317 if ((%(dnll)s->nd != 1) | 330 if ((%(dnll)s->nd != 1) |
318 || (%(sm)s->nd != 2) | 331 || (%(sm)s->nd != 2) |
319 || (%(y_idx)s->nd != 1)) | 332 || (%(y_idx)s->nd != 1)) |
341 | 354 |
342 for (size_t i = 0; i < %(dx)s->dimensions[0]; ++i) | 355 for (size_t i = 0; i < %(dx)s->dimensions[0]; ++i) |
343 { | 356 { |
344 const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0]; | 357 const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0]; |
345 | 358 |
346 const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; | 359 const %(y_idx_type)s y_i = ((%(y_idx_type)s*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0]; |
347 | 360 |
348 const double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i); | 361 const double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i); |
349 npy_intp Ssm = %(sm)s->strides[1]/sizeof(double); | 362 npy_intp Ssm = %(sm)s->strides[1]/sizeof(double); |
350 | 363 |
351 double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i); | 364 double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i); |