diff linear_regression.py @ 75:90e4c0784d6e

Added draft of LinearRegression learner
author bengioy@bengiomac.local
date Sat, 03 May 2008 21:59:26 -0400
parents
children 1e2bb5bad636
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/linear_regression.py	Sat May 03 21:59:26 2008 -0400
@@ -0,0 +1,109 @@
+
+from learner import *
+from theano import tensor as t
+from compile import Function
+from theano.scalar import as_scalar
+
+# this is one of the simplest example of learner, and illustrates
+# the use of theano 
+class LinearRegression(Learner):
+    """
+    Implement linear regression, with or without L2 regularization
+    (the former is called Ridge Regression and the latter Ordinary Least Squares).
+
+    The predictor is obtained analytically.
+
+    The L2 regularization coefficient is obtained analytically.
+    For each (input[t],output[t]) pair in a minibatch,::
+    
+       output_t = b + W * input_t
+
+    where b and W are obtained by minimizing::
+
+       lambda sum_{ij} W_{ij}^2  + sum_t ||output_t - target_t||^2
+
+    Let X be the whole training set inputs matrix (one input example per row),
+    with the first column full of 1's, and Let Y the whole training set
+    targets matrix (one example's target vector per row).
+    Let theta = the matrix with b in its first column and W in the others,
+    then each theta[:,i] is the solution of the linear system::
+
+       XtX * theta[:,i] = XtY[:,i]
+
+    where XtX is a (n_inputs+1)x(n_inputs+1) matrix containing X'*X
+    plus lambda on the diagonal except at (0,0),
+    and XtY is a (n_inputs+1)*n_outputs matrix containing X'*Y.
+
+    The fields and attributes expected and produced by use and update are the following:
+
+     - Input and output fields (example-wise quantities):
+
+       - 'input' (always expected by use and update as an input_dataset field)
+       - 'target' (optionally expected by use and update as an input_dataset field)
+       - 'output' (optionally produced by use as an output dataset field)
+       - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error
+
+     - optional input attributes (optionally expected as input_dataset attributes)
+
+       - 'lambda' (only used by update)
+       - 'b' (only used by use)
+       - 'W' (only used by use)
+
+     - optional output attributes (available in self and optionally in output dataset)
+
+       - 'b' (only set by update)
+       - 'W' (only set by update)
+       - 'total_squared_error' (set by use and by update) = sum over examples of example_wise_squared_error 
+       - 'total_loss' (set by use and by update) = regularizer + total_squared_error
+       - 'XtX' (only set by update)
+       - 'XtY' (only set by update)
+       
+    """
+
+    def __init__(self,lambda=0.):
+        """
+        @type lambda: float
+        @param lambda: regularization coefficient
+        """
+        
+        W=t.matrix('W')
+        # b is a broadcastable row vector (can be replicated into
+        # as many rows as there are examples in the minibach)
+        b=t.row('b')
+        minibatch_input = t.matrix('input') # n_examples x n_inputs
+        minibatch_target = t.matrix('target') # n_examples x n_outputs
+        minibatch_output = t.dot(minibatch_input,W.T) + b  # n_examples x n_outputs
+        lambda = as_scalar(lambda)
+        regularizer = self.lambda * t.dot(W,W)
+        example_squared_error = t.sum_within_rows(t.sqr(minibatch_output-minibatch_target))
+        self.output_function = Function([W,b,minibatch_input],[minibatch_output])
+        self.squared_error_function = Function([minibatch_output,minibatch_target],[self.example_squared_error])
+        self.loss_function = Function([W,squared_error],[self.regularizer + t.sum(self.example_squared_error)])
+        self.W=None
+        self.b=None
+        self.XtX=None
+        self.XtY=None
+        
+    def forget(self):
+        if self.W:
+            self.XtX *= 0
+            self.XtY *= 0
+
+    def use(self,input_dataset,output_fieldnames=None,copy_inputs=True):
+        input_fieldnames = input_dataset.fieldNames()
+        assert "input" in input_fieldnames
+        if not output_fields:
+            output_fields = ["output"]
+            if "target" in input_fieldnames:
+                output_fields += ["squared_error"]
+        else:
+            if "squared_error" in output_fields or "total_loss" in output_fields:
+                assert "target" in input_fieldnames
+
+        use_functions = []
+        for output_fieldname in output_fieldnames:
+            if output_fieldname=="output":
+                use_functions.append(self.output_function)
+            elif output_fieldname=="squared_error":
+                use_functions.append(lambda self.output_function)
+