diff learner.py @ 111:88257dfedf8c

Added another work in progress, for mlp's
author bengioy@bengiomac.local
date Wed, 07 May 2008 09:16:04 -0400
parents 8fa1ef2411a0
children d0a1bd0378c6
line wrap: on
line diff
--- a/learner.py	Tue May 06 22:24:55 2008 -0400
+++ b/learner.py	Wed May 07 09:16:04 2008 -0400
@@ -1,5 +1,6 @@
 
 from dataset import *
+from compile import Function
     
 class Learner(AttributesHolder):
     """Base class for learning algorithms, provides an interface
@@ -84,8 +85,10 @@
         """
         A subset of self.attributeNames() which are the names of attributes modified/created by update() in order
         to do its work.
+
+        By default these are inferred from the various update output attributes:
         """
-        raise AbstractFunction()
+        return ["parameters"] + self.updateMinibatchOutputAttributes() + self.updateEndOutputAttributes()
 
     def useOutputAttributes(self):
         """
@@ -251,7 +254,7 @@
         return output_dataset
 
 
-class OneShotTLearner(TLearner):
+class MinibatchUpdatesTLearner(TLearner):
     """
     This adds to TLearner a 
       - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch():
@@ -262,6 +265,10 @@
                           going only once through the training data. For more complicated
                           models, more specialized subclasses of TLearner should be used
                           or a learning-algorithm specific update method should be defined.
+
+      - a 'parameters' attribute which is a list of parameters (whose names are
+      specified by the user's subclass with the parameterAttributes() method)
+      
     """
 
     def __init__(self):
@@ -288,12 +295,16 @@
     def updateEndOutputAttributes(self):
         raise AbstractFunction()
 
-    def updateStart(self): pass
+    def parameterAttributes(self):
+        raise AbstractFunction()
+
+        def updateStart(self): pass
 
     def updateEnd(self):
         self.setAttributes(self.updateEndOutputAttributes(),
                            self.update_end_function
                            (self.names2attributes(self.updateEndInputAttributes())))
+        self.parameters = self.names2attributes(self.parameterAttributes())
         
     def updateMinibatch(self,minibatch):
         # make sure all required fields are allocated and initialized
@@ -331,3 +342,22 @@
         self.updateEnd()
         return self.use
 
+class OnlineGradientBasedTLearner(MinibatchUpdatesTLearner):
+    """
+    Specialization of MinibatchUpdatesTLearner in which the minibatch updates
+    are obtained by performing an online (minibatch-based) gradient step.
+
+    Sub-classes must define the following methods:
+    
+    """
+    def __init__(self,truly_online=False):
+        """
+        If truly_online then only one pass is made through the training set passed to update().
+        
+        """
+        self.truly_online=truly_online
+
+    def isLastEpoch(self):
+        return self.truly_online
+
+