Mercurial > pylearn
comparison pylearn/algorithms/linear_regression.py @ 1505:723e2d761985
auto white space fix.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Mon, 12 Sep 2011 10:49:15 -0400 |
parents | bf5c0f797161 |
children |
comparison
equal
deleted
inserted
replaced
1504:bf5c0f797161 | 1505:723e2d761985 |
---|---|
24 all_results_dataset=linear_predictor(test_set) # creates a dataset with "output" and "squared_error" field | 24 all_results_dataset=linear_predictor(test_set) # creates a dataset with "output" and "squared_error" field |
25 outputs = linear_predictor.compute_outputs(inputs) # inputs and outputs are numpy arrays | 25 outputs = linear_predictor.compute_outputs(inputs) # inputs and outputs are numpy arrays |
26 outputs, errors = linear_predictor.compute_outputs_and_errors(inputs,targets) | 26 outputs, errors = linear_predictor.compute_outputs_and_errors(inputs,targets) |
27 errors = linear_predictor.compute_errors(inputs,targets) | 27 errors = linear_predictor.compute_errors(inputs,targets) |
28 mse = linear_predictor.compute_mse(inputs,targets) | 28 mse = linear_predictor.compute_mse(inputs,targets) |
29 | 29 |
30 | 30 |
31 | 31 |
32 The training_set must have fields "input" and "target". | 32 The training_set must have fields "input" and "target". |
33 The test_set must have field "input", and needs "target" if | 33 The test_set must have field "input", and needs "target" if |
34 we want to compute the squared errors. | 34 we want to compute the squared errors. |
35 | 35 |
36 The predictor parameters are obtained analytically from the training set. | 36 The predictor parameters are obtained analytically from the training set. |
37 | 37 |
38 For each (input[t],output[t]) pair in a minibatch,:: | 38 For each (input[t],output[t]) pair in a minibatch,:: |
39 | 39 |
40 output_t = b + W * input_t | 40 output_t = b + W * input_t |
41 | 41 |
42 where b and W are obtained by minimizing:: | 42 where b and W are obtained by minimizing:: |
43 | 43 |
44 L2_regularizer sum_{ij} W_{ij}^2 + sum_t ||output_t - target_t||^2 | 44 L2_regularizer sum_{ij} W_{ij}^2 + sum_t ||output_t - target_t||^2 |
107 | 107 |
108 cls.__compiled = True | 108 cls.__compiled = True |
109 | 109 |
110 def __init__(self): | 110 def __init__(self): |
111 self.compile() | 111 self.compile() |
112 | 112 |
113 class LinearRegressionEquations(LinearPredictorEquations): | 113 class LinearRegressionEquations(LinearPredictorEquations): |
114 P = LinearPredictorEquations | 114 P = LinearPredictorEquations |
115 XtX = T.matrix() # (n_inputs+1) x (n_inputs+1) | 115 XtX = T.matrix() # (n_inputs+1) x (n_inputs+1) |
116 XtY = T.matrix() # (n_inputs+1) x n_outputs | 116 XtY = T.matrix() # (n_inputs+1) x n_outputs |
117 extended_input = prepend_1_to_each_row(P.inputs) | 117 extended_input = prepend_1_to_each_row(P.inputs) |
118 new_XtX = T.add(XtX,T.dot(extended_input.T,extended_input)) | 118 new_XtX = T.add(XtX,T.dot(extended_input.T,extended_input)) |
119 new_XtY = T.add(XtY,T.dot(extended_input.T,P.targets)) | 119 new_XtY = T.add(XtY,T.dot(extended_input.T,P.targets)) |
120 | 120 |
121 __compiled = False | 121 __compiled = False |
122 | 122 |
123 @classmethod | 123 @classmethod |
124 def compile(cls, mode="FAST_RUN"): | 124 def compile(cls, mode="FAST_RUN"): |
125 if cls.__compiled: | 125 if cls.__compiled: |
126 return | 126 return |
127 def fn(input_vars,output_vars): | 127 def fn(input_vars,output_vars): |
154 outputs = self.compute_outputs(inputs) | 154 outputs = self.compute_outputs(inputs) |
155 return [outputs,self.equations.compute_errors(outputs,targets)] | 155 return [outputs,self.equations.compute_errors(outputs,targets)] |
156 def compute_mse(self,inputs,targets): | 156 def compute_mse(self,inputs,targets): |
157 errors = self.compute_errors(inputs,targets) | 157 errors = self.compute_errors(inputs,targets) |
158 return numpy.sum(errors)/errors.size | 158 return numpy.sum(errors)/errors.size |
159 | 159 |
160 def __call__(self,dataset,output_fieldnames=None,cached_output_dataset=False): | 160 def __call__(self,dataset,output_fieldnames=None,cached_output_dataset=False): |
161 assert dataset.hasFields(["input"]) | 161 assert dataset.hasFields(["input"]) |
162 if output_fieldnames is None: | 162 if output_fieldnames is None: |
163 if dataset.hasFields(["target"]): | 163 if dataset.hasFields(["target"]): |
164 output_fieldnames = ["output","squared_error"] | 164 output_fieldnames = ["output","squared_error"] |
171 f = self.compute_outputs | 171 f = self.compute_outputs |
172 elif output_fieldnames == ["output","squared_error"]: | 172 elif output_fieldnames == ["output","squared_error"]: |
173 f = self.compute_outputs_and_errors | 173 f = self.compute_outputs_and_errors |
174 else: | 174 else: |
175 raise ValueError("unknown field(s) in output_fieldnames: "+str(output_fieldnames)) | 175 raise ValueError("unknown field(s) in output_fieldnames: "+str(output_fieldnames)) |
176 | 176 |
177 ds=ApplyFunctionDataSet(dataset,f,output_fieldnames) | 177 ds=ApplyFunctionDataSet(dataset,f,output_fieldnames) |
178 if cached_output_dataset: | 178 if cached_output_dataset: |
179 return CachedDataSet(ds) | 179 return CachedDataSet(ds) |
180 else: | 180 else: |
181 return ds | 181 return ds |
182 | 182 |
183 | 183 |
184 def linear_predictor(inputs,params,*otherargs): | 184 def linear_predictor(inputs,params,*otherargs): |
185 p = LinearPredictor(params) | 185 p = LinearPredictor(params) |
186 return p.compute_outputs(inputs) | 186 return p.compute_outputs(inputs) |
187 | 187 |
188 #TODO : an online version | 188 #TODO : an online version |
189 class OnlineLinearRegression():#OnlineLearningAlgorithm): | 189 class OnlineLinearRegression():#OnlineLearningAlgorithm): |
190 """ | 190 """ |
191 Training can proceed sequentially (with multiple calls to update with | 191 Training can proceed sequentially (with multiple calls to update with |
193 update the predictor is ready to be used (and optimized for the union | 193 update the predictor is ready to be used (and optimized for the union |
194 of all the training sets passed to update since construction or since | 194 of all the training sets passed to update since construction or since |
195 the last call to forget). | 195 the last call to forget). |
196 """ | 196 """ |
197 pass | 197 pass |
198 | |
199 | |
200 | |
201 |