1258
|
1 """
|
|
2 BASIC IDEA
|
|
3 ==========
|
|
4
|
|
5 The main ideas of the proposal are as follows:
|
|
6 1. Basic components are defined in the traditional way (classes, methods, etc.)
|
|
7 2. Components are then "wrapped" into **plugins**
|
|
8 3. Plugins are registered through a **scheduler** to act on certain **events**
|
|
9 4. We have defined two basic types of events:
|
|
10 4.1 scheduler events: these are emitted by the scheduler itself. For now we have defined:
|
|
11 scheduler.begin() : sent at the very beginning of scheduler execution
|
1259
|
12 scheduler.end() : sent when the scheduler ends ( this makes sense since there is
|
|
13 at least from the user perspective a hierarchy of schedulars)
|
1258
|
14 4.2 plugin generate events: plugins act as iterators and generate 2 basic events,
|
|
15 plugin.next() : sent of every "iteration". Each plugin is free to define what an
|
|
16 iteration actually means.
|
|
17 plugin.end() : sent when the plugin is done iterating.
|
|
18 5. Using Olivier Breuleux's schedular, plugins can decide to act on every event, every n-th
|
|
19 occurence of an event, etc.
|
|
20
|
|
21 OVERVIEW
|
|
22 ========
|
1256
|
23
|
1258
|
24 The code below attempts to show how one would code a training experiment, where we pre-train +
|
|
25 fine-tune a DBN, while using cross-validation. We have ommitted data preprocessing for now, but
|
|
26 do not think it would have much of an impact on the finished product.
|
|
27 """
|
|
28
|
1256
|
29 class Counter(Plugin):
|
1258
|
30 """
|
|
31 This is an example of a plugin which takes care of counting "stuff". It is generic enough
|
|
32 that it would probably be included in the library somewhere.
|
|
33 """
|
1256
|
34
|
1259
|
35 def __init__(self, name, end_count, next_count = 1):
|
1258
|
36 """
|
|
37 :param name: name of the event we are counting (could be useful for debugging)
|
|
38 :param next_count: number of iterations before triggering a "next" event
|
|
39 :param end_count: number of iterations before triggering an "end" event
|
|
40 """
|
|
41 super(self, Counter).__init__()
|
1256
|
42 self.n = 0
|
1258
|
43 self.next_count = next_count
|
|
44 self.end_count = end_count
|
1256
|
45
|
|
46 def execute(self, msg):
|
1258
|
47 """
|
|
48 The execute function is the one which gets executed each time the plugin is "acted"
|
|
49 upon. This will send "next" and "end" events to the scheduler, which other plugins can
|
|
50 listen to. We show an example of this later.
|
|
51 """
|
1256
|
52 self.n += 1
|
1258
|
53 if self.n > self.end_count:
|
|
54 self.fire(Event('end', value = self.n))
|
|
55 elif self.n > self.next_count:
|
|
56 self.fire(Event('next', value = self.n))
|
|
57
|
1256
|
58
|
|
59 def fixed_epoch_trainer(model, save, n_epochs):
|
1258
|
60 """
|
|
61 This is an example of a meta-plugin. Meta-plugin are just like any other plugins, except
|
|
62 they themselves contain other plugins and their own schedulers. A meta-plugin would replace
|
|
63 code-blocks which are usually found inside for-loop or if-else statements.
|
|
64
|
|
65 The fixed_epoch_trainer meta-plugin takes care of training a given model, for a fixed
|
|
66 number of epochs and then saving the resulting model.
|
1259
|
67
|
|
68 Other arguments for having meta-plugin :
|
|
69 * they can define a semantically separable block of code
|
|
70 * they are what the library provides 99% of the code ( so you can define a certain
|
|
71 template of connecting plugins as a meta-plugin and ship it without worry of things)
|
|
72 * they can be breaked apart by the main schedular ( so you would not have "different"
|
|
73 schedulars running in the same time; is just one schedular, this way of constructing
|
|
74 things is just to improve understanding and intuitions of the syste
|
|
75 * they help pushing all the complexity over the backbone of the library ( i.e. the
|
|
76 schedular )
|
|
77 * all plugins registered inside a hyper-plugin are active only when the hyper-plugin is
|
|
78 active; In this sense they can help definining scopes ( as for variables ) - optional
|
1258
|
79 """
|
|
80 # we start by defining our own private scheduler
|
1256
|
81 sched = Scheduler()
|
|
82
|
1258
|
83 # convert "components" to plugins. Maybe this could be automated or done in another way...
|
|
84 # syntax is not really important here.
|
1256
|
85 [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
|
|
86
|
1258
|
87 # instantiate the counter plugin to "end" after n_epochs
|
|
88 counter = Counter('epoch', 1, n_epochs)
|
|
89
|
|
90 ####
|
|
91 # Registering actions: overview of syntax
|
|
92 # plugin1.act(sched, on=[event1, event2]) means that that plugin1 will perform one
|
|
93 # "iteration" on every occurence of event1 and event2.
|
|
94 ####
|
1256
|
95
|
1258
|
96 # In our example, we assume that 1 iteration of "model" means 1 minibatch update.
|
|
97 # Model performs an iteration:
|
|
98 # * when program first starts
|
|
99 # * after is it done with previous iteration (auto-loop)
|
|
100 # * after each epoch, as defined by the counter. The counter counts epochs by "trapping"
|
|
101 # the model's end event.
|
1256
|
102 model.act(sched, on=[sched.begin(), model.step(), counter.step()])
|
1258
|
103 # counter is incremented every time the
|
1256
|
104 counter.act(sched, on=model.end())
|
1258
|
105 # the save_model plugin then takes care of saving everything once the counter expires
|
1256
|
106 save_model.act(sched, on=counter.end())
|
|
107
|
1258
|
108 # Meta-plugins also generate events: they map "internal" events to "external" events
|
|
109 # fixed_epoch_trainer.end() will thus correspond to counter.end()
|
|
110 sched.end(on=counter.end())
|
|
111 # in the same way, you could have:
|
|
112 # sched.next(on=counter.next()) but this is not useful here
|
1256
|
113
|
|
114 return sched
|
|
115
|
|
116 def early_stop_trainer(model, validate, save, **kw):
|
1258
|
117 """
|
|
118 This is another plugin example, which takes care of training the model but using early
|
|
119 stopping.
|
|
120 """
|
1256
|
121 sched = Scheduler()
|
|
122
|
|
123 # define plugins
|
|
124 [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
|
|
125
|
|
126 early_stop = Stopper(**kw)
|
|
127
|
|
128 # register actions
|
|
129 model.act(sched, on=[sched.begin(), model.step(), validate.step()])
|
1258
|
130 # we measure validation error after every epoch (model.end)
|
1256
|
131 validate.act(sched, on=model.end())
|
1258
|
132 # early-stopper gets triggered every time we have a new validation error
|
|
133 # the error itself is passed within the "step" message
|
1256
|
134 early_stop.act(sched, on=validate.step())
|
1258
|
135 # model is saved whenever we find a new best model (early_stop.step) or when we have found
|
|
136 # THE best model (early_stop.end)
|
1256
|
137 save_model.act(sched, on=[early_stop.step(), early_stop.end()])
|
|
138
|
1258
|
139 sched.end(on=early_stop.end())
|
1256
|
140
|
|
141 return sched
|
|
142
|
|
143 def dbn_trainer(rbm1, rbm2):
|
1258
|
144 """
|
|
145 This meta-plugin pre-trains a two-layer DBN for a fixed number of epochs, and then performs
|
|
146 fine-tuning on the resulting MLP. This should hopefully be self-explanatory.
|
|
147 """
|
1256
|
148 sched = Scheduler()
|
|
149
|
|
150 pretrain_layer1 = fixed_epoch_trainer(rbm1, save)
|
|
151 pretrain_layer1.act(sched, on=sched.begin())
|
|
152
|
|
153 pretrain_layer2 = fixed_epoch_trainer(rbm2, save)
|
|
154 pretrain_layer2.act(sched, on=pretrain_layer1.end())
|
|
155
|
|
156 ## TBD: by the layer committee
|
|
157 mlp = function(rbm1, rbm2)
|
|
158
|
|
159 fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp)
|
|
160 fine_tuning.act(sched, on=pretrain_layer2.end())
|
|
161
|
|
162 return sched
|
|
163
|
|
164 def single_crossval_run(trainer, kfold_plugin, kfold_measure)
|
1258
|
165 """
|
|
166 For a fixed set of hyper-parameters, this evaluates the generalization error using KFold
|
|
167 cross-validation.
|
|
168 """
|
1256
|
169
|
|
170 sched = Scheduler()
|
|
171
|
|
172 # k-fold plugin will call rbm.change_dataset using various splits of the data
|
|
173 kfold_plugin.act(sched, on=[sched.begin(), trainer.end()])
|
|
174 trainer.act(sched, on=[kfold_plugin.step()])
|
|
175
|
1258
|
176 # trainer ends on early_stop.end(). This means that trainer.end() will forward
|
1256
|
177 # the early-stopping message which contains the best validation error.
|
|
178 kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()]
|
|
179
|
|
180 # this best validation error is then forwarded by single_crossval_run
|
1258
|
181 sched.end(on=kfold_measure.end())
|
1256
|
182
|
|
183 return sched
|
|
184
|
|
185
|
|
186 #### MAIN LOOP ####
|
|
187 rbm1 = ...
|
|
188 rbm2 = ...
|
|
189 dataset = ....
|
|
190 dbn_trainer = dbn_trainer(rbm1, rbm2)
|
|
191 kfold_plugin = KFold([rbm1, rbm2], dataset)
|
|
192 kfold_measure = ...
|
|
193
|
1258
|
194 ## In our view, the meta-plugins defined above would live in the library somewhere. Hooks can be
|
|
195 ## added without modifying the library code. The meta-plugin's scheduler contains a dictionary
|
|
196 ## of "registered" plugins along with their events. We can thus register "user-plugins" based on
|
|
197 ## any of these events.
|
|
198 # define a logger plugin of some sort
|
|
199 print_stat = ....
|
|
200 # act on each iteration of the early-stopping plugin
|
|
201 # NB: the message is forwarded as is. It is up to the print_stat plugin to parse it properly.
|
1256
|
202 print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step())
|
|
203
|
|
204 #### THIS SHOULD CORRESPOND TO THE OUTER LOOP ####
|
1258
|
205 # this is the final outer-loop which tests various configurations of hyperparameters
|
|
206
|
1256
|
207 sched = Scheduler()
|
|
208
|
|
209 hyperparam_change = DBN_HyperParam([rbm1, rbm2])
|
|
210 hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure)
|
|
211
|
|
212 hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()])
|
|
213 hyperparam_test.act(sched, on=hyperparam_change.step())
|
|
214
|
1258
|
215 sched.end(hyperparam_change.end())
|
1256
|
216
|
|
217 ##### RUN THE WHOLE DAMN THING #####
|
|
218 sched.run()
|