view linear_regression.py @ 78:3499918faa9d

In the middle of designing TLearner
author bengioy@bengiomac.local
date Mon, 05 May 2008 09:35:30 -0400
parents 1e2bb5bad636
children c4726e19b8ec
line wrap: on
line source


from learner import *
from theano import tensor as t
from compile import Function
from theano.scalar import as_scalar

# this is one of the simplest example of learner, and illustrates
# the use of theano 
class LinearRegression(Learner):
    """
    Implement linear regression, with or without L2 regularization
    (the former is called Ridge Regression and the latter Ordinary Least Squares).

    The predictor is obtained analytically.

    The L2 regularization coefficient is obtained analytically.
    For each (input[t],output[t]) pair in a minibatch,::
    
       output_t = b + W * input_t

    where b and W are obtained by minimizing::

       lambda sum_{ij} W_{ij}^2  + sum_t ||output_t - target_t||^2

    Let X be the whole training set inputs matrix (one input example per row),
    with the first column full of 1's, and Let Y the whole training set
    targets matrix (one example's target vector per row).
    Let theta = the matrix with b in its first column and W in the others,
    then each theta[:,i] is the solution of the linear system::

       XtX * theta[:,i] = XtY[:,i]

    where XtX is a (n_inputs+1)x(n_inputs+1) matrix containing X'*X
    plus lambda on the diagonal except at (0,0),
    and XtY is a (n_inputs+1)*n_outputs matrix containing X'*Y.

    The fields and attributes expected and produced by use and update are the following:

     - Input and output fields (example-wise quantities):

       - 'input' (always expected by use and update as an input_dataset field)
       - 'target' (optionally expected by use and update as an input_dataset field)
       - 'output' (optionally produced by use as an output dataset field)
       - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error

     - optional input attributes (optionally expected as input_dataset attributes)

       - 'lambda' (only used by update)
       - 'b' (only used by use)
       - 'W' (only used by use)

     - optional output attributes (available in self and optionally in output dataset)

       - 'b' (only set by update)
       - 'W' (only set by update)
       - 'regularization_term' (only set by update)
       - 'XtX' (only set by update)
       - 'XtY' (only set by update)
       
    """

# definitions specifiques a la regression lineaire:

    def global_inputs(self):
        self.lambda = as_scalar(0.,'lambda')
        self.theta = t.matrix('theta')
        self.W = self.theta[:,1:] 
        self.b = self.theta[:,0]
        self.XtX = t.matrix('XtX')
        self.XtY = t.matrix('XtY')

    def global_outputs(self):
        self.regularizer = self.lambda * t.dot(self.W,self.W)
        self.loss = self.regularizer + t.sum(self.squared_error) # this only makes sense if the whole training set fits in memory in a minibatch
        self.loss_function = Function([self.W,self.lambda,self.squared_error],[self.loss])

    def initialize(self):
        self.XtX.resize((1+self.n_inputs,1+self.n_inputs))
        self.XtY.resize((1+self.n_inputs,self.n_outputs))
        self.XtX.data[:,:]=0
        self.XtY.data[:,:]=0
        numpy.diag(self.XtX.data)[1:]=self.lambda.data
        
    def updated_variables(self):
        self.new_XtX = self.XtX + t.dot(self.extended_input.T,self.extended_input)
        self.new_XtY = self.XtY + t.dot(self.extended_input.T,self.target)
        self.new_theta = t.solve(self.XtX,self.XtY)
    
    def minibatch_wise_inputs(self):
        self.input = t.matrix('input') # n_examples x n_inputs
        self.target = t.matrix('target') # n_examples x n_outputs
        
    def minibatch_wise_outputs(self):
        # self.input is a (n_examples, n_inputs) minibatch matrix
        self.extended_input = t.prepend_one_to_each_row(self.input)
        self.output = t.dot(self.input,self.W.T) + self.b  # (n_examples , n_outputs) matrix
        self.squared_error = t.sum_within_rows(t.sqr(self.output-self.target)) # (n_examples ) vector

    def attributeNames(self):
        return ["lambda","b","W","regularization_term","XtX","XtY"]

    def defaultOutputFields(self, input_fields):
        output_fields = ["output"]
        if "target" in input_fields:
            output_fields.append("squared_error")
        return output_fields
        
    # poutine generale basee sur ces fonctions

    def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
        if not output_fields:
            output_fields = self.defaultOutputFields(input_fields)
        if stats_collector:
            stats_collector_inputs = stats_collector.inputUpdateAttributes()
            for attribute in stats_collector_inputs:
                if attribute not in input_fields:
                    output_fields.append(attribute)
        key = (input_fields,output_fields)
        if key not in self.use_functions_dictionary:
            self.use_functions_dictionary[key]=Function(self.names2attributes(input_fields),
                                                   self.names2attributes(output_fields))
        return self.use_functions_dictionary[key]

    def attributes(self,return_copy=False):
        return self.names2attributes(self.attributeNames())
            
    def names2attributes(self,names,return_Result=False, return_copy=False):
        if return_Result:
            if return_copy:
                return [copy.deepcopy(self.__getattr__(name)) for name in names]
            else:
                return [self.__getattr__(name) for name in names]
        else:
            if return_copy:
                return [copy.deepcopy(self.__getattr__(name).data) for name in names]
            else:
                return [self.__getattr__(name).data for name in names]

    def use(self,input_dataset,output_fieldnames=None,test_stats_collector=None,copy_inputs=True):
        minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(),output_fieldnames,test_stats_collector)
        virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
                                                      minibatchwise_use_function,
                                                      True,DataSet.numpy_vstack,
                                                      DataSet.numpy_hstack)
        # actually force the computation
        output_dataset = CachedDataSet(virtual_output_dataset,True)
        if copy_inputs:
            output_dataset = input_dataset | output_dataset
        # compute the attributes that should be copied in the dataset
        output_dataset.setAttributes(self.attributeNames(),self.attributes(return_copy=True))
        if test_stats_collector:
            test_stats_collector.update(output_dataset)
            for attribute in test_stats_collector.attributeNames():
                output_dataset[attribute] = copy.deepcopy(test_stats_collector[attribute])
        return output_dataset

    def update(self,training_set,train_stats_collector=None):
        self.update_start()
        for minibatch in training_set.minibatches(self.training_set_input_fields, minibatch_size=self.minibatch_size):
            self.update_minibatch(minibatch)
            if train_stats_collector:
                minibatch_set = minibatch.examples()
                minibatch_set.setAttributes(self.attributeNames(),self.attributes())
                train_stats_collector.update(minibatch_set)
        self.update_end()
        return self.use
    
    def __init__(self,lambda=0.,max_memory_use=500):
        """
        @type lambda: float
        @param lambda: regularization coefficient
        """
        
        W=t.matrix('W')
        # b is a broadcastable row vector (can be replicated into
        # as many rows as there are examples in the minibach)
        b=t.row('b')
        minibatch_input = t.matrix('input') # n_examples x n_inputs
        minibatch_target = t.matrix('target') # n_examples x n_outputs
        minibatch_output = t.dot(minibatch_input,W.T) + b  # n_examples x n_outputs
        lambda = as_scalar(lambda)
        regularizer = self.lambda * t.dot(W,W)
        example_squared_error = t.sum_within_rows(t.sqr(minibatch_output-minibatch_target))
        self.output_function = Function([W,b,minibatch_input],[minibatch_output])
        self.squared_error_function = Function([minibatch_output,minibatch_target],[self.example_squared_error])
        self.loss_function = Function([W,squared_error],[self.regularizer + t.sum(self.example_squared_error)])
        self.W=None
        self.b=None
        self.XtX=None
        self.XtY=None
        
    def forget(self):
        if self.W:
            self.XtX *= 0
            self.XtY *= 0

    def use(self,input_dataset,output_fieldnames=None,copy_inputs=True):
        input_fieldnames = input_dataset.fieldNames()
        assert "input" in input_fieldnames
        if not output_fields:
            output_fields = ["output"]
            if "target" in input_fieldnames:
                output_fields += ["squared_error"]
        else:
            if "squared_error" in output_fields or "total_loss" in output_fields:
                assert "target" in input_fieldnames

        use_functions = []
        for output_fieldname in output_fieldnames:
            if output_fieldname=="output":
                use_functions.append(self.output_function)
            elif output_fieldname=="squared_error":
                use_functions.append(lambda self.output_function)
    
        n_examples = len(input_dataset)
        
        for minibatch in input_dataset.minibatches(minibatch_size=minibatch_size, allow_odd_last_minibatch=True):
            use_function(