view learner.py @ 93:a62c79ec7c8a

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 05 May 2008 18:14:44 -0400
parents c4726e19b8ec
children c4916445e025
line wrap: on
line source


from dataset import *
    
class Learner(object):
    """Base class for learning algorithms, provides an interface
    that allows various algorithms to be applicable to generic learning
    algorithms.

    A Learner can be seen as a learning algorithm, a function that when
    applied to training data returns a learned function, an object that
    can be applied to other data and return some output data.
    """
    
    def __init__(self):
        pass

    def forget(self):
        """
        Reset the state of the learner to a blank slate, before seeing
        training data. The operation may be non-deterministic if the
        learner has a random number generator that is set to use a
        different seed each time it forget() is called.
        """
        raise NotImplementedError

    def update(self,training_set,train_stats_collector=None):
        """
        Continue training a learner, with the evidence provided by the given training set.
        Hence update can be called multiple times. This is particularly useful in the
        on-line setting or the sequential (Bayesian or not) settings.
        The result is a function that can be applied on data, with the same
        semantics of the Learner.use method.

        The user may optionally provide a training StatsCollector that is used to record
        some statistics of the outputs computed during training. It is update(d) during
        training.
        """
        return self.use # default behavior is 'non-adaptive', i.e. update does not do anything
    
    
    def __call__(self,training_set,train_stats_collector=None):
        """
        Train a learner from scratch using the provided training set,
        and return the learned function.
        """
        self.forget()
        return self.update(learning_task,train_stats_collector)

    def use(self,input_dataset,output_fields=None,copy_inputs=True):
        """Once a Learner has been trained by one or more call to 'update', it can
        be used with one or more calls to 'use'. The argument is a DataSet (possibly
        containing a single example) and the result is a DataSet of the same length.
        If output_fields is specified, it may be use to indicate which fields should
        be constructed in the output DataSet (for example ['output','classification_error']).
        Optionally, if copy_inputs, the input fields (of the input_dataset) can be made
        visible in the output DataSet returned by this method.
        """
        raise NotImplementedError

    def attributeNames(self):
        """
        A Learner may have attributes that it wishes to export to other objects. To automate
        such export, sub-classes should define here the names (list of strings) of these attributes.
        """
        return []

class TLearner(Learner):
    """
    TLearner is a virtual class of Learners that attempts to factor out of the definition
    of a learner the steps that are common to many implementations of learning algorithms,
    so as to leave only "the equations" to define in particular sub-classes, using Theano.

    In the default implementations of use and update, it is assumed that the 'use' and 'update' methods
    visit examples in the input dataset sequentially. In the 'use' method only one pass through the dataset is done,
    whereas the sub-learner may wish to iterate over the examples multiple times. Subclasses where this
    basic model is not appropriate can simply redefine update or use.
    
    Sub-classes must provide the following functions and functionalities:
      - attributeNames(): defines all the names of attributes which can be used as fields or
                          attributes in input/output datasets or in stats collectors.
                          All these attributes are expected to be theano.Result objects
                          (with a .data property and recognized by theano.Function for compilation).
                          The sub-class constructor defines the relations between
                          the Theano variables that may be used by 'use' and 'update'
                          or by a stats collector.
      - defaultOutputFields(input_fields): return a list of default dataset output fields when
                          None are provided by the caller of use.
      - update_start(), update_end(), update_minibatch(minibatch): functions
                          executed at the beginning, the end, and in the middle
                          (for each minibatch) of the update method. This model only
                          works for 'online' or one-short learning that requires
                          going only once through the training data. For more complicated
                          models, more specialized subclasses of TLearner should be used
                          or a learning-algorithm specific update method should be defined.

    The following naming convention is assumed and important.
    Attributes whose names are listed in attributeNames() can be of any type,
    but those that can be referenced as input/output dataset fields or as
    output attributes in 'use' or as input attributes in the stats collector
    should be associated with a Theano Result variable. If the exported attribute
    name is <name>, the corresponding Result name (an internal attribute of
    the TLearner, created in the sub-class constructor) should be _<name>.
    Typically <name> will be numpy ndarray and _<name> will be the corresponding
    Theano Tensor (for symbolic manipulation).
    """

    def __init__(self):
        Learner.__init__(self)
        
    def _minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
        """
        Private helper function called by the generic TLearner.use. It returns a function
        that can map the given input fields to the given output fields (along with the
        attributes that the stats collector needs for its computation.
        """
        if not output_fields:
            output_fields = self.defaultOutputFields(input_fields)
        if stats_collector:
            stats_collector_inputs = stats_collector.inputUpdateAttributes()
            for attribute in stats_collector_inputs:
                if attribute not in input_fields:
                    output_fields.append(attribute)
        key = (input_fields,output_fields)
        if key not in self.use_functions_dictionary:
            self.use_functions_dictionary[key]=Function(self._names2attributes(input_fields),
                                                   self._names2attributes(output_fields))
        return self.use_functions_dictionary[key]

    def attributes(self,return_copy=False):
        """
        Return a list with the values of the learner's attributes (or optionally, a deep copy).
        """
        return self.names2attributes(self.attributeNames())
            
    def _names2attributes(self,names,return_Result=False, return_copy=False):
        """
        Private helper function that maps a list of attribute names to a list
        of (optionally copies) values or of the Result objects that own these values.
        """
        if return_Result:
            if return_copy:
                return [copy.deepcopy(self.__getattr__(name)) for name in names]
            else:
                return [self.__getattr__(name) for name in names]
        else:
            if return_copy:
                return [copy.deepcopy(self.__getattr__(name).data) for name in names]
            else:
                return [self.__getattr__(name).data for name in names]

    def use(self,input_dataset,output_fieldnames=None,output_attributes=None,
            test_stats_collector=None,copy_inputs=True):
        """
        The learner tries to compute in the output dataset the output fields specified 
        """
        minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(),
                                                                  output_fieldnames,
                                                                  test_stats_collector)
        virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
                                                      minibatchwise_use_function,
                                                      True,DataSet.numpy_vstack,
                                                      DataSet.numpy_hstack)
        # actually force the computation
        output_dataset = CachedDataSet(virtual_output_dataset,True)
        if copy_inputs:
            output_dataset = input_dataset | output_dataset
        # copy the wanted attributes in the dataset
        if output_attributes:
            assert set(output_attributes) <= set(self.attributeNames())
            output_dataset.setAttributes(output_attributes,
                                         self._names2attributes(output_attributes,return_copy=True))
        if test_stats_collector:
            test_stats_collector.update(output_dataset)
            output_dataset.setAttributes(test_stats_collector.attributeNames(),
                                         test_stats_collector.attributes())
        return output_dataset

    def update_start(self): pass
    def update_end(self): pass
    def update_minibatch(self,minibatch):
        raise AbstractFunction()
    
    def update(self,training_set,train_stats_collector=None):
        
        self.update_start()
        for minibatch in training_set.minibatches(self.training_set_input_fields,
                                                  minibatch_size=self.minibatch_size):
            self.update_minibatch(minibatch)
            if train_stats_collector:
                minibatch_set = minibatch.examples()
                minibatch_set.setAttributes(self.attributeNames(),self.attributes())
                train_stats_collector.update(minibatch_set)
        self.update_end()
        return self.use