comparison nnet_ops.py @ 34:1b152f46ad0c

consolidated code
author bergstrj@iro.umontreal.ca
date Thu, 17 Apr 2008 12:49:48 -0400
parents 039c0f249859
children 810a8e3c85e1
comparison
equal deleted inserted replaced
33:bb92087cb0f6 34:1b152f46ad0c
69 nll, sm = crossentropy_softmax_1hot(x, b, y_idx) 69 nll, sm = crossentropy_softmax_1hot(x, b, y_idx)
70 dx = CrossentropySoftmax1HotDx(g_nll, sm, y_idx).outputs[0] 70 dx = CrossentropySoftmax1HotDx(g_nll, sm, y_idx).outputs[0]
71 db = tensor.Sum(dx, axis = [0]).outputs[0] 71 db = tensor.Sum(dx, axis = [0]).outputs[0]
72 return dx, db, None 72 return dx, db, None
73 73
74 def c_validate_update(self, (x, b, y_idx), (nll, sm), sub): 74 def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub):
75 """Allocate output storage""" 75 """Not sure..."""
76 return """ 76 return ""
77 def c_support_code(self):
78 return """
79 """
80 def c_code(self, (x, b, y_idx), (nll, sm), sub):
81 # this implementation was lifted from
82 # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
83
84 #TODO: put this into a templated function, in the support code
85 #TODO: declare the max of each row as an Op output
86
87 #TODO: set error messages for failures in this code
88
89 return """
90 npy_intp* Nx = %(x)s->dimensions;
91
77 if (%(x)s->nd != 2) { %(fail)s } 92 if (%(x)s->nd != 2) { %(fail)s }
78 if (%(b)s->nd != 1) { %(fail)s } 93 if (%(b)s->nd != 1) { %(fail)s }
79 if (%(y_idx)s->nd != 1) { %(fail)s } 94 if (%(y_idx)s->nd != 1) { %(fail)s }
80 if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} 95 if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
81 if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s} 96 if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
83 98
84 %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s); 99 %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s);
85 if(!%(nll)s){%(fail)s} 100 if(!%(nll)s){%(fail)s}
86 101
87 %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s); 102 %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s);
88 if(!%(sm)s){Py_XDECREF(%(nll)s); %(fail)s} 103 if(!%(sm)s) {
89 104 // The normal cleanup code will take care of %(nll)s
90 """ % dict(locals(), **sub) 105 // Py_XDECREF(%(nll)s); %(nll)s=NULL;
91 def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub): 106 %(fail)s
92 """Not sure...""" 107 }
93 return ""
94 def c_support_code(self):
95 return """
96 """
97 def c_code(self, (x, b, y_idx), (nll, sm), sub):
98 # this implementation was lifted from
99 # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx
100
101 #TODO: put this into a templated function, in the support code
102 #TODO: declare the max of each row as an Op output
103
104 #TODO: set error messages for failures in this code
105
106 return """
107 npy_intp* Nx = %(x)s->dimensions;
108 if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s} 108 if (%(x)s->dimensions[1] != %(b)s->dimensions[0]) {%(fail)s}
109 if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s} 109 if (%(sm)s->dimensions[0] != %(x)s->dimensions[0]) {%(fail)s}
110 if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s} 110 if (%(sm)s->dimensions[1] != %(x)s->dimensions[1]) {%(fail)s}
111 111
112 for (size_t i = 0; i < Nx[0]; ++i) 112 for (size_t i = 0; i < Nx[0]; ++i)