changeset 1258:c88db30f4e08

added comments to our plugin proposal
author gdesjardins
date Fri, 24 Sep 2010 13:59:47 -0400
parents d79070c60546
children 6f76ecef869e
files doc/v2_planning/plugin_RP_GD.py
diffstat 1 files changed, 112 insertions(+), 21 deletions(-) [+]
line wrap: on
line diff
--- a/doc/v2_planning/plugin_RP_GD.py	Fri Sep 24 12:54:27 2010 -0400
+++ b/doc/v2_planning/plugin_RP_GD.py	Fri Sep 24 13:59:47 2010 -0400
@@ -1,39 +1,110 @@
-####
-# H1: everything works in term of iterator
-#     everything has a step() and end() method
-####
+"""
+BASIC IDEA
+==========
+
+The main ideas of the proposal are as follows:
+1. Basic components are defined in the traditional way (classes, methods, etc.)
+2. Components are then "wrapped" into **plugins**
+3. Plugins are registered through a **scheduler** to act on certain **events**
+4. We have defined two basic types of events:
+    4.1 scheduler events: these are emitted by the scheduler itself. For now we have defined:
+        scheduler.begin() : sent at the very beginning of scheduler execution
+        scheduler.end()   : sent when the scheduler ends
+    4.2 plugin generate events: plugins act as iterators and generate 2 basic events,
+        plugin.next() : sent of every "iteration". Each plugin is free to define what an
+                        iteration actually means.
+        plugin.end()  : sent when the plugin is done iterating.
+5. Using Olivier Breuleux's schedular, plugins can decide to act on every event, every n-th
+   occurence of an event, etc.
+
+OVERVIEW
+========
 
-# Construct counter plugin that keeps track of number of epochs
+The code below attempts to show how one would code a training experiment, where we pre-train +
+fine-tune a DBN, while using cross-validation. We have ommitted data preprocessing for now, but
+do not think it would have much of an impact on the finished product.
+"""
+
 class Counter(Plugin):
+    """
+    This is an example of a plugin which takes care of counting "stuff". It is generic enough
+    that it would probably be included in the library somewhere.
+    """
 
-    def __init__(self, sch, name, threshold):
-        super(self, Counter).__init__(sch, name)
+    def __init__(self, name, next_count, end_count):
+        """
+        :param name: name of the event we are counting (could be useful for debugging)
+        :param next_count: number of iterations before triggering a "next" event
+        :param end_count: number of iterations before triggering an "end" event
+        """
+        super(self, Counter).__init__()
         self.n = 0
-        self.threshold = threshold
+        self.next_count = next_count
+        self.end_count = end_count
 
     def execute(self, msg):
+        """
+        The execute function is the one which gets executed each time the plugin is "acted"
+        upon. This will send "next" and "end" events to the scheduler, which other plugins can
+        listen to. We show an example of this later.
+        """
         self.n += 1
-        if self.n > self.threshold:
-            self.fire(Event('terminate', value = self.n))
+        if self.n > self.end_count:
+            self.fire(Event('end', value = self.n))
+        elif self.n > self.next_count:
+            self.fire(Event('next', value = self.n))
+
 
 def fixed_epoch_trainer(model, save, n_epochs):
+    """
+    This is an example of a meta-plugin. Meta-plugin are just like any other plugins, except
+    they themselves contain other plugins and their own schedulers. A meta-plugin would replace
+    code-blocks which are usually found inside for-loop or if-else statements.
+
+    The fixed_epoch_trainer meta-plugin takes care of training a given model, for a fixed
+    number of epochs and then saving the resulting model.
+    """
+    # we start by defining our own private scheduler
     sched = Scheduler()
 
-    # define plugins 
+    # convert "components" to plugins. Maybe this could be automated or done in another way...
+    # syntax is not really important here.
     [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
 
-    counter = Counter(sched, 'epoch', n_epochs)
+    # instantiate the counter plugin to "end" after n_epochs
+    counter = Counter('epoch', 1, n_epochs)
+
+    ####
+    # Registering actions: overview of syntax
+    # plugin1.act(sched, on=[event1, event2]) means that that plugin1 will perform one
+    # "iteration" on every occurence of event1 and event2.
+    ####
 
-    # register actions
+    # In our example, we assume that 1 iteration of "model" means 1 minibatch update.
+    # Model performs an iteration:
+    # * when program first starts
+    # * after is it done with previous iteration (auto-loop)
+    # * after each epoch, as defined by the counter. The counter counts epochs by "trapping"
+    #   the model's end event.
     model.act(sched, on=[sched.begin(), model.step(), counter.step()])
+    # counter is incremented every time the 
     counter.act(sched, on=model.end())
+    # the save_model plugin then takes care of saving everything once the counter expires
     save_model.act(sched, on=counter.end())
 
-    sched.terminate(on=counter.end())
+    # Meta-plugins also generate events: they map "internal" events to "external" events
+    # fixed_epoch_trainer.end() will thus correspond to counter.end() 
+    sched.end(on=counter.end())
+    # in the same way, you could have:
+    # sched.next(on=counter.next()) but this is not useful here
 
     return sched
 
 def early_stop_trainer(model, validate, save, **kw):
+    """
+    This is another plugin example, which takes care of training the model but using early
+    stopping.
+    """
     sched = Scheduler()
 
     # define plugins 
@@ -43,15 +114,24 @@
 
     # register actions
     model.act(sched, on=[sched.begin(), model.step(), validate.step()])
+    # we measure validation error after every epoch (model.end)
     validate.act(sched, on=model.end())
+    # early-stopper gets triggered every time we have a new validation error
+    # the error itself is passed within the "step" message
     early_stop.act(sched, on=validate.step())
+    # model is saved whenever we find a new best model (early_stop.step) or when we have found
+    # THE best model (early_stop.end)
     save_model.act(sched, on=[early_stop.step(), early_stop.end()])
 
-    sched.terminate(on=early_stop.end())
+    sched.end(on=early_stop.end())
 
     return sched
 
 def dbn_trainer(rbm1, rbm2):
+    """
+    This meta-plugin pre-trains a two-layer DBN for a fixed number of epochs, and then performs
+    fine-tuning on the resulting MLP. This should hopefully be self-explanatory.
+    """
     sched = Scheduler()
 
     pretrain_layer1 = fixed_epoch_trainer(rbm1, save)
@@ -69,6 +149,10 @@
     return sched
 
 def single_crossval_run(trainer, kfold_plugin, kfold_measure)
+    """
+    For a fixed set of hyper-parameters, this evaluates the generalization error using KFold
+    cross-validation.
+    """
 
     sched = Scheduler()
 
@@ -76,12 +160,12 @@
     kfold_plugin.act(sched, on=[sched.begin(), trainer.end()])
     trainer.act(sched, on=[kfold_plugin.step()])
 
-    # trainer terminates on early_stop.end(). This means that trainer.end() will forward
+    # trainer ends on early_stop.end(). This means that trainer.end() will forward
     # the early-stopping message which contains the best validation error.
     kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()]
 
     # this best validation error is then forwarded by single_crossval_run
-    sched.terminate(on=kfold_measure.end())
+    sched.end(on=kfold_measure.end())
     
     return sched
 
@@ -94,11 +178,19 @@
 kfold_plugin = KFold([rbm1, rbm2], dataset)
 kfold_measure = ...
 
-# manually add "hook" to monitor early stopping statistics
-# NB: advantage of plugins is that this code can go anywhere ...
+## In our view, the meta-plugins defined above would live in the library somewhere. Hooks can be
+## added without modifying the library code. The meta-plugin's scheduler contains a dictionary
+## of "registered" plugins along with their events. We can thus register "user-plugins" based on
+## any of these events.
+# define a logger plugin of some sort
+print_stat = ....   
+# act on each iteration of the early-stopping plugin
+# NB: the message is forwarded as is. It is up to the print_stat plugin to parse it properly.
 print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step())
 
 #### THIS SHOULD CORRESPOND TO THE OUTER LOOP ####
+# this is the final outer-loop which tests various configurations of hyperparameters
+
 sched = Scheduler()
 
 hyperparam_change = DBN_HyperParam([rbm1, rbm2])
@@ -107,8 +199,7 @@
 hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()])
 hyperparam_test.act(sched, on=hyperparam_change.step())
 
-sched.terminate(hyperparam_change.end())
-
+sched.end(hyperparam_change.end())
 
 ##### RUN THE WHOLE DAMN THING #####
 sched.run()