comparison linear_regression.py @ 75:90e4c0784d6e

Added draft of LinearRegression learner
author bengioy@bengiomac.local
date Sat, 03 May 2008 21:59:26 -0400
parents
children 1e2bb5bad636
comparison
equal deleted inserted replaced
73:69f97aad3faf 75:90e4c0784d6e
1
2 from learner import *
3 from theano import tensor as t
4 from compile import Function
5 from theano.scalar import as_scalar
6
7 # this is one of the simplest example of learner, and illustrates
8 # the use of theano
9 class LinearRegression(Learner):
10 """
11 Implement linear regression, with or without L2 regularization
12 (the former is called Ridge Regression and the latter Ordinary Least Squares).
13
14 The predictor is obtained analytically.
15
16 The L2 regularization coefficient is obtained analytically.
17 For each (input[t],output[t]) pair in a minibatch,::
18
19 output_t = b + W * input_t
20
21 where b and W are obtained by minimizing::
22
23 lambda sum_{ij} W_{ij}^2 + sum_t ||output_t - target_t||^2
24
25 Let X be the whole training set inputs matrix (one input example per row),
26 with the first column full of 1's, and Let Y the whole training set
27 targets matrix (one example's target vector per row).
28 Let theta = the matrix with b in its first column and W in the others,
29 then each theta[:,i] is the solution of the linear system::
30
31 XtX * theta[:,i] = XtY[:,i]
32
33 where XtX is a (n_inputs+1)x(n_inputs+1) matrix containing X'*X
34 plus lambda on the diagonal except at (0,0),
35 and XtY is a (n_inputs+1)*n_outputs matrix containing X'*Y.
36
37 The fields and attributes expected and produced by use and update are the following:
38
39 - Input and output fields (example-wise quantities):
40
41 - 'input' (always expected by use and update as an input_dataset field)
42 - 'target' (optionally expected by use and update as an input_dataset field)
43 - 'output' (optionally produced by use as an output dataset field)
44 - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error
45
46 - optional input attributes (optionally expected as input_dataset attributes)
47
48 - 'lambda' (only used by update)
49 - 'b' (only used by use)
50 - 'W' (only used by use)
51
52 - optional output attributes (available in self and optionally in output dataset)
53
54 - 'b' (only set by update)
55 - 'W' (only set by update)
56 - 'total_squared_error' (set by use and by update) = sum over examples of example_wise_squared_error
57 - 'total_loss' (set by use and by update) = regularizer + total_squared_error
58 - 'XtX' (only set by update)
59 - 'XtY' (only set by update)
60
61 """
62
63 def __init__(self,lambda=0.):
64 """
65 @type lambda: float
66 @param lambda: regularization coefficient
67 """
68
69 W=t.matrix('W')
70 # b is a broadcastable row vector (can be replicated into
71 # as many rows as there are examples in the minibach)
72 b=t.row('b')
73 minibatch_input = t.matrix('input') # n_examples x n_inputs
74 minibatch_target = t.matrix('target') # n_examples x n_outputs
75 minibatch_output = t.dot(minibatch_input,W.T) + b # n_examples x n_outputs
76 lambda = as_scalar(lambda)
77 regularizer = self.lambda * t.dot(W,W)
78 example_squared_error = t.sum_within_rows(t.sqr(minibatch_output-minibatch_target))
79 self.output_function = Function([W,b,minibatch_input],[minibatch_output])
80 self.squared_error_function = Function([minibatch_output,minibatch_target],[self.example_squared_error])
81 self.loss_function = Function([W,squared_error],[self.regularizer + t.sum(self.example_squared_error)])
82 self.W=None
83 self.b=None
84 self.XtX=None
85 self.XtY=None
86
87 def forget(self):
88 if self.W:
89 self.XtX *= 0
90 self.XtY *= 0
91
92 def use(self,input_dataset,output_fieldnames=None,copy_inputs=True):
93 input_fieldnames = input_dataset.fieldNames()
94 assert "input" in input_fieldnames
95 if not output_fields:
96 output_fields = ["output"]
97 if "target" in input_fieldnames:
98 output_fields += ["squared_error"]
99 else:
100 if "squared_error" in output_fields or "total_loss" in output_fields:
101 assert "target" in input_fieldnames
102
103 use_functions = []
104 for output_fieldname in output_fieldnames:
105 if output_fieldname=="output":
106 use_functions.append(self.output_function)
107 elif output_fieldname=="squared_error":
108 use_functions.append(lambda self.output_function)
109