Mercurial > pylearn
view doc/v2_planning/plugin_RP_GD.py @ 1451:8110ca3cec3f
merge
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 31 Mar 2011 18:29:11 -0400 |
parents | 6f76ecef869e |
children |
line wrap: on
line source
""" BASIC IDEA ========== The main ideas of the proposal are as follows: 1. Basic components are defined in the traditional way (classes, methods, etc.) 2. Components are then "wrapped" into **plugins** 3. Plugins are registered through a **scheduler** to act on certain **events** 4. We have defined two basic types of events: 4.1 scheduler events: these are emitted by the scheduler itself. For now we have defined: scheduler.begin() : sent at the very beginning of scheduler execution scheduler.end() : sent when the scheduler ends ( this makes sense since there is at least from the user perspective a hierarchy of schedulars) 4.2 plugin generate events: plugins act as iterators and generate 2 basic events, plugin.next() : sent of every "iteration". Each plugin is free to define what an iteration actually means. plugin.end() : sent when the plugin is done iterating. 5. Using Olivier Breuleux's schedular, plugins can decide to act on every event, every n-th occurence of an event, etc. OVERVIEW ======== The code below attempts to show how one would code a training experiment, where we pre-train + fine-tune a DBN, while using cross-validation. We have ommitted data preprocessing for now, but do not think it would have much of an impact on the finished product. """ class Counter(Plugin): """ This is an example of a plugin which takes care of counting "stuff". It is generic enough that it would probably be included in the library somewhere. """ def __init__(self, name, end_count, next_count = 1): """ :param name: name of the event we are counting (could be useful for debugging) :param next_count: number of iterations before triggering a "next" event :param end_count: number of iterations before triggering an "end" event """ super(self, Counter).__init__() self.n = 0 self.next_count = next_count self.end_count = end_count def execute(self, msg): """ The execute function is the one which gets executed each time the plugin is "acted" upon. This will send "next" and "end" events to the scheduler, which other plugins can listen to. We show an example of this later. """ self.n += 1 if self.n > self.end_count: self.fire(Event('end', value = self.n)) elif self.n > self.next_count: self.fire(Event('next', value = self.n)) def fixed_epoch_trainer(model, save, n_epochs): """ This is an example of a meta-plugin. Meta-plugin are just like any other plugins, except they themselves contain other plugins and their own schedulers. A meta-plugin would replace code-blocks which are usually found inside for-loop or if-else statements. The fixed_epoch_trainer meta-plugin takes care of training a given model, for a fixed number of epochs and then saving the resulting model. Other arguments for having meta-plugin : * they can define a semantically separable block of code * they are what the library provides 99% of the code ( so you can define a certain template of connecting plugins as a meta-plugin and ship it without worry of things) * they can be breaked apart by the main schedular ( so you would not have "different" schedulars running in the same time; is just one schedular, this way of constructing things is just to improve understanding and intuitions of the syste * they help pushing all the complexity over the backbone of the library ( i.e. the schedular ) * all plugins registered inside a hyper-plugin are active only when the hyper-plugin is active; In this sense they can help definining scopes ( as for variables ) - optional """ # we start by defining our own private scheduler sched = Scheduler() # convert "components" to plugins. Maybe this could be automated or done in another way... # syntax is not really important here. [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) # instantiate the counter plugin to "end" after n_epochs counter = Counter('epoch', 1, n_epochs) #### # Registering actions: overview of syntax # plugin1.act(sched, on=[event1, event2]) means that that plugin1 will perform one # "iteration" on every occurence of event1 and event2. #### # In our example, we assume that 1 iteration of "model" means 1 minibatch update. # Model performs an iteration: # * when program first starts # * after is it done with previous iteration (auto-loop) # * after each epoch, as defined by the counter. The counter counts epochs by "trapping" # the model's end event. model.act(sched, on=[sched.begin(), model.step(), counter.step()]) # counter is incremented every time the counter.act(sched, on=model.end()) # the save_model plugin then takes care of saving everything once the counter expires save_model.act(sched, on=counter.end()) # Meta-plugins also generate events: they map "internal" events to "external" events # fixed_epoch_trainer.end() will thus correspond to counter.end() sched.end(on=counter.end()) # in the same way, you could have: # sched.next(on=counter.next()) but this is not useful here return sched def early_stop_trainer(model, validate, save, **kw): """ This is another plugin example, which takes care of training the model but using early stopping. """ sched = Scheduler() # define plugins [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) early_stop = Stopper(**kw) # register actions model.act(sched, on=[sched.begin(), model.step(), validate.step()]) # we measure validation error after every epoch (model.end) validate.act(sched, on=model.end()) # early-stopper gets triggered every time we have a new validation error # the error itself is passed within the "step" message early_stop.act(sched, on=validate.step()) # model is saved whenever we find a new best model (early_stop.step) or when we have found # THE best model (early_stop.end) save_model.act(sched, on=[early_stop.step(), early_stop.end()]) sched.end(on=early_stop.end()) return sched def dbn_trainer(rbm1, rbm2): """ This meta-plugin pre-trains a two-layer DBN for a fixed number of epochs, and then performs fine-tuning on the resulting MLP. This should hopefully be self-explanatory. """ sched = Scheduler() pretrain_layer1 = fixed_epoch_trainer(rbm1, save) pretrain_layer1.act(sched, on=sched.begin()) pretrain_layer2 = fixed_epoch_trainer(rbm2, save) pretrain_layer2.act(sched, on=pretrain_layer1.end()) ## TBD: by the layer committee mlp = function(rbm1, rbm2) fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp) fine_tuning.act(sched, on=pretrain_layer2.end()) return sched def single_crossval_run(trainer, kfold_plugin, kfold_measure) """ For a fixed set of hyper-parameters, this evaluates the generalization error using KFold cross-validation. """ sched = Scheduler() # k-fold plugin will call rbm.change_dataset using various splits of the data kfold_plugin.act(sched, on=[sched.begin(), trainer.end()]) trainer.act(sched, on=[kfold_plugin.step()]) # trainer ends on early_stop.end(). This means that trainer.end() will forward # the early-stopping message which contains the best validation error. kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()] # this best validation error is then forwarded by single_crossval_run sched.end(on=kfold_measure.end()) return sched #### MAIN LOOP #### rbm1 = ... rbm2 = ... dataset = .... dbn_trainer = dbn_trainer(rbm1, rbm2) kfold_plugin = KFold([rbm1, rbm2], dataset) kfold_measure = ... ## In our view, the meta-plugins defined above would live in the library somewhere. Hooks can be ## added without modifying the library code. The meta-plugin's scheduler contains a dictionary ## of "registered" plugins along with their events. We can thus register "user-plugins" based on ## any of these events. # define a logger plugin of some sort print_stat = .... # act on each iteration of the early-stopping plugin # NB: the message is forwarded as is. It is up to the print_stat plugin to parse it properly. print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step()) #### THIS SHOULD CORRESPOND TO THE OUTER LOOP #### # this is the final outer-loop which tests various configurations of hyperparameters sched = Scheduler() hyperparam_change = DBN_HyperParam([rbm1, rbm2]) hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure) hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()]) hyperparam_test.act(sched, on=hyperparam_change.step()) sched.end(hyperparam_change.end()) ##### RUN THE WHOLE DAMN THING ##### sched.run()