Mercurial > pylearn
diff pylearn/algorithms/linear_regression.py @ 1505:723e2d761985
auto white space fix.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Mon, 12 Sep 2011 10:49:15 -0400 |
parents | bf5c0f797161 |
children |
line wrap: on
line diff
--- a/pylearn/algorithms/linear_regression.py Mon Sep 12 10:48:33 2011 -0400 +++ b/pylearn/algorithms/linear_regression.py Mon Sep 12 10:49:15 2011 -0400 @@ -26,8 +26,8 @@ outputs, errors = linear_predictor.compute_outputs_and_errors(inputs,targets) errors = linear_predictor.compute_errors(inputs,targets) mse = linear_predictor.compute_mse(inputs,targets) - - + + The training_set must have fields "input" and "target". The test_set must have field "input", and needs "target" if @@ -36,7 +36,7 @@ The predictor parameters are obtained analytically from the training set. For each (input[t],output[t]) pair in a minibatch,:: - + output_t = b + W * input_t where b and W are obtained by minimizing:: @@ -109,7 +109,7 @@ def __init__(self): self.compile() - + class LinearRegressionEquations(LinearPredictorEquations): P = LinearPredictorEquations XtX = T.matrix() # (n_inputs+1) x (n_inputs+1) @@ -119,7 +119,7 @@ new_XtY = T.add(XtY,T.dot(extended_input.T,P.targets)) __compiled = False - + @classmethod def compile(cls, mode="FAST_RUN"): if cls.__compiled: @@ -156,7 +156,7 @@ def compute_mse(self,inputs,targets): errors = self.compute_errors(inputs,targets) return numpy.sum(errors)/errors.size - + def __call__(self,dataset,output_fieldnames=None,cached_output_dataset=False): assert dataset.hasFields(["input"]) if output_fieldnames is None: @@ -173,17 +173,17 @@ f = self.compute_outputs_and_errors else: raise ValueError("unknown field(s) in output_fieldnames: "+str(output_fieldnames)) - + ds=ApplyFunctionDataSet(dataset,f,output_fieldnames) if cached_output_dataset: return CachedDataSet(ds) else: return ds - + def linear_predictor(inputs,params,*otherargs): - p = LinearPredictor(params) - return p.compute_outputs(inputs) + p = LinearPredictor(params) + return p.compute_outputs(inputs) #TODO : an online version class OnlineLinearRegression():#OnlineLearningAlgorithm): @@ -195,7 +195,3 @@ the last call to forget). """ pass - - - -