Mercurial > pylearn
diff doc/v2_planning/plugin_RP_GD.py @ 1258:c88db30f4e08
added comments to our plugin proposal
author | gdesjardins |
---|---|
date | Fri, 24 Sep 2010 13:59:47 -0400 |
parents | bf41991692ea |
children | 6f76ecef869e |
line wrap: on
line diff
--- a/doc/v2_planning/plugin_RP_GD.py Fri Sep 24 12:54:27 2010 -0400 +++ b/doc/v2_planning/plugin_RP_GD.py Fri Sep 24 13:59:47 2010 -0400 @@ -1,39 +1,110 @@ -#### -# H1: everything works in term of iterator -# everything has a step() and end() method -#### +""" +BASIC IDEA +========== + +The main ideas of the proposal are as follows: +1. Basic components are defined in the traditional way (classes, methods, etc.) +2. Components are then "wrapped" into **plugins** +3. Plugins are registered through a **scheduler** to act on certain **events** +4. We have defined two basic types of events: + 4.1 scheduler events: these are emitted by the scheduler itself. For now we have defined: + scheduler.begin() : sent at the very beginning of scheduler execution + scheduler.end() : sent when the scheduler ends + 4.2 plugin generate events: plugins act as iterators and generate 2 basic events, + plugin.next() : sent of every "iteration". Each plugin is free to define what an + iteration actually means. + plugin.end() : sent when the plugin is done iterating. +5. Using Olivier Breuleux's schedular, plugins can decide to act on every event, every n-th + occurence of an event, etc. + +OVERVIEW +======== -# Construct counter plugin that keeps track of number of epochs +The code below attempts to show how one would code a training experiment, where we pre-train + +fine-tune a DBN, while using cross-validation. We have ommitted data preprocessing for now, but +do not think it would have much of an impact on the finished product. +""" + class Counter(Plugin): + """ + This is an example of a plugin which takes care of counting "stuff". It is generic enough + that it would probably be included in the library somewhere. + """ - def __init__(self, sch, name, threshold): - super(self, Counter).__init__(sch, name) + def __init__(self, name, next_count, end_count): + """ + :param name: name of the event we are counting (could be useful for debugging) + :param next_count: number of iterations before triggering a "next" event + :param end_count: number of iterations before triggering an "end" event + """ + super(self, Counter).__init__() self.n = 0 - self.threshold = threshold + self.next_count = next_count + self.end_count = end_count def execute(self, msg): + """ + The execute function is the one which gets executed each time the plugin is "acted" + upon. This will send "next" and "end" events to the scheduler, which other plugins can + listen to. We show an example of this later. + """ self.n += 1 - if self.n > self.threshold: - self.fire(Event('terminate', value = self.n)) + if self.n > self.end_count: + self.fire(Event('end', value = self.n)) + elif self.n > self.next_count: + self.fire(Event('next', value = self.n)) + def fixed_epoch_trainer(model, save, n_epochs): + """ + This is an example of a meta-plugin. Meta-plugin are just like any other plugins, except + they themselves contain other plugins and their own schedulers. A meta-plugin would replace + code-blocks which are usually found inside for-loop or if-else statements. + + The fixed_epoch_trainer meta-plugin takes care of training a given model, for a fixed + number of epochs and then saving the resulting model. + """ + # we start by defining our own private scheduler sched = Scheduler() - # define plugins + # convert "components" to plugins. Maybe this could be automated or done in another way... + # syntax is not really important here. [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) - counter = Counter(sched, 'epoch', n_epochs) + # instantiate the counter plugin to "end" after n_epochs + counter = Counter('epoch', 1, n_epochs) + + #### + # Registering actions: overview of syntax + # plugin1.act(sched, on=[event1, event2]) means that that plugin1 will perform one + # "iteration" on every occurence of event1 and event2. + #### - # register actions + # In our example, we assume that 1 iteration of "model" means 1 minibatch update. + # Model performs an iteration: + # * when program first starts + # * after is it done with previous iteration (auto-loop) + # * after each epoch, as defined by the counter. The counter counts epochs by "trapping" + # the model's end event. model.act(sched, on=[sched.begin(), model.step(), counter.step()]) + # counter is incremented every time the counter.act(sched, on=model.end()) + # the save_model plugin then takes care of saving everything once the counter expires save_model.act(sched, on=counter.end()) - sched.terminate(on=counter.end()) + # Meta-plugins also generate events: they map "internal" events to "external" events + # fixed_epoch_trainer.end() will thus correspond to counter.end() + sched.end(on=counter.end()) + # in the same way, you could have: + # sched.next(on=counter.next()) but this is not useful here return sched def early_stop_trainer(model, validate, save, **kw): + """ + This is another plugin example, which takes care of training the model but using early + stopping. + """ sched = Scheduler() # define plugins @@ -43,15 +114,24 @@ # register actions model.act(sched, on=[sched.begin(), model.step(), validate.step()]) + # we measure validation error after every epoch (model.end) validate.act(sched, on=model.end()) + # early-stopper gets triggered every time we have a new validation error + # the error itself is passed within the "step" message early_stop.act(sched, on=validate.step()) + # model is saved whenever we find a new best model (early_stop.step) or when we have found + # THE best model (early_stop.end) save_model.act(sched, on=[early_stop.step(), early_stop.end()]) - sched.terminate(on=early_stop.end()) + sched.end(on=early_stop.end()) return sched def dbn_trainer(rbm1, rbm2): + """ + This meta-plugin pre-trains a two-layer DBN for a fixed number of epochs, and then performs + fine-tuning on the resulting MLP. This should hopefully be self-explanatory. + """ sched = Scheduler() pretrain_layer1 = fixed_epoch_trainer(rbm1, save) @@ -69,6 +149,10 @@ return sched def single_crossval_run(trainer, kfold_plugin, kfold_measure) + """ + For a fixed set of hyper-parameters, this evaluates the generalization error using KFold + cross-validation. + """ sched = Scheduler() @@ -76,12 +160,12 @@ kfold_plugin.act(sched, on=[sched.begin(), trainer.end()]) trainer.act(sched, on=[kfold_plugin.step()]) - # trainer terminates on early_stop.end(). This means that trainer.end() will forward + # trainer ends on early_stop.end(). This means that trainer.end() will forward # the early-stopping message which contains the best validation error. kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()] # this best validation error is then forwarded by single_crossval_run - sched.terminate(on=kfold_measure.end()) + sched.end(on=kfold_measure.end()) return sched @@ -94,11 +178,19 @@ kfold_plugin = KFold([rbm1, rbm2], dataset) kfold_measure = ... -# manually add "hook" to monitor early stopping statistics -# NB: advantage of plugins is that this code can go anywhere ... +## In our view, the meta-plugins defined above would live in the library somewhere. Hooks can be +## added without modifying the library code. The meta-plugin's scheduler contains a dictionary +## of "registered" plugins along with their events. We can thus register "user-plugins" based on +## any of these events. +# define a logger plugin of some sort +print_stat = .... +# act on each iteration of the early-stopping plugin +# NB: the message is forwarded as is. It is up to the print_stat plugin to parse it properly. print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step()) #### THIS SHOULD CORRESPOND TO THE OUTER LOOP #### +# this is the final outer-loop which tests various configurations of hyperparameters + sched = Scheduler() hyperparam_change = DBN_HyperParam([rbm1, rbm2]) @@ -107,8 +199,7 @@ hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()]) hyperparam_test.act(sched, on=hyperparam_change.step()) -sched.terminate(hyperparam_change.end()) - +sched.end(hyperparam_change.end()) ##### RUN THE WHOLE DAMN THING ##### sched.run()