Mercurial > pylearn
changeset 443:060c12314734
Hopefully last bugfix in Softmax
author | Pascal Lamblin <lamblinp@iro.umontreal.ca> |
---|---|
date | Fri, 22 Aug 2008 17:33:06 -0400 |
parents | b3315b252824 |
children | 9cfc2fc0f4d1 |
files | nnet_ops.py |
diffstat | 1 files changed, 10 insertions(+), 6 deletions(-) [+] |
line wrap: on
line diff
--- a/nnet_ops.py Fri Aug 22 15:53:34 2008 -0400 +++ b/nnet_ops.py Fri Aug 22 17:33:06 2008 -0400 @@ -105,7 +105,7 @@ row = x[i] + b sm[i] = numpy.exp(row - numpy.max(row)) sm[i] *= 1.0 / numpy.sum(sm[i]) - output_storage[0][0] = nll + output_storage[0][0] = sm def grad(self, (x, b), (g_sm,)): sm = softmax_with_bias(x, b) @@ -257,7 +257,7 @@ #dx[i,j] = - (\sum_k dy[i,k] sm[i,k]) sm[i,j] + dy[i,j] sm[i,j] for i in xrange(sm.shape[0]): dy_times_sm_i = dy[i] * sm[i] - dx[i] = dy_times_sm - sum(dy_times_sm_i) * y[i] + dx[i] = dy_times_sm_i - sum(dy_times_sm_i) * sm[i] output_storage[0][0] = dx def grad(self, *args): @@ -318,6 +318,10 @@ } ''' % dict(locals(), **sub) +def softmax(x, **kwargs): + b = tensor.zeros_like(x[0,:]) + return softmax_with_bias(x, b, **kwargs) + class CrossentropySoftmax1HotWithBias(theano.Op): """A special compound L{Op} for the output of neural-net classifiers. @@ -576,7 +580,7 @@ y = tensor.as_tensor(self.val) if x.type.dtype != y.type.dtype: TypeError("the value to prepend don't have the same type as the matrix") - + node = theano.Apply(op=self, inputs=[mat], outputs=[tensor.matrix()]) return node @@ -599,7 +603,7 @@ def grad(self, (mat,), (goutput,)): return goutput[:,1:] -class Prepend_scalar_to_each_row(theano.Op): +class Prepend_scalar_to_each_row(theano.Op): def make_node(self, val, mat): #check type of input if isinstance(val, float): @@ -610,7 +614,7 @@ y = tensor.as_tensor(val) if x.type.dtype != y.type.dtype: TypeError("the value to prepend don't have the same type as the matrix") - + node = theano.Apply(op=self, inputs=[val,mat], outputs=[tensor.matrix()]) return node @@ -659,4 +663,4 @@ def grad(self, (theta, A, B), (gtheta,)): raise NotImplementedError() - +