view algorithms/logistic_regression.py @ 515:dc2d93590da0

Small bugfix in hidden weight initialization.
author Joseph Turian <turian@iro.umontreal.ca>
date Fri, 31 Oct 2008 16:37:19 -0400
parents c7ce66b4e8f4
children b267a8000f92
line wrap: on
line source

import theano
from theano import tensor as T
from theano.tensor import nnet
from theano.compile import module
from theano import printing, pprint
from theano import compile

import numpy as N

class LogRegInstanceType(module.FancyModuleInstance):
    def initialize(self, n_in, n_out=1, rng=N.random, seed=None):
        #self.component is the LogisticRegressionTemplate instance that built this guy.
        """
        @todo: Remove seed. Used only to keep Stacker happy.
        """

        self.w = N.zeros((n_in, n_out))
        self.b = N.zeros(n_out)
        self.lr = 0.01
        self.__hide__ = ['params']
        self.input_dimension = n_in
        self.output_dimension = n_out

class Module_Nclass(module.FancyModule):
    InstanceType = LogRegInstanceType

    def __init__(self, x=None, targ=None, w=None, b=None, lr=None, regularize=False):
        super(Module_Nclass, self).__init__() #boilerplate

        self.x = x if x is not None else T.matrix('input')
        self.targ = targ if targ is not None else T.lvector()

        self.w = w if w is not None else module.Member(T.dmatrix())
        self.b = b if b is not None else module.Member(T.dvector())
        self.lr = lr if lr is not None else module.Member(T.dscalar())

        self.params = [p for p in [self.w, self.b] if p.owner is None]

        linear_output = T.dot(self.x, self.w) + self.b

        (xent, softmax, max_pr, argmax) = nnet.crossentropy_softmax_max_and_argmax_1hot(
                linear_output, self.targ)
        sum_xent = T.sum(xent)

        self.softmax = softmax
        self.argmax = argmax 
        self.max_pr = max_pr
        self.sum_xent = sum_xent

        # Softmax being computed directly.
        softmax_unsupervised = nnet.softmax(linear_output)
        self.softmax_unsupervised = softmax_unsupervised

        #compatibility with current implementation of stacker/daa or something
        #TODO: remove this, make a wrapper
        self.cost = self.sum_xent
        self.input = self.x
        # TODO: I want to make output = linear_output.
        self.output = self.softmax_unsupervised

        #define the apply method
        self.pred = T.argmax(linear_output, axis=1)
        self.apply = module.Method([self.input], self.pred)

        self.validate = module.Method([self.input, self.targ], [self.cost, self.argmax, self.max_pr])
        self.softmax_output = module.Method([self.input], self.softmax_unsupervised)

        if self.params:
            gparams = T.grad(sum_xent, self.params)

            self.update = module.Method([self.input, self.targ], sum_xent,
                    updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))

class Module(module.FancyModule):
    InstanceType = LogRegInstanceType

    def __init__(self, input=None, targ=None, w=None, b=None, lr=None, regularize=False):
        super(Module, self).__init__() #boilerplate

        self.input = input if input is not None else T.matrix('input')
        self.targ = targ if targ is not None else T.lcol()

        self.w = w if w is not None else module.Member(T.dmatrix())
        self.b = b if b is not None else module.Member(T.dvector())
        self.lr = lr if lr is not None else module.Member(T.dscalar())

        self.params = [p for p in [self.w, self.b] if p.owner is None]

        output = nnet.sigmoid(T.dot(self.x, self.w) + self.b)
        xent = -self.targ * T.log(output) - (1.0 - self.targ) * T.log(1.0 - output)
        sum_xent = T.sum(xent)

        self.output = output
        self.xent = xent
        self.sum_xent = sum_xent
        self.cost = sum_xent

        #define the apply method
        self.pred = (T.dot(self.input, self.w) + self.b) > 0.0
        self.apply = module.Method([self.input], self.pred)

        #if this module has any internal parameters, define an update function for them
        if self.params:
            gparams = T.grad(sum_xent, self.params)
            self.update = module.Method([self.input, self.targ], sum_xent,
                                        updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gparams)))

class Learner(object):
    """TODO: Encapsulate the algorithm for finding an optimal regularization coefficient"""
    pass