comparison nnet_ops.py @ 32:039c0f249859

added C impl for softmax dx
author bergstrj@iro.umontreal.ca
date Sun, 13 Apr 2008 00:21:11 -0400
parents bf0145fa73e8
children 1b152f46ad0c
comparison
equal deleted inserted replaced
31:2d6c49ec5749 32:039c0f249859
99 # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx 99 # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
100 100
101 #TODO: put this into a templated function, in the support code 101 #TODO: put this into a templated function, in the support code
102 #TODO: declare the max of each row as an Op output 102 #TODO: declare the max of each row as an Op output
103 103
104 #TODO: set error messages for failures in this code
105
104 return """ 106 return """
105 npy_intp* Nx = %(x)s->dimensions; 107 npy_intp* Nx = %(x)s->dimensions;
106 assert(%(x)s->dimensions[1] == %(b)s->dimensions[0]); 108 if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s}
107 assert(%(sm)s->dimensions[0] == %(x)s->dimensions[0]); 109 if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s}
108 assert(%(sm)s->dimensions[1] == %(x)s->dimensions[1]); 110 if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s}
109 111
110 for (size_t i = 0; i < Nx[0]; ++i) 112 for (size_t i = 0; i < Nx[0]; ++i)
111 { 113 {
112 size_t j; 114 size_t j;
113 double sum = 0.0; 115 double sum = 0.0;
147 149
148 double sm_ij = exp(row_ij - row_max); 150 double sm_ij = exp(row_ij - row_max);
149 sum += sm_ij; 151 sum += sm_ij;
150 sm_i[j * Ssm] = sm_ij; 152 sm_i[j * Ssm] = sm_ij;
151 } 153 }
152 assert( (0.0 != sum) && (!isinf(sum))); //that was our best... 154 if ( (0.0 == sum) || (isinf(sum)))
155 {
156 //that was our best...
157 %(fail)s;
158 }
153 //if we still can't sum it up, we're screwed. 159 //if we still can't sum it up, we're screwed.
154 //So far, this assertion has never failed... 160 //So far, this assertion has never failed...
155 } 161 }
156 162
157 //cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n); 163 //cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
159 for (j = 0; j < Nx[1]; ++j) 165 for (j = 0; j < Nx[1]; ++j)
160 { 166 {
161 sm_i[j * Ssm] *= sum_inv; 167 sm_i[j * Ssm] *= sum_inv;
162 } 168 }
163 169
164 assert(y_i < Nx[1]); 170 if (y_i >= Nx[1])
171 {
172 %(fail)s;
173 }
165 174
166 nll_i[0] = - x_i[y_i*Sx] 175 nll_i[0] = - x_i[y_i*Sx]
167 - b_i[y_i*Sb] 176 - b_i[y_i*Sb]
168 + (discount_max ? row_max : 0.0) 177 + (discount_max ? row_max : 0.0)
169 + log(sum); 178 + log(sum);
193 dx[i] = dy[i] * sm[i] #vector scale 202 dx[i] = dy[i] * sm[i] #vector scale
194 dx[i, y_idx[i]] -= dy[i] #scalar decrement 203 dx[i, y_idx[i]] -= dy[i] #scalar decrement
195 self.outputs[0].data = dx 204 self.outputs[0].data = dx
196 def grad(self, *args): 205 def grad(self, *args):
197 raise NotImplementedError() 206 raise NotImplementedError()
198 207 def c_validate_update(self, (dnll, sm, y_idx), (dx,), sub):
199 #TODO: write a version of CrossentropySoftmax1Hot that accepts a bias for x, if 208 """Allocate output storage"""
200 # this op needs to be faster. 209 return """
201 210 if (%(dnll)s->nd != 1) { %(fail)s }
211 if (%(sm)s->nd != 2) { %(fail)s }
212 if (%(y_idx)s->nd != 1) { %(fail)s }
213 if (%(dnll)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
214 if (%(sm)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
215 if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}
216
217 %(dx)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(sm)s), type_num_%(sm)s);
218 if(!%(dx)s){%(fail)s}
219
220 """ % dict(locals(), **sub)
221 def c_validate_cleanup(self, inputs, outputs, sub):
222 """Not sure..."""
223 return ""
224 def c_support_code(self):
225 return """
226 """
227 def c_code(self, (dnll, sm, y_idx), (dx,), sub):
228 return """
229 npy_intp* shape = %(dx)s->dimensions;
230 if (%(dnll)s->dimensions[0] != %(sm)s->dimensions[0]) {%(fail)s}
231 if (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0]) {%(fail)s}
232 if (%(dnll)s->dimensions[0] != %(dx)s->dimensions[0]) {%(fail)s}
233
234 if (%(sm)s->dimensions[1] != %(dx)s->dimensions[1]) {%(fail)s}
235
236 for (size_t i = 0; i < shape[0]; ++i)
237 {
238 const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0];
239
240 const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
241
242 const double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
243 npy_intp Ssm = %(sm)s->strides[1]/sizeof(double);
244
245 double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i);
246 npy_intp Sdx = %(dx)s->strides[1]/sizeof(double);
247
248 for (size_t j = 0; j < shape[1]; ++j)
249 {
250 dx_i[j * Sdx] = dnll_i * sm_i[j * Ssm];
251 }
252 if (y_i >= shape[1])
253 {
254 %(fail)s;
255 }
256 dx_i[y_i * Sdx] -= dnll_i;
257 }
258 """ % dict(locals(), **sub)