1256
|
1 ####
|
|
2 # H1: everything works in term of iterator
|
|
3 # everything has a step() and end() method
|
|
4 ####
|
|
5
|
|
6 # Construct counter plugin that keeps track of number of epochs
|
|
7 class Counter(Plugin):
|
|
8
|
|
9 def __init__(self, sch, name, threshold):
|
|
10 super(self, Counter).__init__(sch, name)
|
|
11 self.n = 0
|
|
12 self.threshold = threshold
|
|
13
|
|
14 def execute(self, msg):
|
|
15 self.n += 1
|
|
16 if self.n > self.threshold:
|
|
17 self.fire(Event('terminate', value = self.n))
|
|
18
|
|
19 def fixed_epoch_trainer(model, save, n_epochs):
|
|
20 sched = Scheduler()
|
|
21
|
|
22 # define plugins
|
|
23 [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
|
|
24
|
|
25 counter = Counter(sched, 'epoch', n_epochs)
|
|
26
|
|
27 # register actions
|
|
28 model.act(sched, on=[sched.begin(), model.step(), counter.step()])
|
|
29 counter.act(sched, on=model.end())
|
|
30 save_model.act(sched, on=counter.end())
|
|
31
|
|
32 sched.terminate(on=counter.end())
|
|
33
|
|
34 return sched
|
|
35
|
|
36 def early_stop_trainer(model, validate, save, **kw):
|
|
37 sched = Scheduler()
|
|
38
|
|
39 # define plugins
|
|
40 [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
|
|
41
|
|
42 early_stop = Stopper(**kw)
|
|
43
|
|
44 # register actions
|
|
45 model.act(sched, on=[sched.begin(), model.step(), validate.step()])
|
|
46 validate.act(sched, on=model.end())
|
|
47 early_stop.act(sched, on=validate.step())
|
|
48 save_model.act(sched, on=[early_stop.step(), early_stop.end()])
|
|
49
|
|
50 sched.terminate(on=early_stop.end())
|
|
51
|
|
52 return sched
|
|
53
|
|
54 def dbn_trainer(rbm1, rbm2):
|
|
55 sched = Scheduler()
|
|
56
|
|
57 pretrain_layer1 = fixed_epoch_trainer(rbm1, save)
|
|
58 pretrain_layer1.act(sched, on=sched.begin())
|
|
59
|
|
60 pretrain_layer2 = fixed_epoch_trainer(rbm2, save)
|
|
61 pretrain_layer2.act(sched, on=pretrain_layer1.end())
|
|
62
|
|
63 ## TBD: by the layer committee
|
|
64 mlp = function(rbm1, rbm2)
|
|
65
|
|
66 fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp)
|
|
67 fine_tuning.act(sched, on=pretrain_layer2.end())
|
|
68
|
|
69 return sched
|
|
70
|
|
71 def single_crossval_run(trainer, kfold_plugin, kfold_measure)
|
|
72
|
|
73 sched = Scheduler()
|
|
74
|
|
75 # k-fold plugin will call rbm.change_dataset using various splits of the data
|
|
76 kfold_plugin.act(sched, on=[sched.begin(), trainer.end()])
|
|
77 trainer.act(sched, on=[kfold_plugin.step()])
|
|
78
|
|
79 # trainer terminates on early_stop.end(). This means that trainer.end() will forward
|
|
80 # the early-stopping message which contains the best validation error.
|
|
81 kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()]
|
|
82
|
|
83 # this best validation error is then forwarded by single_crossval_run
|
|
84 sched.terminate(on=kfold_measure.end())
|
|
85
|
|
86 return sched
|
|
87
|
|
88
|
|
89 #### MAIN LOOP ####
|
|
90 rbm1 = ...
|
|
91 rbm2 = ...
|
|
92 dataset = ....
|
|
93 dbn_trainer = dbn_trainer(rbm1, rbm2)
|
|
94 kfold_plugin = KFold([rbm1, rbm2], dataset)
|
|
95 kfold_measure = ...
|
|
96
|
|
97 # manually add "hook" to monitor early stopping statistics
|
|
98 # NB: advantage of plugins is that this code can go anywhere ...
|
|
99 print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step())
|
|
100
|
|
101 #### THIS SHOULD CORRESPOND TO THE OUTER LOOP ####
|
|
102 sched = Scheduler()
|
|
103
|
|
104 hyperparam_change = DBN_HyperParam([rbm1, rbm2])
|
|
105 hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure)
|
|
106
|
|
107 hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()])
|
|
108 hyperparam_test.act(sched, on=hyperparam_change.step())
|
|
109
|
|
110 sched.terminate(hyperparam_change.end())
|
|
111
|
|
112
|
|
113 ##### RUN THE WHOLE DAMN THING #####
|
|
114 sched.run()
|