diff doc/v2_planning/plugin_RP_GD.py @ 1256:bf41991692ea

new plugin approach
author gdesjardins
date Fri, 24 Sep 2010 12:53:53 -0400
parents
children c88db30f4e08
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/doc/v2_planning/plugin_RP_GD.py	Fri Sep 24 12:53:53 2010 -0400
@@ -0,0 +1,114 @@
+####
+# H1: everything works in term of iterator
+#     everything has a step() and end() method
+####
+
+# Construct counter plugin that keeps track of number of epochs
+class Counter(Plugin):
+
+    def __init__(self, sch, name, threshold):
+        super(self, Counter).__init__(sch, name)
+        self.n = 0
+        self.threshold = threshold
+
+    def execute(self, msg):
+        self.n += 1
+        if self.n > self.threshold:
+            self.fire(Event('terminate', value = self.n))
+
+def fixed_epoch_trainer(model, save, n_epochs):
+    sched = Scheduler()
+
+    # define plugins 
+    [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
+
+    counter = Counter(sched, 'epoch', n_epochs)
+
+    # register actions
+    model.act(sched, on=[sched.begin(), model.step(), counter.step()])
+    counter.act(sched, on=model.end())
+    save_model.act(sched, on=counter.end())
+
+    sched.terminate(on=counter.end())
+
+    return sched
+
+def early_stop_trainer(model, validate, save, **kw):
+    sched = Scheduler()
+
+    # define plugins 
+    [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
+
+    early_stop = Stopper(**kw)
+
+    # register actions
+    model.act(sched, on=[sched.begin(), model.step(), validate.step()])
+    validate.act(sched, on=model.end())
+    early_stop.act(sched, on=validate.step())
+    save_model.act(sched, on=[early_stop.step(), early_stop.end()])
+
+    sched.terminate(on=early_stop.end())
+
+    return sched
+
+def dbn_trainer(rbm1, rbm2):
+    sched = Scheduler()
+
+    pretrain_layer1 = fixed_epoch_trainer(rbm1, save)
+    pretrain_layer1.act(sched, on=sched.begin())
+
+    pretrain_layer2 = fixed_epoch_trainer(rbm2, save)
+    pretrain_layer2.act(sched, on=pretrain_layer1.end())
+
+    ## TBD: by the layer committee
+    mlp = function(rbm1, rbm2)
+
+    fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp)
+    fine_tuning.act(sched, on=pretrain_layer2.end())
+
+    return sched
+
+def single_crossval_run(trainer, kfold_plugin, kfold_measure)
+
+    sched = Scheduler()
+
+    # k-fold plugin will call rbm.change_dataset using various splits of the data
+    kfold_plugin.act(sched, on=[sched.begin(), trainer.end()])
+    trainer.act(sched, on=[kfold_plugin.step()])
+
+    # trainer terminates on early_stop.end(). This means that trainer.end() will forward
+    # the early-stopping message which contains the best validation error.
+    kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()]
+
+    # this best validation error is then forwarded by single_crossval_run
+    sched.terminate(on=kfold_measure.end())
+    
+    return sched
+
+
+#### MAIN LOOP ####
+rbm1 = ...
+rbm2 = ...
+dataset = ....
+dbn_trainer = dbn_trainer(rbm1, rbm2)
+kfold_plugin = KFold([rbm1, rbm2], dataset)
+kfold_measure = ...
+
+# manually add "hook" to monitor early stopping statistics
+# NB: advantage of plugins is that this code can go anywhere ...
+print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step())
+
+#### THIS SHOULD CORRESPOND TO THE OUTER LOOP ####
+sched = Scheduler()
+
+hyperparam_change = DBN_HyperParam([rbm1, rbm2])
+hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure) 
+
+hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()])
+hyperparam_test.act(sched, on=hyperparam_change.step())
+
+sched.terminate(hyperparam_change.end())
+
+
+##### RUN THE WHOLE DAMN THING #####
+sched.run()