diff learner.py @ 119:7ffecde9dadc

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Wed, 07 May 2008 15:08:18 -0400
parents d0a1bd0378c6
children 4efe6d36c061
line wrap: on
line diff
--- a/learner.py	Wed May 07 13:07:33 2008 -0400
+++ b/learner.py	Wed May 07 15:08:18 2008 -0400
@@ -1,6 +1,6 @@
 
-from dataset import *
-from compile import Function
+from dataset import AttributesHolder
+import compile
     
 class Learner(AttributesHolder):
     """Base class for learning algorithms, provides an interface
@@ -173,8 +173,8 @@
         if key not in self.use_functions_dictionary:
             use_input_attributes = self.useInputAttributes()
             use_output_attributes = self.useOutputAttributes()
-            complete_f = Function(self.names2OpResults(input_fields+use_input_attributes),
-                                  self.names2OpResults(output_fields+use_output_attributes))
+            complete_f = compile.function(self.names2OpResults(input_fields+use_input_attributes),
+                                          self.names2OpResults(output_fields+use_output_attributes))
             def f(*input_field_values):
                 input_attribute_values = self.names2attributes(use_input_attributes)
                 results = complete_f(*(input_field_values + input_attribute_values))
@@ -273,12 +273,13 @@
 
     def __init__(self):
         TLearner.__init__(self)
-        self.update_minibatch_function =
-        Function(self.names2OpResults(self.updateMinibatchOutputAttributes()+
-                                      self.updateMinibatchInputFields()),
+        self.update_minibatch_function = compile.function
+        (self.names2OpResults(self.updateMinibatchOutputAttributes()+
+                              self.updateMinibatchInputFields()),
                  self.names2OpResults(self.updateMinibatchOutputAttributes()))
-        self.update_end_function = Function(self.names2OpResults(self.updateEndInputAttributes()),
-                                            self.names2OpResults(self.updateEndOutputAttributes()))
+        self.update_end_function = compile.function
+        (self.names2OpResults(self.updateEndInputAttributes()),
+         self.names2OpResults(self.updateEndOutputAttributes()))
 
     def updateMinibatchInputFields(self):
         raise AbstractFunction()
@@ -310,7 +311,9 @@
         # make sure all required fields are allocated and initialized
         self.allocate(minibatch)
         self.setAttributes(self.updateMinibatchOutputAttributes(),
-                           self.update_minibatch_function(*(self.names2attributes(self.updateMinibatchInputAttributes()))
+                           # concatenate the attribute values and field values and then apply update fn
+                           self.update_minibatch_function(*(self.names2attributes
+                                                            (self.updateMinibatchInputAttributes()))
                                                           + minibatch(self.updateMinibatchInputFields())))
         
     def isLastEpoch(self):
@@ -347,17 +350,40 @@
     Specialization of MinibatchUpdatesTLearner in which the minibatch updates
     are obtained by performing an online (minibatch-based) gradient step.
 
-    Sub-classes must define the following methods:
-    
+    Sub-classes must define the following:
+
+      self._learning_rate (may be changed by the sub-class between epochs or minibatches)
+     
+      self.lossAttribute()  = name of the loss field 
+      
     """
     def __init__(self,truly_online=False):
         """
         If truly_online then only one pass is made through the training set passed to update().
-        
+
+        SUBCLASSES SHOULD CALL THIS CONSTRUCTOR ONLY AFTER HAVING DEFINED ALL THEIR THEANO FORMULAS
         """
         self.truly_online=truly_online
 
+        # create the formulas for the gradient update
+        old_params = [self.__getattr__("_"+name) for name in self.parameterAttributes()]
+        new_params_names = ["_new_"+name for name in self.parameterAttributes()]
+        loss = self.__getattr__(self.lossAttribute())
+        self.setAttributes(new_params_names,
+                           [t.add_inplace(self.param,
+                                          self._learning_rate*t.grad(loss,param))
+                            for param in old_params])
+
     def isLastEpoch(self):
         return self.truly_online
 
+    def updateMinibatchInputAttributes(self):
+        return self.parameterAttributes()
+    
+    def updateMinibatchOutputAttributes(self):
+        return ["_new"+name for name in self.parameterAttributes()]
+    
+    def updateEndInputAttributes(self):
+        return self.parameterAttributes()
 
+