Mercurial > pylearn
diff learner.py @ 133:b4657441dd65
Corrected typos
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Fri, 09 May 2008 13:38:54 -0400 |
parents | f6505ec32dc3 |
children | 3f4e5c9bdc5e |
line wrap: on
line diff
--- a/learner.py Thu May 08 00:54:14 2008 -0400 +++ b/learner.py Fri May 09 13:38:54 2008 -0400 @@ -47,7 +47,7 @@ and return the learned function. """ self.forget() - return self.update(learning_task,train_stats_collector) + return self.update(training_set,train_stats_collector) def use(self,input_dataset,output_fieldnames=None, test_stats_collector=None,copy_inputs=True, @@ -254,7 +254,7 @@ Private helper function that maps a list of attribute names to a list of corresponding Op Results (with the same name but with a '_' prefix). """ - return [self.__getattribute__('_'+name).data for name in names] + return [self.__getattribute__('_'+name) for name in names] class MinibatchUpdatesTLearner(TLearner): @@ -311,7 +311,8 @@ def parameterAttributes(self): raise AbstractFunction() - def updateStart(self): pass + def updateStart(self,training_set): + pass def updateEnd(self): self.setAttributes(self.updateEndOutputAttributes(), @@ -343,12 +344,15 @@ """ self.updateStart(training_set) stop=False + if hasattr(self,'_minibatch_size') and self._minibatch_size: + minibatch_size=self._minibatch_size + else: + minibatch_size=min(100,len(training_set)) while not stop: if train_stats_collector: train_stats_collector.forget() # restart stats collectin at the beginning of each epoch - for minibatch in training_set.minibatches(self.training_set_input_fields, - minibatch_size=self.minibatch_size): - self.update_minibatch(minibatch) + for minibatch in training_set.minibatches(minibatch_size=minibatch_size): + self.updateMinibatch(minibatch) if train_stats_collector: minibatch_set = minibatch.examples() minibatch_set.setAttributes(self.attributeNames(),self.attributes()) @@ -390,7 +394,7 @@ return self.parameterAttributes() def updateMinibatchOutputAttributes(self): - return ["_new"+name for name in self.parameterAttributes()] + return ["new_"+name for name in self.parameterAttributes()] def updateEndInputAttributes(self): return self.parameterAttributes()