Mercurial > pylearn
diff doc/v2_planning/plugin_RP_GD.py @ 1256:bf41991692ea
new plugin approach
author | gdesjardins |
---|---|
date | Fri, 24 Sep 2010 12:53:53 -0400 |
parents | |
children | c88db30f4e08 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/doc/v2_planning/plugin_RP_GD.py Fri Sep 24 12:53:53 2010 -0400 @@ -0,0 +1,114 @@ +#### +# H1: everything works in term of iterator +# everything has a step() and end() method +#### + +# Construct counter plugin that keeps track of number of epochs +class Counter(Plugin): + + def __init__(self, sch, name, threshold): + super(self, Counter).__init__(sch, name) + self.n = 0 + self.threshold = threshold + + def execute(self, msg): + self.n += 1 + if self.n > self.threshold: + self.fire(Event('terminate', value = self.n)) + +def fixed_epoch_trainer(model, save, n_epochs): + sched = Scheduler() + + # define plugins + [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) + + counter = Counter(sched, 'epoch', n_epochs) + + # register actions + model.act(sched, on=[sched.begin(), model.step(), counter.step()]) + counter.act(sched, on=model.end()) + save_model.act(sched, on=counter.end()) + + sched.terminate(on=counter.end()) + + return sched + +def early_stop_trainer(model, validate, save, **kw): + sched = Scheduler() + + # define plugins + [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) + + early_stop = Stopper(**kw) + + # register actions + model.act(sched, on=[sched.begin(), model.step(), validate.step()]) + validate.act(sched, on=model.end()) + early_stop.act(sched, on=validate.step()) + save_model.act(sched, on=[early_stop.step(), early_stop.end()]) + + sched.terminate(on=early_stop.end()) + + return sched + +def dbn_trainer(rbm1, rbm2): + sched = Scheduler() + + pretrain_layer1 = fixed_epoch_trainer(rbm1, save) + pretrain_layer1.act(sched, on=sched.begin()) + + pretrain_layer2 = fixed_epoch_trainer(rbm2, save) + pretrain_layer2.act(sched, on=pretrain_layer1.end()) + + ## TBD: by the layer committee + mlp = function(rbm1, rbm2) + + fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp) + fine_tuning.act(sched, on=pretrain_layer2.end()) + + return sched + +def single_crossval_run(trainer, kfold_plugin, kfold_measure) + + sched = Scheduler() + + # k-fold plugin will call rbm.change_dataset using various splits of the data + kfold_plugin.act(sched, on=[sched.begin(), trainer.end()]) + trainer.act(sched, on=[kfold_plugin.step()]) + + # trainer terminates on early_stop.end(). This means that trainer.end() will forward + # the early-stopping message which contains the best validation error. + kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()] + + # this best validation error is then forwarded by single_crossval_run + sched.terminate(on=kfold_measure.end()) + + return sched + + +#### MAIN LOOP #### +rbm1 = ... +rbm2 = ... +dataset = .... +dbn_trainer = dbn_trainer(rbm1, rbm2) +kfold_plugin = KFold([rbm1, rbm2], dataset) +kfold_measure = ... + +# manually add "hook" to monitor early stopping statistics +# NB: advantage of plugins is that this code can go anywhere ... +print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step()) + +#### THIS SHOULD CORRESPOND TO THE OUTER LOOP #### +sched = Scheduler() + +hyperparam_change = DBN_HyperParam([rbm1, rbm2]) +hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure) + +hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()]) +hyperparam_test.act(sched, on=hyperparam_change.step()) + +sched.terminate(hyperparam_change.end()) + + +##### RUN THE WHOLE DAMN THING ##### +sched.run()