view pylearn/algorithms/logistic_regression.py @ 643:5f9ffefa9ca8 sequencelabelling 20090130-rerun

pylearn/algorithms/logistic_regression.py: Added back in unsupervised softmax computation.
author Joseph Turian <turian@iro.umontreal.ca>
date Fri, 30 Jan 2009 16:38:46 -0500
parents 89bc88affef0
children c433b9cf9d09
line wrap: on
line source

import sys, copy
import theano
from theano import tensor as T
from theano.tensor import nnet
from theano.compile import module
from theano import printing, pprint
from theano import compile

import numpy as N

class LogRegN(module.FancyModule):
    """
    A symbolic module for performing N-class logistic regression.

    Notable variables
    -----------------

    self.input
    self.target 
    self.softmax
    self.argmax
    self.regularized_cost
    self.unregularized_cost
    """

    def __init__(self, 
            n_in=None, n_out=None,
            input=None, target=None, 
            w=None, b=None, 
            l2=None, l1=None):
        super(LogRegN, self).__init__() #boilerplate

        self.n_in = n_in
        self.n_out = n_out

        self.input = input if input is not None else T.matrix()
        self.target = target if target is not None else T.lvector()

        self.w = w if w is not None else module.Member(T.dmatrix())
        self.b = b if b is not None else module.Member(T.dvector())

        #the params of the model are the ones we fit to the data
        self.params = [p for p in [self.w, self.b] if p.owner is None]
        
        #the hyper-parameters of the model are not fit to the data
        self.l2 = l2 if l2 is not None else module.Member(T.dscalar())
        self.l1 = l1 if l1 is not None else module.Member(T.dscalar())

        #here we actually build the model
        self.linear_output = T.dot(self.input, self.w) + self.b
        if 0:
            # TODO: pending support for target being a sparse matrix
            self.softmax = nnet.softmax(self.linear_output)

            self._max_pr, self.argmax = T.max_and_argmax(self.linear_output)
            self._xent = self.target * T.log(self.softmax)
        else:
            # TODO: when above is fixed, remove this hack (need an argmax
            # which is independent of targets)
            self.argmax_standalone = T.argmax(self.linear_output)
            (self._xent, self.softmax, self._max_pr, self.argmax) =\
                    nnet.crossentropy_softmax_max_and_argmax_1hot(
                    self.linear_output, self.target)

        self.unregularized_cost = T.sum(self._xent)
        self.l1_cost = self.l1 * T.sum(abs(self.w))
        self.l2_cost = self.l2 * T.sum(self.w**2)
        self.regularized_cost = self.unregularized_cost + self.l1_cost + self.l2_cost
        self._loss_zero_one = T.mean(T.neq(self.argmax, self.target))

        # Softmax being computed directly.
	# TODO: Move somewhere else, more clean.
        self.softmax_unsupervised = nnet.softmax(self.linear_output)

        # METHODS
        if 0: #TODO: PENDING THE BETTER IMPLEMENTATION ABOVE
            self.predict = module.Method([self.input], self.argmax)
            self.label_probs = module.Method([self.input], self.softmax)
        self.validate = module.Method([self.input, self.target], 
                [self._loss_zero_one, self.regularized_cost, self.unregularized_cost])

    def _instance_initialize(self, obj):
        obj.w = N.zeros((self.n_in, self.n_out))
        obj.b = N.zeros(self.n_out)
        obj.__pp_hide__ = ['params']

def logistic_regression(n_in, n_out, l1, l2, minimizer=None):
    if n_out == 2:
        raise NotImplementedError()
    else:
        rval = LogRegN(n_in=n_in, n_out=n_out, l1=l1, l2=l2)
        print 'RVAL input target', rval.input, rval.target
        rval.minimizer = minimizer([rval.input, rval.target], rval.regularized_cost,
                rval.params)
        return rval.make(mode='FAST_RUN')

#TODO: grouping parameters by prefix does not play well with providing defaults. Think...
#FIX : Guillaume suggested a convention: plugin handlers (dataset_factory, minimizer_factory,
#      etc.) should never provide default arguments for parameters, and accept **kwargs to catch
#      irrelevant parameters.
class _fit_logreg_defaults(object):
    minimizer_algo = 'dummy'
    #minimizer_lr = 0.001
    dataset = 'MNIST_1k'
    l1 = 0.0
    l2 = 0.0
    batchsize = 8
    verbose = 1

# consider pre-importing each file in algorithms, datasets (possibly with try/catch around each
# import so that this import failure is ignored)

def fit_logistic_regression_online(state, channel=lambda *args, **kwargs:None):
    #use stochastic gradient descent
    state.use_defaults(_fit_logreg_defaults)

    dataset = make(state.dataset)
    train = dataset.train
    valid = dataset.valid
    test = dataset.test

    logreg = logistic_regression(
            n_in=train.x.shape[1],
            n_out=dataset.n_classes,
            l2=state.l2,
            l1=state.l1,
            minimizer=make_minimizer(**state.subdict(prefix='minimizer_')))

    batchsize = state.batchsize
    verbose = state.verbose
    iter = [0]

    def step():
        # step by making a pass through the training set
        for j in xrange(0,len(train.x)-batchsize+1,batchsize):
            cost_j = logreg.minimizer.step_cost(train.x[j:j+batchsize], train.y[j:j+batchsize])
            if verbose > 1:
                print 'estimated train cost', cost_j
        #TODO: consult iter[0] for periodic saving to cwd (model, minimizer, and stopper)

    def check():
        validate = logreg.validate(valid.x, valid.y)
        if verbose > 0: 
            print 'iter', iter[0], 'validate', validate
            sys.stdout.flush()
        iter[0] += 1
        return validate[0]

    def save():
        return copy.deepcopy(logreg)

    stopper = make_stopper(**state.subdict(prefix='stopper_'))
    stopper.find_min(step, check, save)

    state.train_01, state.train_rcost, state.train_cost = logreg.validate(train.x, train.y)
    state.valid_01, state.valid_rcost, state.valid_cost = logreg.validate(valid.x, valid.y)
    state.test_01, state.test_rcost, state.test_cost = logreg.validate(test.x, test.y)

    state.n_train = len(train.y)
    state.n_valid = len(valid.y)
    state.n_test = len(test.y)

class LogReg2(module.FancyModule):
    def __init__(self, input=None, targ=None, w=None, b=None, lr=None, regularize=False):
        super(LogReg2, self).__init__() #boilerplate

        self.input = module.Member(input) if input is not None else T.matrix('input')
        self.targ = module.Member(targ) if targ is not None else T.lcol()

        self.w = module.Member(w) if w is not None else module.Member(T.dmatrix())
        self.b = module.Member(b) if b is not None else module.Member(T.dvector())
        self.lr = module.Member(lr) if lr is not None else module.Member(T.dscalar())

        self.params = [p for p in [self.w, self.b] if p.owner is None]

        output = nnet.sigmoid(T.dot(self.x, self.w) + self.b)
        xent = -self.targ * T.log(output) - (1.0 - self.targ) * T.log(1.0 - output)
        sum_xent = T.sum(xent)

        self.output = output
        self.xent = xent
        self.sum_xent = sum_xent
        self.cost = sum_xent

        #define the apply method
        self.pred = (T.dot(self.input, self.w) + self.b) > 0.0
        self.apply = module.Method([self.input], self.pred)

        #if this module has any internal parameters, define an update function for them
        if self.params:
            gparams = T.grad(sum_xent, self.params)
            self.update = module.Method([self.input, self.targ], sum_xent,
                                        updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))