changeset 14:5ede27026e05

Working on gradient_based_learner
author bengioy@bengiomac.local
date Wed, 26 Mar 2008 22:56:13 -0400
parents 633453635d51
children 60b164a0d84a
files gradient_learner.py learner.py
diffstat 2 files changed, 19 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/gradient_learner.py	Wed Mar 26 21:38:08 2008 -0400
+++ b/gradient_learner.py	Wed Mar 26 22:56:13 2008 -0400
@@ -7,14 +7,15 @@
 
 class GradientLearner(Learner):
     """
-    Generic Learner for gradient-based optimization of a training criterion
+    Base class for gradient-based optimization of a training criterion
     that can consist in two parts, an additive part over examples, and
     an example-independent part (usually called the regularizer).
     The user provides a Theano formula that maps the fields of a training example
     and parameters to output fields (for the use function), one of which must be a cost
-    that is the training criterion to be minimized. The user also provides
-    a GradientBasedOptimizer that implements the optimization strategy.
-    The inputs, parameters, outputs and lists of Theano tensors,
+    that is the training criterion to be minimized. Subclasses implement
+    a training strategy that uses the function to compute gradients and
+    to compute outputs in the update method.
+    The inputs, parameters, and outputs are lists of Theano tensors,
     while the example_wise_cost and regularization_term are Theano tensors.
     The user can specify a regularization coefficient that multiplies the regularization term.
     The training algorithm looks for parameters that minimize
@@ -23,6 +24,8 @@
     i.e. the regularization_term should not depend on the inputs, only on the parameters.
     The learned function can map a subset of inputs to a subset of outputs (as long as the inputs subset
     includes all the inputs required in the Theano expression for the selected outputs).
+    It is assumed that all the inputs are provided in the training set, but
+    not necessarily when using the learned function.
     """
     def __init__(self, inputs, parameters, outputs, example_wise_cost, regularization_term,
                  gradient_based_optimizer=StochasticGradientDescent(), regularization_coefficient = astensor(1.0)):
@@ -35,6 +38,13 @@
         self.regularization_coefficient = regularization_coefficient
         self.parameters_example_wise_gradient = gradient.grad(example_wise_cost, parameters)
         self.parameters_regularization_gradient = gradient.grad(self.regularization_coefficient * regularization, parameters)
+        if example_wise_cost not in outputs:
+            outputs.append(example_wise_cost)
+        if regularization_term not in outputs:
+            outputs.append(regularization_term)
+        self.example_wise_gradient_fn = Function(inputs + parameters, 
+                                       [self.parameters_example_wise_gradient + self.parameters_regularization_gradient])
+        self.use_functions = {frozenset([input.name for input in inputs]) : Function(inputs, outputs)}
 
-#    def update(self,training_set):
-        
+    def update(self,training_set):
+
--- a/learner.py	Wed Mar 26 21:38:08 2008 -0400
+++ b/learner.py	Wed Mar 26 22:56:13 2008 -0400
@@ -23,13 +23,15 @@
         """
         raise NotImplementedError
 
-    def update(self,training_set):
+    def update(self,training_set,train_stats_collector=None):
         """
         Continue training a learner, with the evidence provided by the given training set.
         Hence update can be called multiple times. This is particularly useful in the
         on-line setting or the sequential (Bayesian or not) settings.
         The result is a function that can be applied on data, with the same
         semantics of the Learner.use method.
+        The user may optionally provide a training StatsCollector that is used to record
+        some statistics of the outputs computed during training.
         """
         return self.use # default behavior is 'non-adaptive', i.e. update does not do anything