view doc/v2_planning/plugin_RP_GD.py @ 1363:18b2ebec6bca

Reply to a comment of OD
author Razvan Pascanu <r.pascanu@gmail.com>
date Fri, 12 Nov 2010 11:11:49 -0500
parents 6f76ecef869e
children
line wrap: on
line source

"""
BASIC IDEA
==========

The main ideas of the proposal are as follows:
1. Basic components are defined in the traditional way (classes, methods, etc.)
2. Components are then "wrapped" into **plugins**
3. Plugins are registered through a **scheduler** to act on certain **events**
4. We have defined two basic types of events:
    4.1 scheduler events: these are emitted by the scheduler itself. For now we have defined:
        scheduler.begin() : sent at the very beginning of scheduler execution
        scheduler.end()   : sent when the scheduler ends ( this makes sense since there is 
                            at least from the user perspective a hierarchy of schedulars)
    4.2 plugin generate events: plugins act as iterators and generate 2 basic events,
        plugin.next() : sent of every "iteration". Each plugin is free to define what an
                        iteration actually means.
        plugin.end()  : sent when the plugin is done iterating.
5. Using Olivier Breuleux's schedular, plugins can decide to act on every event, every n-th
   occurence of an event, etc.

OVERVIEW
========

The code below attempts to show how one would code a training experiment, where we pre-train +
fine-tune a DBN, while using cross-validation. We have ommitted data preprocessing for now, but
do not think it would have much of an impact on the finished product.
"""

class Counter(Plugin):
    """
    This is an example of a plugin which takes care of counting "stuff". It is generic enough
    that it would probably be included in the library somewhere.
    """

    def __init__(self, name, end_count, next_count = 1):
        """
        :param name: name of the event we are counting (could be useful for debugging)
        :param next_count: number of iterations before triggering a "next" event
        :param end_count: number of iterations before triggering an "end" event
        """
        super(self, Counter).__init__()
        self.n = 0
        self.next_count = next_count
        self.end_count = end_count

    def execute(self, msg):
        """
        The execute function is the one which gets executed each time the plugin is "acted"
        upon. This will send "next" and "end" events to the scheduler, which other plugins can
        listen to. We show an example of this later.
        """
        self.n += 1
        if self.n > self.end_count:
            self.fire(Event('end', value = self.n))
        elif self.n > self.next_count:
            self.fire(Event('next', value = self.n))


def fixed_epoch_trainer(model, save, n_epochs):
    """
    This is an example of a meta-plugin. Meta-plugin are just like any other plugins, except
    they themselves contain other plugins and their own schedulers. A meta-plugin would replace
    code-blocks which are usually found inside for-loop or if-else statements.

    The fixed_epoch_trainer meta-plugin takes care of training a given model, for a fixed
    number of epochs and then saving the resulting model.

    Other arguments for having meta-plugin : 
      * they can define a semantically separable block of code 
      * they are what the library provides 99% of the code ( so you can define a certain 
      template of connecting plugins as a meta-plugin and ship it without worry of things)
      * they can be breaked apart by the main schedular ( so you would not have "different" 
        schedulars running in the same time; is just one schedular, this way of constructing
        things is just to improve understanding and intuitions of the syste
      * they help pushing all the complexity over the backbone of the library ( i.e. the 
        schedular ) 
      * all plugins registered inside a hyper-plugin are active only when the hyper-plugin is 
      active; In this sense they can help definining scopes ( as for variables ) - optional
    """
    # we start by defining our own private scheduler
    sched = Scheduler()

    # convert "components" to plugins. Maybe this could be automated or done in another way...
    # syntax is not really important here.
    [model, validate, save] = map(pluggin_wrapper, [model, validate, save])

    # instantiate the counter plugin to "end" after n_epochs
    counter = Counter('epoch', 1, n_epochs)

    ####
    # Registering actions: overview of syntax
    # plugin1.act(sched, on=[event1, event2]) means that that plugin1 will perform one
    # "iteration" on every occurence of event1 and event2.
    ####

    # In our example, we assume that 1 iteration of "model" means 1 minibatch update.
    # Model performs an iteration:
    # * when program first starts
    # * after is it done with previous iteration (auto-loop)
    # * after each epoch, as defined by the counter. The counter counts epochs by "trapping"
    #   the model's end event.
    model.act(sched, on=[sched.begin(), model.step(), counter.step()])
    # counter is incremented every time the 
    counter.act(sched, on=model.end())
    # the save_model plugin then takes care of saving everything once the counter expires
    save_model.act(sched, on=counter.end())

    # Meta-plugins also generate events: they map "internal" events to "external" events
    # fixed_epoch_trainer.end() will thus correspond to counter.end() 
    sched.end(on=counter.end())
    # in the same way, you could have:
    # sched.next(on=counter.next()) but this is not useful here

    return sched

def early_stop_trainer(model, validate, save, **kw):
    """
    This is another plugin example, which takes care of training the model but using early
    stopping.
    """
    sched = Scheduler()

    # define plugins 
    [model, validate, save] = map(pluggin_wrapper, [model, validate, save])

    early_stop = Stopper(**kw)

    # register actions
    model.act(sched, on=[sched.begin(), model.step(), validate.step()])
    # we measure validation error after every epoch (model.end)
    validate.act(sched, on=model.end())
    # early-stopper gets triggered every time we have a new validation error
    # the error itself is passed within the "step" message
    early_stop.act(sched, on=validate.step())
    # model is saved whenever we find a new best model (early_stop.step) or when we have found
    # THE best model (early_stop.end)
    save_model.act(sched, on=[early_stop.step(), early_stop.end()])

    sched.end(on=early_stop.end())

    return sched

def dbn_trainer(rbm1, rbm2):
    """
    This meta-plugin pre-trains a two-layer DBN for a fixed number of epochs, and then performs
    fine-tuning on the resulting MLP. This should hopefully be self-explanatory.
    """
    sched = Scheduler()

    pretrain_layer1 = fixed_epoch_trainer(rbm1, save)
    pretrain_layer1.act(sched, on=sched.begin())

    pretrain_layer2 = fixed_epoch_trainer(rbm2, save)
    pretrain_layer2.act(sched, on=pretrain_layer1.end())

    ## TBD: by the layer committee
    mlp = function(rbm1, rbm2)

    fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp)
    fine_tuning.act(sched, on=pretrain_layer2.end())

    return sched

def single_crossval_run(trainer, kfold_plugin, kfold_measure)
    """
    For a fixed set of hyper-parameters, this evaluates the generalization error using KFold
    cross-validation.
    """

    sched = Scheduler()

    # k-fold plugin will call rbm.change_dataset using various splits of the data
    kfold_plugin.act(sched, on=[sched.begin(), trainer.end()])
    trainer.act(sched, on=[kfold_plugin.step()])

    # trainer ends on early_stop.end(). This means that trainer.end() will forward
    # the early-stopping message which contains the best validation error.
    kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()]

    # this best validation error is then forwarded by single_crossval_run
    sched.end(on=kfold_measure.end())
    
    return sched


#### MAIN LOOP ####
rbm1 = ...
rbm2 = ...
dataset = ....
dbn_trainer = dbn_trainer(rbm1, rbm2)
kfold_plugin = KFold([rbm1, rbm2], dataset)
kfold_measure = ...

## In our view, the meta-plugins defined above would live in the library somewhere. Hooks can be
## added without modifying the library code. The meta-plugin's scheduler contains a dictionary
## of "registered" plugins along with their events. We can thus register "user-plugins" based on
## any of these events.
# define a logger plugin of some sort
print_stat = ....   
# act on each iteration of the early-stopping plugin
# NB: the message is forwarded as is. It is up to the print_stat plugin to parse it properly.
print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step())

#### THIS SHOULD CORRESPOND TO THE OUTER LOOP ####
# this is the final outer-loop which tests various configurations of hyperparameters

sched = Scheduler()

hyperparam_change = DBN_HyperParam([rbm1, rbm2])
hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure) 

hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()])
hyperparam_test.act(sched, on=hyperparam_change.step())

sched.end(hyperparam_change.end())

##### RUN THE WHOLE DAMN THING #####
sched.run()