Mercurial > pylearn
diff learner.py @ 111:88257dfedf8c
Added another work in progress, for mlp's
author | bengioy@bengiomac.local |
---|---|
date | Wed, 07 May 2008 09:16:04 -0400 |
parents | 8fa1ef2411a0 |
children | d0a1bd0378c6 |
line wrap: on
line diff
--- a/learner.py Tue May 06 22:24:55 2008 -0400 +++ b/learner.py Wed May 07 09:16:04 2008 -0400 @@ -1,5 +1,6 @@ from dataset import * +from compile import Function class Learner(AttributesHolder): """Base class for learning algorithms, provides an interface @@ -84,8 +85,10 @@ """ A subset of self.attributeNames() which are the names of attributes modified/created by update() in order to do its work. + + By default these are inferred from the various update output attributes: """ - raise AbstractFunction() + return ["parameters"] + self.updateMinibatchOutputAttributes() + self.updateEndOutputAttributes() def useOutputAttributes(self): """ @@ -251,7 +254,7 @@ return output_dataset -class OneShotTLearner(TLearner): +class MinibatchUpdatesTLearner(TLearner): """ This adds to TLearner a - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch(): @@ -262,6 +265,10 @@ going only once through the training data. For more complicated models, more specialized subclasses of TLearner should be used or a learning-algorithm specific update method should be defined. + + - a 'parameters' attribute which is a list of parameters (whose names are + specified by the user's subclass with the parameterAttributes() method) + """ def __init__(self): @@ -288,12 +295,16 @@ def updateEndOutputAttributes(self): raise AbstractFunction() - def updateStart(self): pass + def parameterAttributes(self): + raise AbstractFunction() + + def updateStart(self): pass def updateEnd(self): self.setAttributes(self.updateEndOutputAttributes(), self.update_end_function (self.names2attributes(self.updateEndInputAttributes()))) + self.parameters = self.names2attributes(self.parameterAttributes()) def updateMinibatch(self,minibatch): # make sure all required fields are allocated and initialized @@ -331,3 +342,22 @@ self.updateEnd() return self.use +class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner): + """ + Specialization of MinibatchUpdatesTLearner in which the minibatch updates + are obtained by performing an online (minibatch-based) gradient step. + + Sub-classes must define the following methods: + + """ + def __init__(self,truly_online=False): + """ + If truly_online then only one pass is made through the training set passed to update(). + + """ + self.truly_online=truly_online + + def isLastEpoch(self): + return self.truly_online + +