comparison linear_regression.py @ 107:c4916445e025

Comments from Pascal V.
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Tue, 06 May 2008 19:54:43 -0400
parents c4726e19b8ec
children 8fa1ef2411a0
comparison
equal deleted inserted replaced
97:05cfe011ca20 107:c4916445e025
4 from compile import Function 4 from compile import Function
5 from theano.scalar import as_scalar 5 from theano.scalar import as_scalar
6 6
7 # this is one of the simplest example of learner, and illustrates 7 # this is one of the simplest example of learner, and illustrates
8 # the use of theano 8 # the use of theano
9 class LinearRegression(Learner): 9 class LinearRegression(OneShotTLearner):
10 """ 10 """
11 Implement linear regression, with or without L2 regularization 11 Implement linear regression, with or without L2 regularization
12 (the former is called Ridge Regression and the latter Ordinary Least Squares). 12 (the former is called Ridge Regression and the latter Ordinary Least Squares).
13 13
14 The predictor parameters are obtained analytically from the training set. 14 The predictor parameters are obtained analytically from the training set.
46 - 'input' (always expected by use and update as an input_dataset field) 46 - 'input' (always expected by use and update as an input_dataset field)
47 - 'target' (optionally expected by use and update as an input_dataset field) 47 - 'target' (optionally expected by use and update as an input_dataset field)
48 - 'output' (optionally produced by use as an output dataset field) 48 - 'output' (optionally produced by use as an output dataset field)
49 - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error 49 - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error
50 50
51 - optional input attributes (optionally expected as input_dataset attributes)
52
53 - optional attributes (optionally expected as input_dataset attributes) 51 - optional attributes (optionally expected as input_dataset attributes)
54 (warning, this may be dangerous, the 'use' method will use those provided in the 52 (warning, this may be dangerous, the 'use' method will use those provided in the
55 input_dataset rather than those learned during 'update'; currently no support 53 input_dataset rather than those learned during 'update'; currently no support
56 for providing these to update): 54 for providing these to update):
57 55
58 - 'lambda' 56 - 'lambda'
59 - 'b' 57 - 'b'
60 - 'W' 58 - 'W'
61 - 'regularization_term' 59 - 'regularization_term'
62 - 'XtX' 60
63 - 'XtY'
64 """ 61 """
65 62
66 def attributeNames(self): 63 def attributeNames(self):
67 return ["lambda","b","W","regularization_term","XtX","XtY"] 64 return ["lambda","b","W","regularization_term"]
68 65
69 # definitions specifiques a la regression lineaire:
70 66
71 67 def __init__(self):
72 def global_inputs(self): 68 self.input = t.matrix('input') # n_examples x n_inputs
69 self.target = t.matrix('target') # n_examples x n_outputs
73 self.lambda = as_scalar(0.,'lambda') 70 self.lambda = as_scalar(0.,'lambda')
74 self.theta = t.matrix('theta') 71 self.theta = t.matrix('theta')
75 self.W = self.theta[:,1:] 72 self.W = self.theta[:,1:]
76 self.b = self.theta[:,0] 73 self.b = self.theta[:,0]
77 self.XtX = t.matrix('XtX') 74 self.XtX = t.matrix('XtX')
78 self.XtY = t.matrix('XtY') 75 self.XtY = t.matrix('XtY')
79
80 def global_outputs(self):
81 self.regularizer = self.lambda * t.dot(self.W,self.W) 76 self.regularizer = self.lambda * t.dot(self.W,self.W)
77 self.squared_error =
82 self.loss = self.regularizer + t.sum(self.squared_error) # this only makes sense if the whole training set fits in memory in a minibatch 78 self.loss = self.regularizer + t.sum(self.squared_error) # this only makes sense if the whole training set fits in memory in a minibatch
83 self.loss_function = Function([self.W,self.lambda,self.squared_error],[self.loss]) 79 self.loss_function = Function([self.W,self.lambda,self.squared_error],[self.loss])
80 self.new_XtX = self.XtX + t.dot(self.extended_input.T,self.extended_input)
81 self.new_XtY = self.XtY + t.dot(self.extended_input.T,self.target)
82 self.new_theta = t.solve(self.XtX,self.XtY)
84 83
85 def initialize(self): 84 def initialize(self):
86 self.XtX.resize((1+self.n_inputs,1+self.n_inputs)) 85 self.XtX.resize((1+self.n_inputs,1+self.n_inputs))
87 self.XtY.resize((1+self.n_inputs,self.n_outputs)) 86 self.XtY.resize((1+self.n_inputs,self.n_outputs))
88 self.XtX.data[:,:]=0 87 self.XtX.data[:,:]=0
89 self.XtY.data[:,:]=0 88 self.XtY.data[:,:]=0
90 numpy.diag(self.XtX.data)[1:]=self.lambda.data 89 numpy.diag(self.XtX.data)[1:]=self.lambda.data
91 90
92 def updated_variables(self): 91 def updated_variables(self):
93 self.new_XtX = self.XtX + t.dot(self.extended_input.T,self.extended_input)
94 self.new_XtY = self.XtY + t.dot(self.extended_input.T,self.target)
95 self.new_theta = t.solve(self.XtX,self.XtY)
96 92
97 def minibatch_wise_inputs(self): 93 def minibatch_wise_inputs(self):
98 self.input = t.matrix('input') # n_examples x n_inputs
99 self.target = t.matrix('target') # n_examples x n_outputs
100
101 def minibatch_wise_outputs(self): 94 def minibatch_wise_outputs(self):
102 # self.input is a (n_examples, n_inputs) minibatch matrix 95 # self.input is a (n_examples, n_inputs) minibatch matrix
103 self.extended_input = t.prepend_one_to_each_row(self.input) 96 self.extended_input = t.prepend_one_to_each_row(self.input)
104 self.output = t.dot(self.input,self.W.T) + self.b # (n_examples , n_outputs) matrix 97 self.output = t.dot(self.input,self.W.T) + self.b # (n_examples , n_outputs) matrix
105 self.squared_error = t.sum_within_rows(t.sqr(self.output-self.target)) # (n_examples ) vector 98 self.squared_error = t.sum_within_rows(t.sqr(self.output-self.target)) # (n_examples ) vector