Mercurial > pylearn
view doc/v2_planning/plugin_RP_GD.py @ 1256:bf41991692ea
new plugin approach
author | gdesjardins |
---|---|
date | Fri, 24 Sep 2010 12:53:53 -0400 |
parents | |
children | c88db30f4e08 |
line wrap: on
line source
#### # H1: everything works in term of iterator # everything has a step() and end() method #### # Construct counter plugin that keeps track of number of epochs class Counter(Plugin): def __init__(self, sch, name, threshold): super(self, Counter).__init__(sch, name) self.n = 0 self.threshold = threshold def execute(self, msg): self.n += 1 if self.n > self.threshold: self.fire(Event('terminate', value = self.n)) def fixed_epoch_trainer(model, save, n_epochs): sched = Scheduler() # define plugins [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) counter = Counter(sched, 'epoch', n_epochs) # register actions model.act(sched, on=[sched.begin(), model.step(), counter.step()]) counter.act(sched, on=model.end()) save_model.act(sched, on=counter.end()) sched.terminate(on=counter.end()) return sched def early_stop_trainer(model, validate, save, **kw): sched = Scheduler() # define plugins [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) early_stop = Stopper(**kw) # register actions model.act(sched, on=[sched.begin(), model.step(), validate.step()]) validate.act(sched, on=model.end()) early_stop.act(sched, on=validate.step()) save_model.act(sched, on=[early_stop.step(), early_stop.end()]) sched.terminate(on=early_stop.end()) return sched def dbn_trainer(rbm1, rbm2): sched = Scheduler() pretrain_layer1 = fixed_epoch_trainer(rbm1, save) pretrain_layer1.act(sched, on=sched.begin()) pretrain_layer2 = fixed_epoch_trainer(rbm2, save) pretrain_layer2.act(sched, on=pretrain_layer1.end()) ## TBD: by the layer committee mlp = function(rbm1, rbm2) fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp) fine_tuning.act(sched, on=pretrain_layer2.end()) return sched def single_crossval_run(trainer, kfold_plugin, kfold_measure) sched = Scheduler() # k-fold plugin will call rbm.change_dataset using various splits of the data kfold_plugin.act(sched, on=[sched.begin(), trainer.end()]) trainer.act(sched, on=[kfold_plugin.step()]) # trainer terminates on early_stop.end(). This means that trainer.end() will forward # the early-stopping message which contains the best validation error. kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()] # this best validation error is then forwarded by single_crossval_run sched.terminate(on=kfold_measure.end()) return sched #### MAIN LOOP #### rbm1 = ... rbm2 = ... dataset = .... dbn_trainer = dbn_trainer(rbm1, rbm2) kfold_plugin = KFold([rbm1, rbm2], dataset) kfold_measure = ... # manually add "hook" to monitor early stopping statistics # NB: advantage of plugins is that this code can go anywhere ... print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step()) #### THIS SHOULD CORRESPOND TO THE OUTER LOOP #### sched = Scheduler() hyperparam_change = DBN_HyperParam([rbm1, rbm2]) hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure) hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()]) hyperparam_test.act(sched, on=hyperparam_change.step()) sched.terminate(hyperparam_change.end()) ##### RUN THE WHOLE DAMN THING ##### sched.run()