diff learner.py @ 107:c4916445e025

Comments from Pascal V.
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Tue, 06 May 2008 19:54:43 -0400
parents c4726e19b8ec
children d97f6fe6bdf9
line wrap: on
line diff
--- a/learner.py	Tue May 06 10:53:28 2008 -0400
+++ b/learner.py	Tue May 06 19:54:43 2008 -0400
@@ -61,6 +61,8 @@
         """
         A Learner may have attributes that it wishes to export to other objects. To automate
         such export, sub-classes should define here the names (list of strings) of these attributes.
+
+        @todo By default, attributeNames looks for all dictionary entries whose name does not start with _.
         """
         return []
 
@@ -68,7 +70,7 @@
     """
     TLearner is a virtual class of Learners that attempts to factor out of the definition
     of a learner the steps that are common to many implementations of learning algorithms,
-    so as to leave only "the equations" to define in particular sub-classes, using Theano.
+    so as to leave only 'the equations' to define in particular sub-classes, using Theano.
 
     In the default implementations of use and update, it is assumed that the 'use' and 'update' methods
     visit examples in the input dataset sequentially. In the 'use' method only one pass through the dataset is done,
@@ -85,14 +87,6 @@
                           or by a stats collector.
       - defaultOutputFields(input_fields): return a list of default dataset output fields when
                           None are provided by the caller of use.
-      - update_start(), update_end(), update_minibatch(minibatch): functions
-                          executed at the beginning, the end, and in the middle
-                          (for each minibatch) of the update method. This model only
-                          works for 'online' or one-short learning that requires
-                          going only once through the training data. For more complicated
-                          models, more specialized subclasses of TLearner should be used
-                          or a learning-algorithm specific update method should be defined.
-
     The following naming convention is assumed and important.
     Attributes whose names are listed in attributeNames() can be of any type,
     but those that can be referenced as input/output dataset fields or as
@@ -102,6 +96,9 @@
     the TLearner, created in the sub-class constructor) should be _<name>.
     Typically <name> will be numpy ndarray and _<name> will be the corresponding
     Theano Tensor (for symbolic manipulation).
+
+    @todo pousser dans Learner toute la poutine qui peut l'etre sans etre
+    dependant de Theano
     """
 
     def __init__(self):
@@ -148,10 +145,14 @@
             else:
                 return [self.__getattr__(name).data for name in names]
 
-    def use(self,input_dataset,output_fieldnames=None,output_attributes=None,
+    def use(self,input_dataset,output_fieldnames=None,output_attributes=[],
             test_stats_collector=None,copy_inputs=True):
         """
-        The learner tries to compute in the output dataset the output fields specified 
+        The learner tries to compute in the output dataset the output fields specified
+
+        @todo check if some of the learner attributes are actually SPECIFIED
+        as attributes of the input_dataset, and if so use their values instead
+        of the ones in the learner.
         """
         minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(),
                                                                   output_fieldnames,
@@ -165,6 +166,8 @@
         if copy_inputs:
             output_dataset = input_dataset | output_dataset
         # copy the wanted attributes in the dataset
+        if output_attributes is None:
+            output_attributes = self.attributeNames()
         if output_attributes:
             assert set(output_attributes) <= set(self.attributeNames())
             output_dataset.setAttributes(output_attributes,
@@ -175,21 +178,46 @@
                                          test_stats_collector.attributes())
         return output_dataset
 
+
+class OneShotTLearner(TLearner):
+    """
+    This adds to TLearner a 
+      - update_start(), update_end(), update_minibatch(minibatch), end_epoch():
+                          functions executed at the beginning, the end, in the middle
+                          (for each minibatch) of the update method, and at the end
+                          of each epoch. This model only
+                          works for 'online' or one-shot learning that requires
+                          going only once through the training data. For more complicated
+                          models, more specialized subclasses of TLearner should be used
+                          or a learning-algorithm specific update method should be defined.
+    """
+
+    def __init__(self):
+        TLearner.__init__(self)
+        
     def update_start(self): pass
     def update_end(self): pass
     def update_minibatch(self,minibatch):
         raise AbstractFunction()
     
     def update(self,training_set,train_stats_collector=None):
-        
+        """
+        @todo check if some of the learner attributes are actually SPECIFIED
+        in as attributes of the training_set.
+        """
         self.update_start()
-        for minibatch in training_set.minibatches(self.training_set_input_fields,
-                                                  minibatch_size=self.minibatch_size):
-            self.update_minibatch(minibatch)
+        stop=False
+        while not stop:
             if train_stats_collector:
-                minibatch_set = minibatch.examples()
-                minibatch_set.setAttributes(self.attributeNames(),self.attributes())
-                train_stats_collector.update(minibatch_set)
+                train_stats_collector.forget() # restart stats collectin at the beginning of each epoch
+            for minibatch in training_set.minibatches(self.training_set_input_fields,
+                                                      minibatch_size=self.minibatch_size):
+                self.update_minibatch(minibatch)
+                if train_stats_collector:
+                    minibatch_set = minibatch.examples()
+                    minibatch_set.setAttributes(self.attributeNames(),self.attributes())
+                    train_stats_collector.update(minibatch_set)
+            stop = self.end_epoch()
         self.update_end()
         return self.use