# HG changeset patch # User James Bergstra # Date 1209766775 14400 # Node ID 5b699b31770a5e8153b93e2b882dbb54480493b8 # Parent 76e5c0f371651b9053c7389b532402d1d8c8dcd0# Parent dde1fb1b63ba516c2d48768603ab140f63c90049 merge diff -r dde1fb1b63ba -r 5b699b31770a _nnet_ops.py --- a/_nnet_ops.py Fri May 02 11:24:17 2008 -0400 +++ b/_nnet_ops.py Fri May 02 18:19:35 2008 -0400 @@ -11,6 +11,11 @@ def test_elemwise(self): TT.verify_grad(self, Sigmoid, [numpy.random.rand(3,4)]) +class T_softplus(unittest.TestCase): + def setUp(self): + numpy.random.seed(9999) + def test_elemwise(self): + TT.verify_grad(self, Softplus, [numpy.random.rand(3,4)]) class T_CrossentropySoftmax1Hot(unittest.TestCase): def setUp(self): @@ -18,10 +23,16 @@ def test0(self): y_idx = [0,1,3] def output1(a,b): - return crossentropy_softmax_1hot(a, b, y_idx)[0:1] + return crossentropy_softmax_1hot_with_bias(a, b, y_idx)[0:1] TT.verify_grad(self, output1, [numpy.random.rand(3,4), numpy.random.rand(4)]) + def test1(self): + y_idx = [0,1,3] + def output1(a): + return crossentropy_softmax_1hot(a, y_idx)[0:1] + TT.verify_grad(self, output1, [numpy.random.rand(3,4)]) + if __name__ == '__main__': diff -r dde1fb1b63ba -r 5b699b31770a nnet_ops.py --- a/nnet_ops.py Fri May 02 11:24:17 2008 -0400 +++ b/nnet_ops.py Fri May 02 18:19:35 2008 -0400 @@ -2,32 +2,93 @@ from theano import tensor, gof, scalar import numpy -class ScalarSigmoid(scalar.UnaryScalarOp): +############ +# +# SCALAR OPS +# + +class ScalarSigmoid(scalar.FloatUnaryScalarOp): + @staticmethod + def st_impl(x): + if x < -30.0: + return 0.0 + if x > 30.0: + return 1.0 + return 1.0 / (1.0 + numpy.exp(-x)) def impl(self, x): - return 1.0 / (1 + numpy.exp(-x)) + return ScalarSigmoid.st_impl(x) def grad(self, (x,), (gz,)): - return gz * scalar_sigmoid(x) * (1.0 - scalar_sigmoid(x)), - def c_foreach(self, (x,), (z,)): - return "%(z)s = 1.0 / (1 + exp(-%(x)s));" % locals() + y = scalar_sigmoid(x) + return [gz * y * (1.0 - y)] + def c_foreach(self, (x,), (z,), sub): + if 'float' in self.inputs[0].dtype: + return """%(z)s = + %(x)s < -30.0 + ? 0.0 + : %(x)s > 30.0 + ? 1.0 + : 1.0 /(1.0+exp(-%(x)s));""" % locals() + raise NotImplementedError('only floatingpoint is implemented') scalar_sigmoid = gof.op.constructor(ScalarSigmoid) -Sigmoid, sigmoid, SigmoidInplace, sigmoid_inplace \ - = theano.tensor.broadcast(ScalarSigmoid, 'Sigmoid') +Sigmoid, sigmoid, SigmoidInplace, sigmoid_inplace =\ + tensor.broadcast(ScalarSigmoid, 'Sigmoid') +class ScalarSoftplus(scalar.FloatUnaryScalarOp): + @staticmethod + def static_impl(x): + if x < -30.0: + return 0.0 + if x > 30.0: + return x + return numpy.log1p(numpy.exp(x)) + def impl(self, x): + return ScalarSoftplus.static_impl(x) + def grad(self, (x,), (gz,)): + return [gz * scalar_sigmoid(x)] + def c_foreach(self, (x,), (z,), sub): + if 'float' in self.inputs[0].dtype: + return """%(z)s = + %(x)s < -30.0 + ? 0.0 + : %(x)s > 30.0 + ? %(x)s + : log1p(exp(%(x)s));""" % locals() + raise NotImplementedError('only floating point x is implemented') +scalar_softplus = gof.op.constructor(ScalarSoftplus) +Softplus, softplus, SoftplusInplace, softplus_inplace =\ + tensor.broadcast(ScalarSoftplus, 'Softplus') -class CrossentropySoftmax1Hot(gof.op.Op): - """A special compound Op for the output of neural-net classifiers. +############ +# +# TENSOR OPS +# + + +class CrossentropySoftmax1HotWithBias(gof.op.Op): + """A special compound L{Op} for the output of neural-net classifiers. + + @type x: is a matrix of floats (32 or 64) + @type b: is a [row] vector of floats (32 or 64), length is number of cols in x + @type y_idx: a [column] vector of int (32 or 64), length is number of rows in x - This Op has two outputs: - - KL(softmax(x), y) - - softmax(x) + @precondition: every entry in y_idx is a valid (non-negative) column index into x + + This L{Op} has two outputs: + - KL(softmax(x+b), y) + - softmax(x+b) - x[i] is assumed to be a dense vector + softmax(x[i]) is the i'th distribution over len(x[i]) options - y[i] is an integer index, encoding a 1-hot distribution + + y_idx[i] is an integer index, encoding a 1-hot distribution. + + In practice, when we're trying to do classification, we have one row in x + and y_idx per example, and y[i] is the index of the (correct) class of the + i'th example. """ - nin=2 + nin=3 nout=2 def __init__(self, x, b, y_idx, **kwargs): x = tensor._as_tensor(x) @@ -52,7 +113,9 @@ def perform(self): x, b, y_idx = [i.data for i in self.inputs] if b.shape[0] != x.shape[1]: - raise ValueError('b must have same shape as x[0]') + raise ValueError('b must have same number of columns as x') + if y_idx.shape[0] != x.shape[0]: + raise ValueError('y_idx must have same number of rows as x') sm = numpy.zeros_like(x) # softmax nll = numpy.zeros(x.shape[0]) #nll(y | softmax(x)) @@ -66,17 +129,12 @@ def grad(self, (x, b, y_idx), (g_nll, g_sm)): if g_sm is not None: raise NotImplementedError() - nll, sm = crossentropy_softmax_1hot(x, b, y_idx) - dx = CrossentropySoftmax1HotDx(g_nll, sm, y_idx).outputs[0] + nll, sm = crossentropy_softmax_1hot_with_bias(x, b, y_idx) + dx = CrossentropySoftmax1HotWithBiasDx(g_nll, sm, y_idx).outputs[0] db = tensor.Sum(dx, axis = [0]).outputs[0] return dx, db, None - def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub): - """Not sure...""" - return "" - def c_support_code(self): - return """ - """ + def c_headers(self): return [''] def c_code(self, (x, b, y_idx), (nll, sm), sub): # this implementation was lifted from # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx @@ -89,25 +147,67 @@ return """ npy_intp* Nx = %(x)s->dimensions; - if (%(x)s->nd != 2) { %(fail)s } - if (%(b)s->nd != 1) { %(fail)s } - if (%(y_idx)s->nd != 1) { %(fail)s } - if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s} - - %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s); - if(!%(nll)s){%(fail)s} + if (%(x)s->nd != 2) + { + PyErr_SetString(PyExc_ValueError, "a not 2d tensor"); + %(fail)s; + } + if (%(b)s->nd != 1) + { + PyErr_SetString(PyExc_ValueError, "b not 1d tensor"); + %(fail)s; + } + if (%(y_idx)s->nd != 1) + { + PyErr_SetString(PyExc_ValueError, "y_idx not 1d tensor"); + %(fail)s; + } + if (%(x)s->descr->type_num != PyArray_DOUBLE) + { + PyErr_SetString(PyExc_TypeError, "a not float64"); + %(fail)s; + } + if (%(b)s->descr->type_num != PyArray_DOUBLE) + { + PyErr_SetString(PyExc_TypeError, "b not float64"); + %(fail)s; + } + if (%(y_idx)s->descr->type_num != PyArray_INT64) + { + PyErr_SetString(PyExc_TypeError, "y_idx not int64"); + %(fail)s; + } + if ((%(x)s->dimensions[1] != %(b)s->dimensions[0]) + || (%(x)s->dimensions[0] != %(y_idx)s->dimensions[0])) + { + PyErr_SetString(PyExc_ValueError, "dimension mismatch in arguments"); + %(fail)s; + } - %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s); - if(!%(sm)s) { - // The normal cleanup code will take care of %(nll)s - // Py_XDECREF(%(nll)s); %(nll)s=NULL; - %(fail)s + if ((NULL == %(nll)s) //initial condition + || (%(nll)s->dimensions[0] != %(y_idx)s->dimensions[0])) + { + if (NULL != %(nll)s) Py_XDECREF(%(nll)s); + %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s); + if(!%(nll)s) + { + PyErr_SetString(PyExc_MemoryError, "failed to alloc nll output"); + %(fail)s; + } } - if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s} - if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s} - if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s} + if ((NULL == %(sm)s) + || (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) + || (%(sm)s->dimensions[1] != %(x)s->dimensions[1])) + { + if (NULL != %(sm)s) Py_XDECREF(%(sm)s); + %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s); + if(!%(sm)s) { + // The normal cleanup code will take care of %(nll)s + // Py_XDECREF(%(nll)s); %(nll)s=NULL; + PyErr_SetString(PyExc_MemoryError, "failed to alloc sm output"); + %(fail)s + } + } for (size_t i = 0; i < Nx[0]; ++i) { @@ -181,11 +281,10 @@ } """ % dict(locals(), **sub) - +crossentropy_softmax_1hot_with_bias = \ + gof.op.constructor(CrossentropySoftmax1HotWithBias) -crossentropy_softmax_1hot = gof.op.constructor(CrossentropySoftmax1Hot) - -class CrossentropySoftmax1HotDx (gof.op.Op): +class CrossentropySoftmax1HotWithBiasDx (gof.op.Op): nin=3 nout=1 """Gradient wrt x of the CrossentropySoftmax1Hot Op""" @@ -204,36 +303,42 @@ self.outputs[0].data = dx def grad(self, *args): raise NotImplementedError() - def c_validate_update(self, (dnll, sm, y_idx), (dx,), sub): - """Allocate output storage""" - return """ - if (%(dnll)s->nd != 1) { %(fail)s } - if (%(sm)s->nd != 2) { %(fail)s } - if (%(y_idx)s->nd != 1) { %(fail)s } - if (%(dnll)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(sm)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s} - - %(dx)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(sm)s), type_num_%(sm)s); - if(!%(dx)s){%(fail)s} - - """ % dict(locals(), **sub) - def c_validate_cleanup(self, inputs, outputs, sub): - """Not sure...""" - return "" - def c_support_code(self): - return """ - """ def c_code(self, (dnll, sm, y_idx), (dx,), sub): return """ - npy_intp* shape = %(dx)s->dimensions; - if (%(dnll)s->dimensions[0] != %(sm)s->dimensions[0]) {%(fail)s} - if (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0]) {%(fail)s} - if (%(dnll)s->dimensions[0] != %(dx)s->dimensions[0]) {%(fail)s} - if (%(sm)s->dimensions[1] != %(dx)s->dimensions[1]) {%(fail)s} + if ((%(dnll)s->descr->type_num != PyArray_DOUBLE) + || (%(sm)s->descr->type_num != PyArray_DOUBLE) + || (%(y_idx)s->descr->type_num != PyArray_INT64)) + { + PyErr_SetString(PyExc_TypeError, "types should be float64, float64, int64"); + %(fail)s; + } + if ((%(dnll)s->nd != 1) + || (%(sm)s->nd != 2) + || (%(y_idx)s->nd != 1)) + { + PyErr_SetString(PyExc_ValueError, "rank error"); + %(fail)s; + } + if ((%(dnll)s->dimensions[0] != %(sm)s->dimensions[0]) + || (%(dnll)s->dimensions[0] != %(y_idx)s->dimensions[0])) + { + PyErr_SetString(PyExc_ValueError, "dimension mismatch"); + %(fail)s; + } + if ((NULL == %(dx)s) + || (%(dx)s->dimensions[0] != %(sm)s->dimensions[0]) + || (%(dx)s->dimensions[1] != %(sm)s->dimensions[1])) + { + if (NULL != %(dx)s) Py_XDECREF(%(dx)s); + %(dx)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(sm)s), type_num_%(sm)s); + if(!%(dx)s) { + PyErr_SetString(PyExc_MemoryError, "failed to alloc dx output"); + %(fail)s + } + } - for (size_t i = 0; i < shape[0]; ++i) + for (size_t i = 0; i < %(dx)s->dimensions[0]; ++i) { const double dnll_i = ((double*)(%(dnll)s->data + %(dnll)s->strides[0] * i))[0]; @@ -245,14 +350,19 @@ double* __restrict__ dx_i = (double*)(%(dx)s->data + %(dx)s->strides[0] * i); npy_intp Sdx = %(dx)s->strides[1]/sizeof(double); - for (size_t j = 0; j < shape[1]; ++j) + for (size_t j = 0; j < %(dx)s->dimensions[1]; ++j) { dx_i[j * Sdx] = dnll_i * sm_i[j * Ssm]; } - if (y_i >= shape[1]) + if (y_i >= %(dx)s->dimensions[1]) { %(fail)s; } dx_i[y_i * Sdx] -= dnll_i; } """ % dict(locals(), **sub) + +def crossentropy_softmax_1hot(x, y_idx, **kwargs): + b = tensor.zeros_like(x[0,:]) + return crossentropy_softmax_1hot_with_bias(x, b, y_idx, **kwargs) +