comparison nnet_ops.py @ 30:bf0145fa73e8

added c implementation for CrossentropySoftmax1Hot
author bergstrj@iro.umontreal.ca
date Fri, 11 Apr 2008 21:41:09 -0400
parents b63e8c0bf21b
children 039c0f249859
comparison
equal deleted inserted replaced
27:e6c550cb2896 30:bf0145fa73e8
27 y[i] is an integer index, encoding a 1-hot distribution 27 y[i] is an integer index, encoding a 1-hot distribution
28 28
29 """ 29 """
30 nin=2 30 nin=2
31 nout=2 31 nout=2
32 def __init__(self, x, y_idx,**kwargs): 32 def __init__(self, x, b, y_idx, **kwargs):
33 x = tensor._as_tensor(x) 33 x = tensor._as_tensor(x)
34 b = tensor._as_tensor(b)
34 y_idx = tensor._as_tensor(y_idx) 35 y_idx = tensor._as_tensor(y_idx)
36 if len(x.broadcastable) != 2 \
37 or x.dtype not in ['float32', 'float64']:
38 raise ValueError('x must be 2-d tensor of floats')
39 if len(b.broadcastable) != 1 \
40 or x.dtype not in ['float32', 'float64']:
41 raise ValueError('x must be 1-d tensor of floats')
42 if len(y_idx.broadcastable) != 1 \
43 or y_idx.dtype not in ['int32', 'int64']:
44 raise ValueError('x must be 1-d tensor of ints')
45
35 # TODO: Is this correct? It used to be y, not y_idx 46 # TODO: Is this correct? It used to be y, not y_idx
36 nll = tensor.Tensor(x.dtype, y_idx.broadcastable) 47 nll = tensor.Tensor(x.dtype, y_idx.broadcastable)
37 # nll = Tensor(x.dtype, y.broadcastable) 48 # nll = Tensor(x.dtype, y.broadcastable)
38 sm = tensor.Tensor(x.dtype, x.broadcastable) 49 sm = tensor.Tensor(x.dtype, x.broadcastable)
39 self.inputs = [x, y_idx] 50 self.inputs = [x, b, y_idx]
40 self.outputs = [nll,sm] 51 self.outputs = [nll, sm]
41 def perform(self): 52 def perform(self):
42 x, y_idx = [i.data for i in self.inputs] 53 x, b, y_idx = [i.data for i in self.inputs]
54 if b.shape[0] != x.shape[1]:
55 raise ValueError('b must have same shape as x[0]')
56
43 sm = numpy.zeros_like(x) # softmax 57 sm = numpy.zeros_like(x) # softmax
44 nll = numpy.zeros(x.shape[0]) #nll(y | softmax(x)) 58 nll = numpy.zeros(x.shape[0]) #nll(y | softmax(x))
45 for i in xrange(sm.shape[0]): 59 for i in xrange(sm.shape[0]):
46 sm[i] = numpy.exp(x[i] - numpy.max(x[i])) #softmax 60 row = x[i] + b
61 sm[i] = numpy.exp(row - numpy.max(row)) #softmax
47 sm[i] *= 1.0 / numpy.sum(sm[i]) #vector scale 62 sm[i] *= 1.0 / numpy.sum(sm[i]) #vector scale
48 nll[i] = -numpy.log( sm[i, y_idx[i]]) #cross-entropy 63 nll[i] = -numpy.log( sm[i, y_idx[i]]) #cross-entropy
49 self.outputs[0].data = nll 64 self.outputs[0].data = nll
50 self.outputs[1].data = sm 65 self.outputs[1].data = sm
51 def grad(self, (x, y_idx), (g_nll, g_sm)): 66 def grad(self, (x, b, y_idx), (g_nll, g_sm)):
52 if g_sm is not None: 67 if g_sm is not None:
53 raise NotImplementedError() 68 raise NotImplementedError()
54 nll, sm = crossentropy_softmax_1hot(x, y_idx) 69 nll, sm = crossentropy_softmax_1hot(x, b, y_idx)
55 dx = CrossentropySoftmax1Hot.Dx(g_nll, sm, y_idx).outputs[0] 70 dx = CrossentropySoftmax1HotDx(g_nll, sm, y_idx).outputs[0]
56 return dx, None 71 db = tensor.Sum(dx, axis = [0]).outputs[0]
57 72 return dx, db, None
58 class Dx (gof.op.Op): 73
59 nin=3 74 def c_validate_update(self, (x, b, y_idx), (nll, sm), sub):
60 nout=1 75 """Allocate output storage"""
61 """Gradient wrt x of the CrossentropySoftmax1Hot Op""" 76 return """
62 def __init__(self, dy, sm, y_idx,**kwargs): 77 if (%(x)s->nd != 2) { %(fail)s }
63 dy = tensor._as_tensor(dy) 78 if (%(b)s->nd != 1) { %(fail)s }
64 sm = tensor._as_tensor(sm) 79 if (%(y_idx)s->nd != 1) { %(fail)s }
65 y_idx = tensor._as_tensor(y_idx) 80 if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
66 self.inputs = [dy, sm, y_idx] 81 if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
67 self.outputs = [tensor.Tensor(sm.dtype, sm.broadcastable)] 82 if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}
68 def perform(self): 83
69 dy,sm,y_idx = [i.data for i in self.inputs] 84 %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s);
70 dx = numpy.zeros_like(sm) 85 if(!%(nll)s){%(fail)s}
71 for i in xrange(sm.shape[0]): 86
72 dx[i] = dy[i] * sm[i] #vector scale 87 %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s);
73 dx[i, y_idx[i]] -= dy[i] #scalar decrement 88 if(!%(sm)s){Py_XDECREF(%(nll)s); %(fail)s}
74 self.outputs[0].data = dx 89
75 def grad(self, *args): 90 """ % dict(locals(), **sub)
76 raise NotImplementedError() 91 def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub):
92 """Not sure..."""
93 return ""
94 def c_support_code(self):
95 return """
96 """
97 def c_code(self, (x, b, y_idx), (nll, sm), sub):
98 # this implementation was lifted from
99 # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
100
101 #TODO: put this into a templated function, in the support code
102 #TODO: declare the max of each row as an Op output
103
104 return """
105 npy_intp* Nx = %(x)s->dimensions;
106 assert(%(x)s->dimensions[1] == %(b)s->dimensions[0]);
107 assert(%(sm)s->dimensions[0] == %(x)s->dimensions[0]);
108 assert(%(sm)s->dimensions[1] == %(x)s->dimensions[1]);
109
110 for (size_t i = 0; i < Nx[0]; ++i)
111 {
112 size_t j;
113 double sum = 0.0;
114 bool discount_max = false;
115
116 const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i);
117 const double* __restrict__ b_i = (double*)(%(b)s->data);
118 const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
119 double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
120 double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i);
121
122 npy_intp Sx = %(x)s->strides[1]/sizeof(double);
123 npy_intp Sb = %(b)s->strides[0]/sizeof(double);
124 npy_intp Ssm = %(sm)s->strides[1]/sizeof(double);
125
126 size_t row_max_j=0;
127 double row_max = x_i[0] + b_i[0];
128 //try to compute sum and sm the easy way
129 for (j = 0; j < Nx[1]; ++j)
130 {
131 double row_ij = x_i[j * Sx] + b_i[j * Sb];
132 row_max_j = (row_ij > row_max) ? j : row_max_j;
133 row_max = (row_ij > row_max) ? row_ij : row_max;
134
135 double sm_ij = exp(row_ij);
136 sum += sm_ij;
137 sm_i[j * Ssm] = sm_ij;
138 }
139 if ((0.0 == sum) || (isinf(sum)))
140 {
141 //our cheap trick didn't work... try again and do it better.
142 discount_max = true;
143 sum = 0.0; //reset sum and recompute....
144 for (j = 0; j < Nx[1]; ++j)
145 {
146 double row_ij = x_i[j * Sx] + b_i[j * Sb];
147
148 double sm_ij = exp(row_ij - row_max);
149 sum += sm_ij;
150 sm_i[j * Ssm] = sm_ij;
151 }
152 assert( (0.0 != sum) && (!isinf(sum))); //that was our best...
153 //if we still can't sum it up, we're screwed.
154 //So far, this assertion has never failed...
155 }
156
157 //cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
158 double sum_inv = 1.0 / sum;
159 for (j = 0; j < Nx[1]; ++j)
160 {
161 sm_i[j * Ssm] *= sum_inv;
162 }
163
164 assert(y_i < Nx[1]);
165
166 nll_i[0] = - x_i[y_i*Sx]
167 - b_i[y_i*Sb]
168 + (discount_max ? row_max : 0.0)
169 + log(sum);
170 //mat_at(y,i,0) = -log( mat_at(s,i,t[i])); //less accurate?
171 //mat_at(y,i,0) = - mat_at(x,i,t[i]) - mat_at(b,0,t[i]) + (discount_max ? maxi : 0.0) + log(sum);
172 }
173 """ % dict(locals(), **sub)
174
175
176
77 crossentropy_softmax_1hot = gof.op.constructor(CrossentropySoftmax1Hot) 177 crossentropy_softmax_1hot = gof.op.constructor(CrossentropySoftmax1Hot)
178
179 class CrossentropySoftmax1HotDx (gof.op.Op):
180 nin=3
181 nout=1
182 """Gradient wrt x of the CrossentropySoftmax1Hot Op"""
183 def __init__(self, dy, sm, y_idx,**kwargs):
184 dy = tensor._as_tensor(dy)
185 sm = tensor._as_tensor(sm)
186 y_idx = tensor._as_tensor(y_idx)
187 self.inputs = [dy, sm, y_idx]
188 self.outputs = [tensor.Tensor(sm.dtype, sm.broadcastable)]
189 def perform(self):
190 dy,sm,y_idx = [i.data for i in self.inputs]
191 dx = numpy.zeros_like(sm)
192 for i in xrange(sm.shape[0]):
193 dx[i] = dy[i] * sm[i] #vector scale
194 dx[i, y_idx[i]] -= dy[i] #scalar decrement
195 self.outputs[0].data = dx
196 def grad(self, *args):
197 raise NotImplementedError()
78 198
79 #TODO: write a version of CrossentropySoftmax1Hot that accepts a bias for x, if 199 #TODO: write a version of CrossentropySoftmax1Hot that accepts a bias for x, if
80 # this op needs to be faster. 200 # this op needs to be faster.
81 201