Mercurial > pylearn
diff learner.py @ 110:8fa1ef2411a0
Worked on OneShotTLearner and implementation of LinearRegression
author | bengioy@bengiomac.local |
---|---|
date | Tue, 06 May 2008 22:24:55 -0400 |
parents | d97f6fe6bdf9 |
children | 88257dfedf8c |
line wrap: on
line diff
--- a/learner.py Tue May 06 20:01:34 2008 -0400 +++ b/learner.py Tue May 06 22:24:55 2008 -0400 @@ -1,7 +1,7 @@ from dataset import * -class Learner(object): +class Learner(AttributesHolder): """Base class for learning algorithms, provides an interface that allows various algorithms to be applicable to generic learning algorithms. @@ -66,6 +66,35 @@ """ return [] + def updateInputAttributes(self): + """ + A subset of self.attributeNames() which are the names of attributes needed by update() in order + to do its work. + """ + raise AbstractFunction() + + def useInputAttributes(self): + """ + A subset of self.attributeNames() which are the names of attributes needed by use() in order + to do its work. + """ + raise AbstractFunction() + + def updateOutputAttributes(self): + """ + A subset of self.attributeNames() which are the names of attributes modified/created by update() in order + to do its work. + """ + raise AbstractFunction() + + def useOutputAttributes(self): + """ + A subset of self.attributeNames() which are the names of attributes modified/created by use() in order + to do its work. + """ + raise AbstractFunction() + + class TLearner(Learner): """ TLearner is a virtual class of Learners that attempts to factor out of the definition @@ -103,50 +132,82 @@ def __init__(self): Learner.__init__(self) + + def defaultOutputFields(self, input_fields): + """ + Return a default list of output field names (to put in the output dataset). + This will be used when None are provided (as output_fields) by the caller of the 'use' method. + This may involve looking at the input_fields (names) available in the + input_dataset. + """ + raise AbstractFunction() + + def allocate(self, minibatch): + """ + This function is called at the beginning of each updateMinibatch + and should be used to check that all required attributes have been + allocated and initialized (usually this function calls forget() + when it has to do an initialization). + """ + raise AbstractFunction() - def _minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): + def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): """ Private helper function called by the generic TLearner.use. It returns a function that can map the given input fields to the given output fields (along with the - attributes that the stats collector needs for its computation. + attributes that the stats collector needs for its computation. The function + called also automatically makes use of the self.useInputAttributes() and + sets the self.useOutputAttributes(). """ if not output_fields: output_fields = self.defaultOutputFields(input_fields) if stats_collector: - stats_collector_inputs = stats_collector.inputUpdateAttributes() + stats_collector_inputs = stats_collector.input2UpdateAttributes() for attribute in stats_collector_inputs: if attribute not in input_fields: output_fields.append(attribute) key = (input_fields,output_fields) if key not in self.use_functions_dictionary: - self.use_functions_dictionary[key]=Function(self._names2attributes(input_fields), - self._names2attributes(output_fields)) + use_input_attributes = self.useInputAttributes() + use_output_attributes = self.useOutputAttributes() + complete_f = Function(self.names2OpResults(input_fields+use_input_attributes), + self.names2OpResults(output_fields+use_output_attributes)) + def f(*input_field_values): + input_attribute_values = self.names2attributes(use_input_attributes) + results = complete_f(*(input_field_values + input_attribute_values)) + output_field_values = results[0:len(output_fields)] + output_attribute_values = results[len(output_fields):len(results)] + if use_output_attributes: + self.setAttributes(use_output_attributes,output_attribute_values) + return output_field_values + self.use_functions_dictionary[key]=f return self.use_functions_dictionary[key] def attributes(self,return_copy=False): """ Return a list with the values of the learner's attributes (or optionally, a deep copy). """ - return self.names2attributes(self.attributeNames()) - - def _names2attributes(self,names,return_Result=False, return_copy=False): + return self.names2attributes(self.attributeNames(),return_copy) + + def names2attributes(self,names,return_copy=False): """ Private helper function that maps a list of attribute names to a list - of (optionally copies) values or of the Result objects that own these values. + of (optionally copies) values of attributes. """ - if return_Result: - if return_copy: - return [copy.deepcopy(self.__getattr__(name)) for name in names] - else: - return [self.__getattr__(name) for name in names] + if return_copy: + return [copy.deepcopy(self.__getattr__(name).data) for name in names] else: - if return_copy: - return [copy.deepcopy(self.__getattr__(name).data) for name in names] - else: - return [self.__getattr__(name).data for name in names] + return [self.__getattr__(name).data for name in names] + + def names2OpResults(self,names): + """ + Private helper function that maps a list of attribute names to a list + of corresponding Op Results (with the same name but with a '_' prefix). + """ + return [self.__getattr__('_'+name).data for name in names] def use(self,input_dataset,output_fieldnames=None,output_attributes=[], - test_stats_collector=None,copy_inputs=True): + test_stats_collector=None,copy_inputs=True, put_stats_in_output_dataset=True): """ The learner tries to compute in the output dataset the output fields specified @@ -164,7 +225,7 @@ If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) are also copied into the output dataset attributes. """ - minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(), + minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(), output_fieldnames, test_stats_collector) virtual_output_dataset = ApplyFunctionDataSet(input_dataset, @@ -179,20 +240,21 @@ if output_attributes is None: output_attributes = self.attributeNames() if output_attributes: - assert set(output_attributes) <= set(self.attributeNames()) + assert set(attribute_names) <= set(self.attributeNames()) output_dataset.setAttributes(output_attributes, - self._names2attributes(output_attributes,return_copy=True)) + self.names2attributes(output_attributes,return_copy=True)) if test_stats_collector: test_stats_collector.update(output_dataset) - output_dataset.setAttributes(test_stats_collector.attributeNames(), - test_stats_collector.attributes()) + if put_stats_in_output_dataset: + output_dataset.setAttributes(test_stats_collector.attributeNames(), + test_stats_collector.attributes()) return output_dataset class OneShotTLearner(TLearner): """ This adds to TLearner a - - update_start(), update_end(), update_minibatch(minibatch), end_epoch(): + - updateStart(), updateEnd(), updateMinibatch(minibatch), isLastEpoch(): functions executed at the beginning, the end, in the middle (for each minibatch) of the update method, and at the end of each epoch. This model only @@ -204,18 +266,56 @@ def __init__(self): TLearner.__init__(self) + self.update_minibatch_function = + Function(self.names2OpResults(self.updateMinibatchOutputAttributes()+ + self.updateMinibatchInputFields()), + self.names2OpResults(self.updateMinibatchOutputAttributes())) + self.update_end_function = Function(self.names2OpResults(self.updateEndInputAttributes()), + self.names2OpResults(self.updateEndOutputAttributes())) + + def updateMinibatchInputFields(self): + raise AbstractFunction() + + def updateMinibatchInputAttributes(self): + raise AbstractFunction() + + def updateMinibatchOutputAttributes(self): + raise AbstractFunction() + + def updateEndInputAttributes(self): + raise AbstractFunction() + + def updateEndOutputAttributes(self): + raise AbstractFunction() + + def updateStart(self): pass + + def updateEnd(self): + self.setAttributes(self.updateEndOutputAttributes(), + self.update_end_function + (self.names2attributes(self.updateEndInputAttributes()))) - def update_start(self): pass - def update_end(self): pass - def update_minibatch(self,minibatch): - raise AbstractFunction() + def updateMinibatch(self,minibatch): + # make sure all required fields are allocated and initialized + self.allocate(minibatch) + self.setAttributes(self.updateMinibatchOutputAttributes(), + self.update_minibatch_function(*(self.names2attributes(self.updateMinibatchInputAttributes())) + + minibatch(self.updateMinibatchInputFields()))) + + def isLastEpoch(self): + """ + This method is called at the end of each epoch (cycling over the training set). + It returns a boolean to indicate if this is the last epoch. + By default just do one epoch. + """ + return True def update(self,training_set,train_stats_collector=None): """ @todo check if some of the learner attributes are actually SPECIFIED in as attributes of the training_set. """ - self.update_start() + self.updateStart(training_set) stop=False while not stop: if train_stats_collector: @@ -227,7 +327,7 @@ minibatch_set = minibatch.examples() minibatch_set.setAttributes(self.attributeNames(),self.attributes()) train_stats_collector.update(minibatch_set) - stop = self.end_epoch() - self.update_end() + stop = self.isLastEpoch() + self.updateEnd() return self.use