Mercurial > pylearn
diff linear_regression.py @ 75:90e4c0784d6e
Added draft of LinearRegression learner
author | bengioy@bengiomac.local |
---|---|
date | Sat, 03 May 2008 21:59:26 -0400 |
parents | |
children | 1e2bb5bad636 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/linear_regression.py Sat May 03 21:59:26 2008 -0400 @@ -0,0 +1,109 @@ + +from learner import * +from theano import tensor as t +from compile import Function +from theano.scalar import as_scalar + +# this is one of the simplest example of learner, and illustrates +# the use of theano +class LinearRegression(Learner): + """ + Implement linear regression, with or without L2 regularization + (the former is called Ridge Regression and the latter Ordinary Least Squares). + + The predictor is obtained analytically. + + The L2 regularization coefficient is obtained analytically. + For each (input[t],output[t]) pair in a minibatch,:: + + output_t = b + W * input_t + + where b and W are obtained by minimizing:: + + lambda sum_{ij} W_{ij}^2 + sum_t ||output_t - target_t||^2 + + Let X be the whole training set inputs matrix (one input example per row), + with the first column full of 1's, and Let Y the whole training set + targets matrix (one example's target vector per row). + Let theta = the matrix with b in its first column and W in the others, + then each theta[:,i] is the solution of the linear system:: + + XtX * theta[:,i] = XtY[:,i] + + where XtX is a (n_inputs+1)x(n_inputs+1) matrix containing X'*X + plus lambda on the diagonal except at (0,0), + and XtY is a (n_inputs+1)*n_outputs matrix containing X'*Y. + + The fields and attributes expected and produced by use and update are the following: + + - Input and output fields (example-wise quantities): + + - 'input' (always expected by use and update as an input_dataset field) + - 'target' (optionally expected by use and update as an input_dataset field) + - 'output' (optionally produced by use as an output dataset field) + - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error + + - optional input attributes (optionally expected as input_dataset attributes) + + - 'lambda' (only used by update) + - 'b' (only used by use) + - 'W' (only used by use) + + - optional output attributes (available in self and optionally in output dataset) + + - 'b' (only set by update) + - 'W' (only set by update) + - 'total_squared_error' (set by use and by update) = sum over examples of example_wise_squared_error + - 'total_loss' (set by use and by update) = regularizer + total_squared_error + - 'XtX' (only set by update) + - 'XtY' (only set by update) + + """ + + def __init__(self,lambda=0.): + """ + @type lambda: float + @param lambda: regularization coefficient + """ + + W=t.matrix('W') + # b is a broadcastable row vector (can be replicated into + # as many rows as there are examples in the minibach) + b=t.row('b') + minibatch_input = t.matrix('input') # n_examples x n_inputs + minibatch_target = t.matrix('target') # n_examples x n_outputs + minibatch_output = t.dot(minibatch_input,W.T) + b # n_examples x n_outputs + lambda = as_scalar(lambda) + regularizer = self.lambda * t.dot(W,W) + example_squared_error = t.sum_within_rows(t.sqr(minibatch_output-minibatch_target)) + self.output_function = Function([W,b,minibatch_input],[minibatch_output]) + self.squared_error_function = Function([minibatch_output,minibatch_target],[self.example_squared_error]) + self.loss_function = Function([W,squared_error],[self.regularizer + t.sum(self.example_squared_error)]) + self.W=None + self.b=None + self.XtX=None + self.XtY=None + + def forget(self): + if self.W: + self.XtX *= 0 + self.XtY *= 0 + + def use(self,input_dataset,output_fieldnames=None,copy_inputs=True): + input_fieldnames = input_dataset.fieldNames() + assert "input" in input_fieldnames + if not output_fields: + output_fields = ["output"] + if "target" in input_fieldnames: + output_fields += ["squared_error"] + else: + if "squared_error" in output_fields or "total_loss" in output_fields: + assert "target" in input_fieldnames + + use_functions = [] + for output_fieldname in output_fieldnames: + if output_fieldname=="output": + use_functions.append(self.output_function) + elif output_fieldname=="squared_error": + use_functions.append(lambda self.output_function) +