diff learner.py @ 110:8fa1ef2411a0

Worked on OneShotTLearner and implementation of LinearRegression
author bengioy@bengiomac.local
date Tue, 06 May 2008 22:24:55 -0400
parents d97f6fe6bdf9
children 88257dfedf8c
line wrap: on
line diff
--- a/learner.py	Tue May 06 20:01:34 2008 -0400
+++ b/learner.py	Tue May 06 22:24:55 2008 -0400
@@ -1,7 +1,7 @@
 
 from dataset import *
     
-class Learner(object):
+class Learner(AttributesHolder):
     """Base class for learning algorithms, provides an interface
     that allows various algorithms to be applicable to generic learning
     algorithms.
@@ -66,6 +66,35 @@
         """
         return []
 
+    def updateInputAttributes(self):
+        """
+        A subset of self.attributeNames() which are the names of attributes needed by update() in order
+        to do its work.
+        """
+        raise AbstractFunction()
+
+    def useInputAttributes(self):
+        """
+        A subset of self.attributeNames() which are the names of attributes needed by use() in order
+        to do its work.
+        """
+        raise AbstractFunction()
+
+    def updateOutputAttributes(self):
+        """
+        A subset of self.attributeNames() which are the names of attributes modified/created by update() in order
+        to do its work.
+        """
+        raise AbstractFunction()
+
+    def useOutputAttributes(self):
+        """
+        A subset of self.attributeNames() which are the names of attributes modified/created by use() in order
+        to do its work.
+        """
+        raise AbstractFunction()
+
+    
 class TLearner(Learner):
     """
     TLearner is a virtual class of Learners that attempts to factor out of the definition
@@ -103,50 +132,82 @@
 
     def __init__(self):
         Learner.__init__(self)
+
+    def defaultOutputFields(self, input_fields):
+        """
+        Return a default list of output field names (to put in the output dataset).
+        This will be used when None are provided (as output_fields) by the caller of the 'use' method.
+        This may involve looking at the input_fields (names) available in the
+        input_dataset.
+        """
+        raise AbstractFunction()
+
+    def allocate(self, minibatch):
+        """
+        This function is called at the beginning of each updateMinibatch
+        and should be used to check that all required attributes have been
+        allocated and initialized (usually this function calls forget()
+        when it has to do an initialization).
+        """
+        raise AbstractFunction()
         
-    def _minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
+    def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
         """
         Private helper function called by the generic TLearner.use. It returns a function
         that can map the given input fields to the given output fields (along with the
-        attributes that the stats collector needs for its computation.
+        attributes that the stats collector needs for its computation. The function
+        called also automatically makes use of the self.useInputAttributes() and
+        sets the self.useOutputAttributes().
         """
         if not output_fields:
             output_fields = self.defaultOutputFields(input_fields)
         if stats_collector:
-            stats_collector_inputs = stats_collector.inputUpdateAttributes()
+            stats_collector_inputs = stats_collector.input2UpdateAttributes()
             for attribute in stats_collector_inputs:
                 if attribute not in input_fields:
                     output_fields.append(attribute)
         key = (input_fields,output_fields)
         if key not in self.use_functions_dictionary:
-            self.use_functions_dictionary[key]=Function(self._names2attributes(input_fields),
-                                                   self._names2attributes(output_fields))
+            use_input_attributes = self.useInputAttributes()
+            use_output_attributes = self.useOutputAttributes()
+            complete_f = Function(self.names2OpResults(input_fields+use_input_attributes),
+                                  self.names2OpResults(output_fields+use_output_attributes))
+            def f(*input_field_values):
+                input_attribute_values = self.names2attributes(use_input_attributes)
+                results = complete_f(*(input_field_values + input_attribute_values))
+                output_field_values = results[0:len(output_fields)]
+                output_attribute_values = results[len(output_fields):len(results)]
+                if use_output_attributes:
+                    self.setAttributes(use_output_attributes,output_attribute_values)
+                return output_field_values
+            self.use_functions_dictionary[key]=f
         return self.use_functions_dictionary[key]
 
     def attributes(self,return_copy=False):
         """
         Return a list with the values of the learner's attributes (or optionally, a deep copy).
         """
-        return self.names2attributes(self.attributeNames())
-            
-    def _names2attributes(self,names,return_Result=False, return_copy=False):
+        return self.names2attributes(self.attributeNames(),return_copy)
+
+    def names2attributes(self,names,return_copy=False):
         """
         Private helper function that maps a list of attribute names to a list
-        of (optionally copies) values or of the Result objects that own these values.
+        of (optionally copies) values of attributes.
         """
-        if return_Result:
-            if return_copy:
-                return [copy.deepcopy(self.__getattr__(name)) for name in names]
-            else:
-                return [self.__getattr__(name) for name in names]
+        if return_copy:
+            return [copy.deepcopy(self.__getattr__(name).data) for name in names]
         else:
-            if return_copy:
-                return [copy.deepcopy(self.__getattr__(name).data) for name in names]
-            else:
-                return [self.__getattr__(name).data for name in names]
+            return [self.__getattr__(name).data for name in names]
+
+    def names2OpResults(self,names):
+        """
+        Private helper function that maps a list of attribute names to a list
+        of corresponding Op Results (with the same name but with a '_' prefix).
+        """
+        return [self.__getattr__('_'+name).data for name in names]
 
     def use(self,input_dataset,output_fieldnames=None,output_attributes=[],
-            test_stats_collector=None,copy_inputs=True):
+            test_stats_collector=None,copy_inputs=True, put_stats_in_output_dataset=True):
         """
         The learner tries to compute in the output dataset the output fields specified
 
@@ -164,7 +225,7 @@
         If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
         are also copied into the output dataset attributes.
         """
-        minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(),
+        minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(),
                                                                   output_fieldnames,
                                                                   test_stats_collector)
         virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
@@ -179,20 +240,21 @@
         if output_attributes is None:
             output_attributes = self.attributeNames()
         if output_attributes:
-            assert set(output_attributes) <= set(self.attributeNames())
+            assert set(attribute_names) <= set(self.attributeNames())
             output_dataset.setAttributes(output_attributes,
-                                         self._names2attributes(output_attributes,return_copy=True))
+                                         self.names2attributes(output_attributes,return_copy=True))
         if test_stats_collector:
             test_stats_collector.update(output_dataset)
-            output_dataset.setAttributes(test_stats_collector.attributeNames(),
-                                         test_stats_collector.attributes())
+            if put_stats_in_output_dataset:
+                output_dataset.setAttributes(test_stats_collector.attributeNames(),
+                                             test_stats_collector.attributes())
         return output_dataset
 
 
 class OneShotTLearner(TLearner):
     """
     This adds to TLearner a 
-      - update_start(), update_end(), update_minibatch(minibatch), end_epoch():
+      - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch():
                           functions executed at the beginning, the end, in the middle
                           (for each minibatch) of the update method, and at the end
                           of each epoch. This model only
@@ -204,18 +266,56 @@
 
     def __init__(self):
         TLearner.__init__(self)
+        self.update_minibatch_function =
+        Function(self.names2OpResults(self.updateMinibatchOutputAttributes()+
+                                      self.updateMinibatchInputFields()),
+                 self.names2OpResults(self.updateMinibatchOutputAttributes()))
+        self.update_end_function = Function(self.names2OpResults(self.updateEndInputAttributes()),
+                                            self.names2OpResults(self.updateEndOutputAttributes()))
+
+    def updateMinibatchInputFields(self):
+        raise AbstractFunction()
+    
+    def updateMinibatchInputAttributes(self):
+        raise AbstractFunction()
+    
+    def updateMinibatchOutputAttributes(self):
+        raise AbstractFunction()
+    
+    def updateEndInputAttributes(self):
+        raise AbstractFunction()
+
+    def updateEndOutputAttributes(self):
+        raise AbstractFunction()
+
+    def updateStart(self): pass
+
+    def updateEnd(self):
+        self.setAttributes(self.updateEndOutputAttributes(),
+                           self.update_end_function
+                           (self.names2attributes(self.updateEndInputAttributes())))
         
-    def update_start(self): pass
-    def update_end(self): pass
-    def update_minibatch(self,minibatch):
-        raise AbstractFunction()
+    def updateMinibatch(self,minibatch):
+        # make sure all required fields are allocated and initialized
+        self.allocate(minibatch)
+        self.setAttributes(self.updateMinibatchOutputAttributes(),
+                           self.update_minibatch_function(*(self.names2attributes(self.updateMinibatchInputAttributes()))
+                                                          + minibatch(self.updateMinibatchInputFields())))
+        
+    def isLastEpoch(self):
+        """
+        This method is called at the end of each epoch (cycling over the training set).
+        It returns a boolean to indicate if this is the last epoch.
+        By default just do one epoch.
+        """
+        return True
     
     def update(self,training_set,train_stats_collector=None):
         """
         @todo check if some of the learner attributes are actually SPECIFIED
         in as attributes of the training_set.
         """
-        self.update_start()
+        self.updateStart(training_set)
         stop=False
         while not stop:
             if train_stats_collector:
@@ -227,7 +327,7 @@
                     minibatch_set = minibatch.examples()
                     minibatch_set.setAttributes(self.attributeNames(),self.attributes())
                     train_stats_collector.update(minibatch_set)
-            stop = self.end_epoch()
-        self.update_end()
+            stop = self.isLastEpoch()
+        self.updateEnd()
         return self.use