Mercurial > pylearn
comparison linear_regression.py @ 75:90e4c0784d6e
Added draft of LinearRegression learner
author | bengioy@bengiomac.local |
---|---|
date | Sat, 03 May 2008 21:59:26 -0400 |
parents | |
children | 1e2bb5bad636 |
comparison
equal
deleted
inserted
replaced
73:69f97aad3faf | 75:90e4c0784d6e |
---|---|
1 | |
2 from learner import * | |
3 from theano import tensor as t | |
4 from compile import Function | |
5 from theano.scalar import as_scalar | |
6 | |
7 # this is one of the simplest example of learner, and illustrates | |
8 # the use of theano | |
9 class LinearRegression(Learner): | |
10 """ | |
11 Implement linear regression, with or without L2 regularization | |
12 (the former is called Ridge Regression and the latter Ordinary Least Squares). | |
13 | |
14 The predictor is obtained analytically. | |
15 | |
16 The L2 regularization coefficient is obtained analytically. | |
17 For each (input[t],output[t]) pair in a minibatch,:: | |
18 | |
19 output_t = b + W * input_t | |
20 | |
21 where b and W are obtained by minimizing:: | |
22 | |
23 lambda sum_{ij} W_{ij}^2 + sum_t ||output_t - target_t||^2 | |
24 | |
25 Let X be the whole training set inputs matrix (one input example per row), | |
26 with the first column full of 1's, and Let Y the whole training set | |
27 targets matrix (one example's target vector per row). | |
28 Let theta = the matrix with b in its first column and W in the others, | |
29 then each theta[:,i] is the solution of the linear system:: | |
30 | |
31 XtX * theta[:,i] = XtY[:,i] | |
32 | |
33 where XtX is a (n_inputs+1)x(n_inputs+1) matrix containing X'*X | |
34 plus lambda on the diagonal except at (0,0), | |
35 and XtY is a (n_inputs+1)*n_outputs matrix containing X'*Y. | |
36 | |
37 The fields and attributes expected and produced by use and update are the following: | |
38 | |
39 - Input and output fields (example-wise quantities): | |
40 | |
41 - 'input' (always expected by use and update as an input_dataset field) | |
42 - 'target' (optionally expected by use and update as an input_dataset field) | |
43 - 'output' (optionally produced by use as an output dataset field) | |
44 - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error | |
45 | |
46 - optional input attributes (optionally expected as input_dataset attributes) | |
47 | |
48 - 'lambda' (only used by update) | |
49 - 'b' (only used by use) | |
50 - 'W' (only used by use) | |
51 | |
52 - optional output attributes (available in self and optionally in output dataset) | |
53 | |
54 - 'b' (only set by update) | |
55 - 'W' (only set by update) | |
56 - 'total_squared_error' (set by use and by update) = sum over examples of example_wise_squared_error | |
57 - 'total_loss' (set by use and by update) = regularizer + total_squared_error | |
58 - 'XtX' (only set by update) | |
59 - 'XtY' (only set by update) | |
60 | |
61 """ | |
62 | |
63 def __init__(self,lambda=0.): | |
64 """ | |
65 @type lambda: float | |
66 @param lambda: regularization coefficient | |
67 """ | |
68 | |
69 W=t.matrix('W') | |
70 # b is a broadcastable row vector (can be replicated into | |
71 # as many rows as there are examples in the minibach) | |
72 b=t.row('b') | |
73 minibatch_input = t.matrix('input') # n_examples x n_inputs | |
74 minibatch_target = t.matrix('target') # n_examples x n_outputs | |
75 minibatch_output = t.dot(minibatch_input,W.T) + b # n_examples x n_outputs | |
76 lambda = as_scalar(lambda) | |
77 regularizer = self.lambda * t.dot(W,W) | |
78 example_squared_error = t.sum_within_rows(t.sqr(minibatch_output-minibatch_target)) | |
79 self.output_function = Function([W,b,minibatch_input],[minibatch_output]) | |
80 self.squared_error_function = Function([minibatch_output,minibatch_target],[self.example_squared_error]) | |
81 self.loss_function = Function([W,squared_error],[self.regularizer + t.sum(self.example_squared_error)]) | |
82 self.W=None | |
83 self.b=None | |
84 self.XtX=None | |
85 self.XtY=None | |
86 | |
87 def forget(self): | |
88 if self.W: | |
89 self.XtX *= 0 | |
90 self.XtY *= 0 | |
91 | |
92 def use(self,input_dataset,output_fieldnames=None,copy_inputs=True): | |
93 input_fieldnames = input_dataset.fieldNames() | |
94 assert "input" in input_fieldnames | |
95 if not output_fields: | |
96 output_fields = ["output"] | |
97 if "target" in input_fieldnames: | |
98 output_fields += ["squared_error"] | |
99 else: | |
100 if "squared_error" in output_fields or "total_loss" in output_fields: | |
101 assert "target" in input_fieldnames | |
102 | |
103 use_functions = [] | |
104 for output_fieldname in output_fieldnames: | |
105 if output_fieldname=="output": | |
106 use_functions.append(self.output_function) | |
107 elif output_fieldname=="squared_error": | |
108 use_functions.append(lambda self.output_function) | |
109 |