view nnet_ops.py @ 30:bf0145fa73e8

added c implementation for CrossentropySoftmax1Hot
author bergstrj@iro.umontreal.ca
date Fri, 11 Apr 2008 21:41:09 -0400
parents b63e8c0bf21b
children 039c0f249859
line wrap: on
line source

import theano
from theano import tensor, gof, scalar
import numpy

class ScalarSigmoid(scalar.UnaryScalarOp):
    def impl(self, x):
        return 1.0 / (1 + numpy.exp(-x))
    def grad(self, (x,), (gz,)):
        return gz * scalar_sigmoid(x) * (1.0 - scalar_sigmoid(x)),
    def c_foreach(self, (x,), (z,)): 
        return "%(z)s = 1.0 / (1 + exp(-%(x)s));" % locals()
scalar_sigmoid = gof.op.constructor(ScalarSigmoid)
Sigmoid, sigmoid, SigmoidInplace, sigmoid_inplace \
        = theano.tensor.broadcast(ScalarSigmoid, 'Sigmoid')



class CrossentropySoftmax1Hot(gof.op.Op):
    """A special compound Op for the output of neural-net classifiers.

    This Op has two outputs:
    - KL(softmax(x), y)
    - softmax(x)

    x[i] is assumed to be a dense vector
    softmax(x[i]) is the i'th distribution over len(x[i]) options
    y[i] is an integer index, encoding a 1-hot distribution

    """
    nin=2
    nout=2
    def __init__(self, x, b, y_idx, **kwargs):
        x = tensor._as_tensor(x)
        b = tensor._as_tensor(b)
        y_idx = tensor._as_tensor(y_idx)
        if len(x.broadcastable) != 2 \
                or x.dtype not in ['float32', 'float64']:
            raise ValueError('x must be 2-d tensor of floats')
        if len(b.broadcastable) != 1 \
                or x.dtype not in ['float32', 'float64']:
            raise ValueError('x must be 1-d tensor of floats')
        if len(y_idx.broadcastable) != 1 \
                or y_idx.dtype not in ['int32', 'int64']:
            raise ValueError('x must be 1-d tensor of ints')

#       TODO: Is this correct? It used to be y, not y_idx
        nll = tensor.Tensor(x.dtype, y_idx.broadcastable)
#        nll = Tensor(x.dtype, y.broadcastable)
        sm = tensor.Tensor(x.dtype, x.broadcastable)
        self.inputs = [x, b, y_idx]
        self.outputs = [nll, sm]
    def perform(self):
        x, b, y_idx = [i.data for i in self.inputs]
        if b.shape[0] != x.shape[1]:
            raise ValueError('b must have same shape as x[0]')

        sm = numpy.zeros_like(x) # softmax
        nll = numpy.zeros(x.shape[0]) #nll(y | softmax(x))
        for i in xrange(sm.shape[0]):
            row = x[i] + b
            sm[i] = numpy.exp(row - numpy.max(row)) #softmax
            sm[i] *= 1.0 / numpy.sum(sm[i]) #vector scale
            nll[i] = -numpy.log( sm[i, y_idx[i]]) #cross-entropy
        self.outputs[0].data = nll
        self.outputs[1].data = sm
    def grad(self, (x, b, y_idx), (g_nll, g_sm)):
        if g_sm is not None:
            raise NotImplementedError()
        nll, sm = crossentropy_softmax_1hot(x, b, y_idx)
        dx = CrossentropySoftmax1HotDx(g_nll, sm, y_idx).outputs[0]
        db = tensor.Sum(dx, axis = [0]).outputs[0]
        return dx, db, None

    def c_validate_update(self, (x, b, y_idx), (nll, sm), sub):
        """Allocate output storage"""
        return """
        if (%(x)s->nd != 2) { %(fail)s }
        if (%(b)s->nd != 1) { %(fail)s }
        if (%(y_idx)s->nd != 1) { %(fail)s }
        if (%(x)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
        if (%(b)s->descr->type_num != PyArray_DOUBLE) { %(fail)s}
        if (%(y_idx)s->descr->type_num != PyArray_INT64) { %(fail)s}

        %(nll)s = (PyArrayObject*)PyArray_SimpleNew(1, PyArray_DIMS(%(y_idx)s), type_num_%(x)s);
        if(!%(nll)s){%(fail)s}

        %(sm)s = (PyArrayObject*)PyArray_SimpleNew(2, PyArray_DIMS(%(x)s), type_num_%(x)s);
        if(!%(sm)s){Py_XDECREF(%(nll)s); %(fail)s}

        """ % dict(locals(), **sub)
    def c_validate_cleanup(self, (x, b, y_idx), (nll, sm), sub):
        """Not sure..."""
        return ""
    def c_support_code(self):
        return """
        """
    def c_code(self,  (x, b, y_idx), (nll, sm), sub):
        # this implementation was lifted from
        # /u/bergstrj/cvs/bergstrj/src/feb07/nn.cxx

        #TODO: put this into a templated function, in the support code
        #TODO: declare the max of each row as an Op output

        return """
        npy_intp* Nx = %(x)s->dimensions;
        assert(%(x)s->dimensions[1] == %(b)s->dimensions[0]);
        assert(%(sm)s->dimensions[0] == %(x)s->dimensions[0]);
        assert(%(sm)s->dimensions[1] == %(x)s->dimensions[1]);

        for (size_t i = 0; i < Nx[0]; ++i)
        {
            size_t j;
            double sum = 0.0;
            bool  discount_max = false;

            const double* __restrict__ x_i = (double*)(%(x)s->data + %(x)s->strides[0] * i);
            const double* __restrict__ b_i = (double*)(%(b)s->data);
            const long int y_i = ((long int*)(%(y_idx)s->data + %(y_idx)s->strides[0] * i))[0];
            double* __restrict__ sm_i = (double*)(%(sm)s->data + %(sm)s->strides[0] * i);
            double* __restrict__ nll_i = (double*)(%(nll)s->data + %(nll)s->strides[0] * i);

            npy_intp Sx = %(x)s->strides[1]/sizeof(double);
            npy_intp Sb = %(b)s->strides[0]/sizeof(double);
            npy_intp Ssm = %(sm)s->strides[1]/sizeof(double);

            size_t row_max_j=0;
            double row_max = x_i[0] + b_i[0];
            //try to compute sum and sm the easy way
            for (j = 0; j < Nx[1]; ++j)
            {
                double row_ij = x_i[j * Sx] +  b_i[j * Sb];
                row_max_j = (row_ij > row_max) ? j : row_max_j;
                row_max   = (row_ij > row_max) ? row_ij : row_max;

                double sm_ij = exp(row_ij);
                sum += sm_ij;
                sm_i[j * Ssm] = sm_ij;
            }
            if ((0.0 == sum) || (isinf(sum))) 
            {
                //our cheap trick didn't work... try again and do it better.
                discount_max = true;
                sum = 0.0; //reset sum and recompute....
                for (j = 0; j < Nx[1]; ++j)
                {
                    double row_ij = x_i[j * Sx] +  b_i[j * Sb];

                    double sm_ij = exp(row_ij - row_max);
                    sum += sm_ij;
                    sm_i[j * Ssm] = sm_ij;
                }
                assert( (0.0 != sum) && (!isinf(sum))); //that was our best... 
                //if we still can't sum it up, we're screwed.
                //So far, this assertion has never failed...
            }

            //cblas_dscal(x.N, 1.0 / sum, &mat_at(s,i,0), s.n);
            double sum_inv = 1.0 / sum;
            for (j = 0; j < Nx[1]; ++j)
            {
                sm_i[j * Ssm] *= sum_inv;
            }

            assert(y_i < Nx[1]);

            nll_i[0] = - x_i[y_i*Sx] 
                       - b_i[y_i*Sb]
                       + (discount_max ? row_max : 0.0)
                       + log(sum);
              //mat_at(y,i,0) = -log( mat_at(s,i,t[i]));  //less accurate?
              //mat_at(y,i,0) =  - mat_at(x,i,t[i]) - mat_at(b,0,t[i]) + (discount_max ? maxi : 0.0) + log(sum);
        }
        """ % dict(locals(), **sub)



crossentropy_softmax_1hot = gof.op.constructor(CrossentropySoftmax1Hot)

class CrossentropySoftmax1HotDx (gof.op.Op):
    nin=3
    nout=1
    """Gradient wrt x of the CrossentropySoftmax1Hot Op"""
    def __init__(self, dy, sm, y_idx,**kwargs):
        dy = tensor._as_tensor(dy)
        sm = tensor._as_tensor(sm)
        y_idx = tensor._as_tensor(y_idx)
        self.inputs = [dy, sm, y_idx]
        self.outputs = [tensor.Tensor(sm.dtype, sm.broadcastable)]
    def perform(self):
        dy,sm,y_idx = [i.data for i in self.inputs]
        dx = numpy.zeros_like(sm)
        for i in xrange(sm.shape[0]):
            dx[i] = dy[i] * sm[i] #vector scale
            dx[i, y_idx[i]] -= dy[i] #scalar decrement
        self.outputs[0].data = dx
    def grad(self, *args):
        raise NotImplementedError()

#TODO: write a version of CrossentropySoftmax1Hot that accepts a bias for x, if
# this op needs to be faster.