view linear_regression.py @ 259:621faba17c60

created 'dummytests', tests that checks consistency of new weird datasets, where we can't compare with actual values in a matrix, for instance. Useful as a first debugging when creating a dataset
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Tue, 03 Jun 2008 16:41:55 -0400
parents f6505ec32dc3
children c9a89be5cb0a
line wrap: on
line source

"""
Implementation of linear regression, with or without L2 regularization.
This is one of the simplest example of L{learner}, and illustrates
the use of theano.
"""

from learner import *
from theano import tensor as t
from theano.scalar import as_scalar

class LinearRegression(MinibatchUpdatesTLearner):
    """
    Implement linear regression, with or without L2 regularization
    (the former is called Ridge Regression and the latter Ordinary Least Squares).

    The predictor parameters are obtained analytically from the training set.
    Training can proceed sequentially (with multiple calls to update with
    different disjoint subsets of the training sets). After each call to
    update the predictor is ready to be used (and optimized for the union
    of all the training sets passed to update since construction or since
    the last call to forget).

    For each (input[t],output[t]) pair in a minibatch,::
    
       output_t = b + W * input_t

    where b and W are obtained by minimizing::

       L2_regularizer sum_{ij} W_{ij}^2  + sum_t ||output_t - target_t||^2

    Let X be the whole training set inputs matrix (one input example per row),
    with the first column full of 1's, and Let Y the whole training set
    targets matrix (one example's target vector per row).
    Let theta = the matrix with b in its first column and W in the others,
    then each theta[:,i] is the solution of the linear system::

       XtX * theta[:,i] = XtY[:,i]

    where XtX is a (n_inputs+1)x(n_inputs+1) matrix containing X'*X
    plus L2_regularizer on the diagonal except at (0,0),
    and XtY is a (n_inputs+1)*n_outputs matrix containing X'*Y.

    The fields and attributes expected and produced by use and update are the following:

     - Input and output fields (example-wise quantities):

       - 'input' (always expected by use and update as an input_dataset field)
       - 'target' (optionally expected by use and update as an input_dataset field)
       - 'output' (optionally produced by use as an output dataset field)
       - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error

     - optional attributes (optionally expected as input_dataset attributes)
       (warning, this may be dangerous, the 'use' method will use those provided in the 
       input_dataset rather than those learned during 'update'; currently no support
       for providing these to update):
       
       - 'L2_regularizer' 
       - 'b' 
       - 'W'
       - 'parameters' = [b, W] 
       - 'regularization_term'
       - 'XtX'
       - 'XtY'

    """

    def attributeNames(self):
        return ["L2_regularizer","parameters","b","W","regularization_term","XtX","XtY"]

    def useInputAttributes(self):
        return ["b","W"]

    def useOutputAttributes(self):
        return []

    def updateInputAttributes(self):
        return ["L2_regularizer","XtX","XtY"]

    def updateMinibatchInputFields(self):
        return ["input","target"]
    
    def updateMinibatchInputAttributes(self):
        return ["XtX","XtY"]
    
    def updateMinibatchOutputAttributes(self):
        return ["new_XtX","new_XtY"]
    
    def updateEndInputAttributes(self):
        return ["theta","XtX","XtY"]

    def updateEndOutputAttributes(self):
        return ["new_theta","b","W","regularization_term"] # CHECK: WILL b AND W CONTAIN OLD OR NEW THETA? @todo i.e. order of computation = ?

    def parameterAttributes(self):
        return ["b","W"]
    
    def defaultOutputFields(self, input_fields):
        output_fields = ["output"]
        if "target" in input_fields:
            output_fields.append("squared_error")
        return output_fields
        
    def __init__(self):
        self._input = t.matrix('input') # n_examples x n_inputs
        self._target = t.matrix('target') # n_examples x n_outputs
        self._L2_regularizer = as_scalar(0.,'L2_regularizer')
        self._theta = t.matrix('theta')
        self._W = self._theta[:,1:] 
        self._b = self._theta[:,0]
        self._XtX = t.matrix('XtX')
        self._XtY = t.matrix('XtY')
        self._extended_input = t.prepend_one_to_each_row(self._input)
        self._output = t.dot(self._input,self._W.T) + self._b  # (n_examples , n_outputs) matrix
        self._squared_error = t.sum_within_rows(t.sqr(self._output-self._target)) # (n_examples ) vector
        self._regularizer = self._L2_regularizer * t.dot(self._W,self._W)
        self._new_XtX = add_inplace(self._XtX,t.dot(self._extended_input.T,self._extended_input))
        self._new_XtY = add_inplace(self._XtY,t.dot(self._extended_input.T,self._target))
        self._new_theta = t.solve_inplace(self._theta,self._XtX,self._XtY)

        MinibatchUpdatesTLearner.__init__(self)
            
    def allocate(self,minibatch):
        minibatch_n_inputs  = minibatch["input"].shape[1]
        minibatch_n_outputs = minibatch["target"].shape[1]
        if not self._n_inputs:
            self._n_inputs = minibatch_n_inputs 
            self._n_outputs = minibatch_n_outputs
            self.XtX = numpy.zeros((1+self._n_inputs,1+self._n_inputs))
            self.XtY = numpy.zeros((1+self._n_inputs,self._n_outputs))
            self.theta = numpy.zeros((self._n_outputs,1+self._n_inputs))
            self.forget()
        elif self._n_inputs!=minibatch_n_inputs or self._n_outputs!=minibatch_n_outputs:
            # if the input or target changes dimension on the fly, we resize and forget everything
            self.forget()
            
    def forget(self):
        if self._n_inputs and self._n_outputs:
            self.XtX.resize((1+self.n_inputs,1+self.n_inputs))
            self.XtY.resize((1+self.n_inputs,self.n_outputs))
            self.XtX.data[:,:]=0
            self.XtY.data[:,:]=0
            numpy.diag(self.XtX.data)[1:]=self.L2_regularizer