diff learner.py @ 133:b4657441dd65

Corrected typos
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Fri, 09 May 2008 13:38:54 -0400
parents f6505ec32dc3
children 3f4e5c9bdc5e
line wrap: on
line diff
--- a/learner.py	Thu May 08 00:54:14 2008 -0400
+++ b/learner.py	Fri May 09 13:38:54 2008 -0400
@@ -47,7 +47,7 @@
         and return the learned function.
         """
         self.forget()
-        return self.update(learning_task,train_stats_collector)
+        return self.update(training_set,train_stats_collector)
 
     def use(self,input_dataset,output_fieldnames=None,
             test_stats_collector=None,copy_inputs=True,
@@ -254,7 +254,7 @@
         Private helper function that maps a list of attribute names to a list
         of corresponding Op Results (with the same name but with a '_' prefix).
         """
-        return [self.__getattribute__('_'+name).data for name in names]
+        return [self.__getattribute__('_'+name) for name in names]
 
 
 class MinibatchUpdatesTLearner(TLearner):
@@ -311,7 +311,8 @@
     def parameterAttributes(self):
         raise AbstractFunction()
 
-        def updateStart(self): pass
+    def updateStart(self,training_set):
+        pass
 
     def updateEnd(self):
         self.setAttributes(self.updateEndOutputAttributes(),
@@ -343,12 +344,15 @@
         """
         self.updateStart(training_set)
         stop=False
+        if hasattr(self,'_minibatch_size') and self._minibatch_size:
+            minibatch_size=self._minibatch_size
+        else:
+            minibatch_size=min(100,len(training_set))
         while not stop:
             if train_stats_collector:
                 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch
-            for minibatch in training_set.minibatches(self.training_set_input_fields,
-                                                      minibatch_size=self.minibatch_size):
-                self.update_minibatch(minibatch)
+            for minibatch in training_set.minibatches(minibatch_size=minibatch_size):
+                self.updateMinibatch(minibatch)
                 if train_stats_collector:
                     minibatch_set = minibatch.examples()
                     minibatch_set.setAttributes(self.attributeNames(),self.attributes())
@@ -390,7 +394,7 @@
         return self.parameterAttributes()
     
     def updateMinibatchOutputAttributes(self):
-        return ["_new"+name for name in self.parameterAttributes()]
+        return ["new_"+name for name in self.parameterAttributes()]
     
     def updateEndInputAttributes(self):
         return self.parameterAttributes()