Mercurial > pylearn
diff linear_regression.py @ 435:eac0a7d44ff0
merge
author | Olivier Breuleux <breuleuo@iro.umontreal.ca> |
---|---|
date | Mon, 04 Aug 2008 16:29:30 -0400 |
parents | 8e4d2ebd816a |
children | 111e547ffa7b |
line wrap: on
line diff
--- a/linear_regression.py Mon Aug 04 16:21:59 2008 -0400 +++ b/linear_regression.py Mon Aug 04 16:29:30 2008 -0400 @@ -4,9 +4,9 @@ the use of theano. """ -from pylearn.learner import OfflineLearningAlgorithm +from pylearn.learner import OfflineLearningAlgorithm,OnlineLearningAlgorithm from theano import tensor as T -from theano.others_ops import prepend_1_to_each_row +from nnet_ops import prepend_1_to_each_row from theano.scalar import as_scalar from common.autoname import AutoName import theano @@ -34,11 +34,6 @@ we want to compute the squared errors. The predictor parameters are obtained analytically from the training set. - Training can proceed sequentially (with multiple calls to update with - different disjoint subsets of the training sets). After each call to - update the predictor is ready to be used (and optimized for the union - of all the training sets passed to update since construction or since - the last call to forget). For each (input[t],output[t]) pair in a minibatch,:: @@ -74,7 +69,7 @@ def __init__(self, L2_regularizer=0,minibatch_size=10000): self.L2_regularizer=L2_regularizer self.equations = LinearRegressionEquations() - self.minibatch_size=1000 + self.minibatch_size=minibatch_size def __call__(self,trainset): first_example = trainset[0] @@ -186,3 +181,21 @@ return ds +def linear_predictor(inputs,params,*otherargs): + p = LinearPredictor(params) + return p.compute_outputs(inputs) + +#TODO : an online version +class OnlineLinearRegression(OnlineLearningAlgorithm): + """ + Training can proceed sequentially (with multiple calls to update with + different disjoint subsets of the training sets). After each call to + update the predictor is ready to be used (and optimized for the union + of all the training sets passed to update since construction or since + the last call to forget). + """ + pass + + + +