Mercurial > pylearn
diff nnet_ops.py @ 34:1b152f46ad0c
consolidated code
author | bergstrj@iro.umontreal.ca |
---|---|
date | Thu, 17 Apr 2008 12:49:48 -0400 |
parents | 039c0f249859 |
children | 810a8e3c85e1 |
line wrap: on
line diff
--- a/nnet_ops.py Thu Apr 17 12:49:33 2008 -0400 +++ b/nnet_ops.py Thu Apr 17 12:49:48 2008 -0400 @@ -71,23 +71,6 @@ db = tensor.Sum(dx, axis = [0]).outputs[0] return dx, db, None - def c_validate_update(self, (x, b, y_idx), (nll, sm), sub): - """Allocate output storage""" - return """ - if (%(x)s->nd != 2) { %(fail)s } - if (%(b)s->nd != 1) { %(fail)s } - if (%(y_idx)s->nd != 1) { %(fail)s } - if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} - if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s} - - %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s); - if(!%(nll)s){%(fail)s} - - %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s); - if(!%(sm)s){Py_XDECREF(%(nll)s); %(fail)s} - - """ % dict(locals(), **sub) def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub): """Not sure...""" return "" @@ -105,6 +88,23 @@ return """ npy_intp* Nx = %(x)s->dimensions; + + if (%(x)s->nd != 2) { %(fail)s } + if (%(b)s->nd != 1) { %(fail)s } + if (%(y_idx)s->nd != 1) { %(fail)s } + if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} + if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} + if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s} + + %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s); + if(!%(nll)s){%(fail)s} + + %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s); + if(!%(sm)s) { + // The normal cleanup code will take care of %(nll)s + // Py_XDECREF(%(nll)s); %(nll)s=NULL; + %(fail)s + } if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s} if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s} if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s}