1258
|
1 """
|
|
2 BASIC IDEA
|
|
3 ==========
|
|
4
|
|
5 The main ideas of the proposal are as follows:
|
|
6 1. Basic components are defined in the traditional way (classes, methods, etc.)
|
|
7 2. Components are then "wrapped" into **plugins**
|
|
8 3. Plugins are registered through a **scheduler** to act on certain **events**
|
|
9 4. We have defined two basic types of events:
|
|
10 4.1 scheduler events: these are emitted by the scheduler itself. For now we have defined:
|
|
11 scheduler.begin() : sent at the very beginning of scheduler execution
|
|
12 scheduler.end() : sent when the scheduler ends
|
|
13 4.2 plugin generate events: plugins act as iterators and generate 2 basic events,
|
|
14 plugin.next() : sent of every "iteration". Each plugin is free to define what an
|
|
15 iteration actually means.
|
|
16 plugin.end() : sent when the plugin is done iterating.
|
|
17 5. Using Olivier Breuleux's schedular, plugins can decide to act on every event, every n-th
|
|
18 occurence of an event, etc.
|
|
19
|
|
20 OVERVIEW
|
|
21 ========
|
1256
|
22
|
1258
|
23 The code below attempts to show how one would code a training experiment, where we pre-train +
|
|
24 fine-tune a DBN, while using cross-validation. We have ommitted data preprocessing for now, but
|
|
25 do not think it would have much of an impact on the finished product.
|
|
26 """
|
|
27
|
1256
|
28 class Counter(Plugin):
|
1258
|
29 """
|
|
30 This is an example of a plugin which takes care of counting "stuff". It is generic enough
|
|
31 that it would probably be included in the library somewhere.
|
|
32 """
|
1256
|
33
|
1258
|
34 def __init__(self, name, next_count, end_count):
|
|
35 """
|
|
36 :param name: name of the event we are counting (could be useful for debugging)
|
|
37 :param next_count: number of iterations before triggering a "next" event
|
|
38 :param end_count: number of iterations before triggering an "end" event
|
|
39 """
|
|
40 super(self, Counter).__init__()
|
1256
|
41 self.n = 0
|
1258
|
42 self.next_count = next_count
|
|
43 self.end_count = end_count
|
1256
|
44
|
|
45 def execute(self, msg):
|
1258
|
46 """
|
|
47 The execute function is the one which gets executed each time the plugin is "acted"
|
|
48 upon. This will send "next" and "end" events to the scheduler, which other plugins can
|
|
49 listen to. We show an example of this later.
|
|
50 """
|
1256
|
51 self.n += 1
|
1258
|
52 if self.n > self.end_count:
|
|
53 self.fire(Event('end', value = self.n))
|
|
54 elif self.n > self.next_count:
|
|
55 self.fire(Event('next', value = self.n))
|
|
56
|
1256
|
57
|
|
58 def fixed_epoch_trainer(model, save, n_epochs):
|
1258
|
59 """
|
|
60 This is an example of a meta-plugin. Meta-plugin are just like any other plugins, except
|
|
61 they themselves contain other plugins and their own schedulers. A meta-plugin would replace
|
|
62 code-blocks which are usually found inside for-loop or if-else statements.
|
|
63
|
|
64 The fixed_epoch_trainer meta-plugin takes care of training a given model, for a fixed
|
|
65 number of epochs and then saving the resulting model.
|
|
66 """
|
|
67 # we start by defining our own private scheduler
|
1256
|
68 sched = Scheduler()
|
|
69
|
1258
|
70 # convert "components" to plugins. Maybe this could be automated or done in another way...
|
|
71 # syntax is not really important here.
|
1256
|
72 [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
|
|
73
|
1258
|
74 # instantiate the counter plugin to "end" after n_epochs
|
|
75 counter = Counter('epoch', 1, n_epochs)
|
|
76
|
|
77 ####
|
|
78 # Registering actions: overview of syntax
|
|
79 # plugin1.act(sched, on=[event1, event2]) means that that plugin1 will perform one
|
|
80 # "iteration" on every occurence of event1 and event2.
|
|
81 ####
|
1256
|
82
|
1258
|
83 # In our example, we assume that 1 iteration of "model" means 1 minibatch update.
|
|
84 # Model performs an iteration:
|
|
85 # * when program first starts
|
|
86 # * after is it done with previous iteration (auto-loop)
|
|
87 # * after each epoch, as defined by the counter. The counter counts epochs by "trapping"
|
|
88 # the model's end event.
|
1256
|
89 model.act(sched, on=[sched.begin(), model.step(), counter.step()])
|
1258
|
90 # counter is incremented every time the
|
1256
|
91 counter.act(sched, on=model.end())
|
1258
|
92 # the save_model plugin then takes care of saving everything once the counter expires
|
1256
|
93 save_model.act(sched, on=counter.end())
|
|
94
|
1258
|
95 # Meta-plugins also generate events: they map "internal" events to "external" events
|
|
96 # fixed_epoch_trainer.end() will thus correspond to counter.end()
|
|
97 sched.end(on=counter.end())
|
|
98 # in the same way, you could have:
|
|
99 # sched.next(on=counter.next()) but this is not useful here
|
1256
|
100
|
|
101 return sched
|
|
102
|
|
103 def early_stop_trainer(model, validate, save, **kw):
|
1258
|
104 """
|
|
105 This is another plugin example, which takes care of training the model but using early
|
|
106 stopping.
|
|
107 """
|
1256
|
108 sched = Scheduler()
|
|
109
|
|
110 # define plugins
|
|
111 [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
|
|
112
|
|
113 early_stop = Stopper(**kw)
|
|
114
|
|
115 # register actions
|
|
116 model.act(sched, on=[sched.begin(), model.step(), validate.step()])
|
1258
|
117 # we measure validation error after every epoch (model.end)
|
1256
|
118 validate.act(sched, on=model.end())
|
1258
|
119 # early-stopper gets triggered every time we have a new validation error
|
|
120 # the error itself is passed within the "step" message
|
1256
|
121 early_stop.act(sched, on=validate.step())
|
1258
|
122 # model is saved whenever we find a new best model (early_stop.step) or when we have found
|
|
123 # THE best model (early_stop.end)
|
1256
|
124 save_model.act(sched, on=[early_stop.step(), early_stop.end()])
|
|
125
|
1258
|
126 sched.end(on=early_stop.end())
|
1256
|
127
|
|
128 return sched
|
|
129
|
|
130 def dbn_trainer(rbm1, rbm2):
|
1258
|
131 """
|
|
132 This meta-plugin pre-trains a two-layer DBN for a fixed number of epochs, and then performs
|
|
133 fine-tuning on the resulting MLP. This should hopefully be self-explanatory.
|
|
134 """
|
1256
|
135 sched = Scheduler()
|
|
136
|
|
137 pretrain_layer1 = fixed_epoch_trainer(rbm1, save)
|
|
138 pretrain_layer1.act(sched, on=sched.begin())
|
|
139
|
|
140 pretrain_layer2 = fixed_epoch_trainer(rbm2, save)
|
|
141 pretrain_layer2.act(sched, on=pretrain_layer1.end())
|
|
142
|
|
143 ## TBD: by the layer committee
|
|
144 mlp = function(rbm1, rbm2)
|
|
145
|
|
146 fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp)
|
|
147 fine_tuning.act(sched, on=pretrain_layer2.end())
|
|
148
|
|
149 return sched
|
|
150
|
|
151 def single_crossval_run(trainer, kfold_plugin, kfold_measure)
|
1258
|
152 """
|
|
153 For a fixed set of hyper-parameters, this evaluates the generalization error using KFold
|
|
154 cross-validation.
|
|
155 """
|
1256
|
156
|
|
157 sched = Scheduler()
|
|
158
|
|
159 # k-fold plugin will call rbm.change_dataset using various splits of the data
|
|
160 kfold_plugin.act(sched, on=[sched.begin(), trainer.end()])
|
|
161 trainer.act(sched, on=[kfold_plugin.step()])
|
|
162
|
1258
|
163 # trainer ends on early_stop.end(). This means that trainer.end() will forward
|
1256
|
164 # the early-stopping message which contains the best validation error.
|
|
165 kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()]
|
|
166
|
|
167 # this best validation error is then forwarded by single_crossval_run
|
1258
|
168 sched.end(on=kfold_measure.end())
|
1256
|
169
|
|
170 return sched
|
|
171
|
|
172
|
|
173 #### MAIN LOOP ####
|
|
174 rbm1 = ...
|
|
175 rbm2 = ...
|
|
176 dataset = ....
|
|
177 dbn_trainer = dbn_trainer(rbm1, rbm2)
|
|
178 kfold_plugin = KFold([rbm1, rbm2], dataset)
|
|
179 kfold_measure = ...
|
|
180
|
1258
|
181 ## In our view, the meta-plugins defined above would live in the library somewhere. Hooks can be
|
|
182 ## added without modifying the library code. The meta-plugin's scheduler contains a dictionary
|
|
183 ## of "registered" plugins along with their events. We can thus register "user-plugins" based on
|
|
184 ## any of these events.
|
|
185 # define a logger plugin of some sort
|
|
186 print_stat = ....
|
|
187 # act on each iteration of the early-stopping plugin
|
|
188 # NB: the message is forwarded as is. It is up to the print_stat plugin to parse it properly.
|
1256
|
189 print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step())
|
|
190
|
|
191 #### THIS SHOULD CORRESPOND TO THE OUTER LOOP ####
|
1258
|
192 # this is the final outer-loop which tests various configurations of hyperparameters
|
|
193
|
1256
|
194 sched = Scheduler()
|
|
195
|
|
196 hyperparam_change = DBN_HyperParam([rbm1, rbm2])
|
|
197 hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure)
|
|
198
|
|
199 hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()])
|
|
200 hyperparam_test.act(sched, on=hyperparam_change.step())
|
|
201
|
1258
|
202 sched.end(hyperparam_change.end())
|
1256
|
203
|
|
204 ##### RUN THE WHOLE DAMN THING #####
|
|
205 sched.run()
|