comparison nnet_ops.py @ 181:1b06bc2c3ca9

fixed c_code for the ops in nnet_ops.py
author Olivier Breuleux <breuleuo@iro.umontreal.ca>
date Tue, 13 May 2008 15:49:39 -0400
parents 2ca8dccba270
children 9a2aecc57a79
comparison
equal deleted inserted replaced
180:2698c0feeb54 181:1b06bc2c3ca9
18 def impl(self, x): 18 def impl(self, x):
19 return ScalarSigmoid.st_impl(x) 19 return ScalarSigmoid.st_impl(x)
20 def grad(self, (x,), (gz,)): 20 def grad(self, (x,), (gz,)):
21 y = scalar_sigmoid(x) 21 y = scalar_sigmoid(x)
22 return [gz * y * (1.0 - y)] 22 return [gz * y * (1.0 - y)]
23 def c_code(self, (x,), (z,), sub): 23 def c_code(self, node, name, (x,), (z,), sub):
24 if 'float' in self.inputs[0].dtype: 24 if node.inputs[0].type in [scalar.float32, scalar.float64]:
25 return """%(z)s = 25 return """%(z)s =
26 %(x)s < -30.0 26 %(x)s < -30.0
27 ? 0.0 27 ? 0.0
28 : %(x)s > 30.0 28 : %(x)s > 30.0
29 ? 1.0 29 ? 1.0
30 : 1.0 /(1.0+exp(-%(x)s));""" % locals() 30 : 1.0 /(1.0+exp(-%(x)s));""" % locals()
31 return NotImplemented#Error('only floatingpoint is implemented') 31 raise NotImplementedError('only floatingpoint is implemented')
32 scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid') 32 scalar_sigmoid = ScalarSigmoid(scalar.upgrade_to_float, name='scalar_sigmoid')
33 sigmoid = tensor.Elemwise(scalar_sigmoid, name='sigmoid') 33 sigmoid = tensor.Elemwise(scalar_sigmoid, name='sigmoid')
34 34
35 class ScalarSoftplus(scalar.UnaryScalarOp): 35 class ScalarSoftplus(scalar.UnaryScalarOp):
36 @staticmethod 36 @staticmethod
42 return numpy.log1p(numpy.exp(x)) 42 return numpy.log1p(numpy.exp(x))
43 def impl(self, x): 43 def impl(self, x):
44 return ScalarSoftplus.static_impl(x) 44 return ScalarSoftplus.static_impl(x)
45 def grad(self, (x,), (gz,)): 45 def grad(self, (x,), (gz,)):
46 return [gz * scalar_sigmoid(x)] 46 return [gz * scalar_sigmoid(x)]
47 def c_code(self, (x,), (z,), sub): 47 def c_code(self, name, node, (x,), (z,), sub):
48 if 'float' in self.inputs[0].dtype: 48 if node.inputs[0].type in [scalar.float32, scalar.float64]:
49 return """%(z)s = 49 return """%(z)s =
50 %(x)s < -30.0 50 %(x)s < -30.0
51 ? 0.0 51 ? 0.0
52 : %(x)s > 30.0 52 : %(x)s > 30.0
53 ? %(x)s 53 ? %(x)s
54 : log1p(exp(%(x)s));""" % locals() 54 : log1p(exp(%(x)s));""" % locals()
55 return NotImplemented#Error('only floating point x is implemented') 55 raise NotImplementedError('only floating point x is implemented')
56 scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus') 56 scalar_softplus = ScalarSoftplus(scalar.upgrade_to_float, name='scalar_softplus')
57 softplus = tensor.Elemwise(scalar_softplus, name='softplus') 57 softplus = tensor.Elemwise(scalar_softplus, name='softplus')
58 58
59 59
60 ############ 60 ############
133 dx = CrossentropySoftmax1HotWithBiasDx()(g_nll, sm, y_idx) 133 dx = CrossentropySoftmax1HotWithBiasDx()(g_nll, sm, y_idx)
134 db = tensor.sum(dx, axis = [0]) 134 db = tensor.sum(dx, axis = [0])
135 return dx, db, None 135 return dx, db, None
136 136
137 def c_headers(self): return ['<iostream>'] 137 def c_headers(self): return ['<iostream>']
138 def c_code(self, (x, b, y_idx), (nll, sm), sub): 138 def c_code(self, node, name, (x, b, y_idx), (nll, sm), sub):
139 # this implementation was lifted from 139 # this implementation was lifted from
140 # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx 140 # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
141 141
142 #TODO: put this into a templated function, in the support code 142 #TODO: put this into a templated function, in the support code
143 #TODO: declare the max of each row as an Op output 143 #TODO: declare the max of each row as an Op output
300 dx[i] = dy[i] * sm[i] #vector scale 300 dx[i] = dy[i] * sm[i] #vector scale
301 dx[i, y_idx[i]] -= dy[i] #scalar decrement 301 dx[i, y_idx[i]] -= dy[i] #scalar decrement
302 output_storage[0][0] = dx 302 output_storage[0][0] = dx
303 def grad(self, *args): 303 def grad(self, *args):
304 raise NotImplementedError() 304 raise NotImplementedError()
305 def c_code(self, (dnll, sm, y_idx), (dx,), sub): 305 def c_code(self, node, name, (dnll, sm, y_idx), (dx,), sub):
306 return """ 306 return """
307 307
308 if ((%(dnll)s->descr->type_num != PyArray_DOUBLE) 308 if ((%(dnll)s->descr->type_num != PyArray_DOUBLE)
309 || (%(sm)s->descr->type_num != PyArray_DOUBLE) 309 || (%(sm)s->descr->type_num != PyArray_DOUBLE)
310 || (%(y_idx)s->descr->type_num != PyArray_INT64)) 310 || (%(y_idx)s->descr->type_num != PyArray_INT64))