diff learner.py @ 128:ee5507af2c60

minor edits
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Wed, 07 May 2008 20:51:24 -0400
parents 4efe6d36c061
children 4c2280edcaf5 3d8e40e7ed18
line wrap: on
line diff
--- a/learner.py	Wed May 07 16:58:06 2008 -0400
+++ b/learner.py	Wed May 07 20:51:24 2008 -0400
@@ -47,17 +47,74 @@
         self.forget()
         return self.update(learning_task,train_stats_collector)
 
-    def use(self,input_dataset,output_fields=None,copy_inputs=True):
-        """Once a Learner has been trained by one or more call to 'update', it can
-        be used with one or more calls to 'use'. The argument is a DataSet (possibly
-        containing a single example) and the result is a DataSet of the same length.
-        If output_fields is specified, it may be use to indicate which fields should
+    def use(self,input_dataset,output_fieldnames=None,
+            test_stats_collector=None,copy_inputs=True,
+            put_stats_in_output_dataset=True,
+            output_attributes=[]):
+        """
+        Once a Learner has been trained by one or more call to 'update', it can
+        be used with one or more calls to 'use'. The argument is an input DataSet (possibly
+        containing a single example) and the result is an output DataSet of the same length.
+        If output_fieldnames is specified, it may be use to indicate which fields should
         be constructed in the output DataSet (for example ['output','classification_error']).
+        Otherwise, self.defaultOutputFields is called to choose the output fields.
         Optionally, if copy_inputs, the input fields (of the input_dataset) can be made
         visible in the output DataSet returned by this method.
+        Optionally, attributes of the learner can be copied in the output dataset,
+        and statistics computed by the stats collector also put in the output dataset.
+        Note the distinction between fields (which are example-wise quantities, e.g. 'input')
+        and attributes (which are not, e.g. 'regularization_term').
+
+        We provide here a default implementation that does all this using
+        a sub-class defined method: minibatchwiseUseFunction.
+        
+        @todo check if some of the learner attributes are actually SPECIFIED
+        as attributes of the input_dataset, and if so use their values instead
+        of the ones in the learner.
+
+        The learner tries to compute in the output dataset the output fields specified.
+        If None is specified then self.defaultOutputFields(input_dataset.fieldNames())
+        is called to determine the output fields.
+
+        Attributes of the learner can also optionally be copied into the output dataset.
+        If output_attributes is None then all of the attributes in self.AttributeNames()
+        are copied in the output dataset, but if it is [] (the default), then none are copied.
+        If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
+        are also copied into the output dataset attributes.
         """
-        raise AbstractFunction()
+        minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(),
+                                                                   output_fieldnames,
+                                                                   test_stats_collector)
+        virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
+                                                      minibatchwise_use_function,
+                                                      True,DataSet.numpy_vstack,
+                                                      DataSet.numpy_hstack)
+        # actually force the computation
+        output_dataset = CachedDataSet(virtual_output_dataset,True)
+        if copy_inputs:
+            output_dataset = input_dataset | output_dataset
+        # copy the wanted attributes in the dataset
+        if output_attributes is None:
+            output_attributes = self.attributeNames()
+        if output_attributes:
+            assert set(attribute_names) <= set(self.attributeNames())
+            output_dataset.setAttributes(output_attributes,
+                                         self.names2attributes(output_attributes,return_copy=True))
+        if test_stats_collector:
+            test_stats_collector.update(output_dataset)
+            if put_stats_in_output_dataset:
+                output_dataset.setAttributes(test_stats_collector.attributeNames(),
+                                             test_stats_collector.attributes())
+        return output_dataset
 
+    def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector):
+        """
+        Returns a function that can map the given input fields to the given output fields
+        and to the attributes that the stats collector needs for its computation.
+        That function is expected to operate on minibatches.
+        The function returned makes use of the self.useInputAttributes() and
+        sets the attributes specified by self.useOutputAttributes().
+        """
     def attributeNames(self):
         """
         A Learner may have attributes that it wishes to export to other objects. To automate
@@ -67,6 +124,22 @@
         """
         return []
 
+    def attributes(self,return_copy=False):
+        """
+        Return a list with the values of the learner's attributes (or optionally, a deep copy).
+        """
+        return self.names2attributes(self.attributeNames(),return_copy)
+
+    def names2attributes(self,names,return_copy=False):
+        """
+        Private helper function that maps a list of attribute names to a list
+        of (optionally copies) values of attributes.
+        """
+        if return_copy:
+            return [copy.deepcopy(self.__getattr__(name).data) for name in names]
+        else:
+            return [self.__getattr__(name).data for name in names]
+
     def updateInputAttributes(self):
         """
         A subset of self.attributeNames() which are the names of attributes needed by update() in order
@@ -145,22 +218,10 @@
         """
         raise AbstractFunction()
 
-    def allocate(self, minibatch):
-        """
-        This function is called at the beginning of each updateMinibatch
-        and should be used to check that all required attributes have been
-        allocated and initialized (usually this function calls forget()
-        when it has to do an initialization).
+    def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector):
         """
-        raise AbstractFunction()
-        
-    def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
-        """
-        Private helper function called by the generic TLearner.use. It returns a function
-        that can map the given input fields to the given output fields (along with the
-        attributes that the stats collector needs for its computation. The function
-        called also automatically makes use of the self.useInputAttributes() and
-        sets the self.useOutputAttributes().
+        Implement minibatchwiseUseFunction by exploiting Theano compilation
+        and the expression graph defined by a sub-class constructor.
         """
         if not output_fields:
             output_fields = self.defaultOutputFields(input_fields)
@@ -186,22 +247,6 @@
             self.use_functions_dictionary[key]=f
         return self.use_functions_dictionary[key]
 
-    def attributes(self,return_copy=False):
-        """
-        Return a list with the values of the learner's attributes (or optionally, a deep copy).
-        """
-        return self.names2attributes(self.attributeNames(),return_copy)
-
-    def names2attributes(self,names,return_copy=False):
-        """
-        Private helper function that maps a list of attribute names to a list
-        of (optionally copies) values of attributes.
-        """
-        if return_copy:
-            return [copy.deepcopy(self.__getattr__(name).data) for name in names]
-        else:
-            return [self.__getattr__(name).data for name in names]
-
     def names2OpResults(self,names):
         """
         Private helper function that maps a list of attribute names to a list
@@ -209,50 +254,6 @@
         """
         return [self.__getattr__('_'+name).data for name in names]
 
-    def use(self,input_dataset,output_fieldnames=None,output_attributes=[],
-            test_stats_collector=None,copy_inputs=True, put_stats_in_output_dataset=True):
-        """
-        The learner tries to compute in the output dataset the output fields specified
-
-        @todo check if some of the learner attributes are actually SPECIFIED
-        as attributes of the input_dataset, and if so use their values instead
-        of the ones in the learner.
-
-        The learner tries to compute in the output dataset the output fields specified.
-        If None is specified then self.defaultOutputFields(input_dataset.fieldNames())
-        is called to determine the output fields.
-
-        Attributes of the learner can also optionally be copied into the output dataset.
-        If output_attributes is None then all of the attributes in self.AttributeNames()
-        are copied in the output dataset, but if it is [] (the default), then none are copied.
-        If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
-        are also copied into the output dataset attributes.
-        """
-        minibatchwise_use_function = self.minibatchwise_use_functions(input_dataset.fieldNames(),
-                                                                      output_fieldnames,
-                                                                      test_stats_collector)
-        virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
-                                                      minibatchwise_use_function,
-                                                      True,DataSet.numpy_vstack,
-                                                      DataSet.numpy_hstack)
-        # actually force the computation
-        output_dataset = CachedDataSet(virtual_output_dataset,True)
-        if copy_inputs:
-            output_dataset = input_dataset | output_dataset
-        # copy the wanted attributes in the dataset
-        if output_attributes is None:
-            output_attributes = self.attributeNames()
-        if output_attributes:
-            assert set(attribute_names) <= set(self.attributeNames())
-            output_dataset.setAttributes(output_attributes,
-                                         self.names2attributes(output_attributes,return_copy=True))
-        if test_stats_collector:
-            test_stats_collector.update(output_dataset)
-            if put_stats_in_output_dataset:
-                output_dataset.setAttributes(test_stats_collector.attributeNames(),
-                                             test_stats_collector.attributes())
-        return output_dataset
-
 
 class MinibatchUpdatesTLearner(TLearner):
     """
@@ -281,6 +282,15 @@
         (self.names2OpResults(self.updateEndInputAttributes()),
          self.names2OpResults(self.updateEndOutputAttributes()))
 
+    def allocate(self, minibatch):
+        """
+        This function is called at the beginning of each updateMinibatch
+        and should be used to check that all required attributes have been
+        allocated and initialized (usually this function calls forget()
+        when it has to do an initialization).
+        """
+        raise AbstractFunction()
+        
     def updateMinibatchInputFields(self):
         raise AbstractFunction()