annotate doc/v2_planning/plugin_RP_GD.py @ 1256:bf41991692ea

new plugin approach
author gdesjardins
date Fri, 24 Sep 2010 12:53:53 -0400
parents
children c88db30f4e08
rev   line source
1256
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
1 ####
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
2 # H1: everything works in term of iterator
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
3 # everything has a step() and end() method
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
4 ####
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
5
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
6 # Construct counter plugin that keeps track of number of epochs
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
7 class Counter(Plugin):
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
8
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
9 def __init__(self, sch, name, threshold):
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
10 super(self, Counter).__init__(sch, name)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
11 self.n = 0
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
12 self.threshold = threshold
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
13
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
14 def execute(self, msg):
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
15 self.n += 1
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
16 if self.n > self.threshold:
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
17 self.fire(Event('terminate', value = self.n))
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
18
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
19 def fixed_epoch_trainer(model, save, n_epochs):
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
20 sched = Scheduler()
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
21
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
22 # define plugins
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
23 [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
24
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
25 counter = Counter(sched, 'epoch', n_epochs)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
26
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
27 # register actions
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
28 model.act(sched, on=[sched.begin(), model.step(), counter.step()])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
29 counter.act(sched, on=model.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
30 save_model.act(sched, on=counter.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
31
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
32 sched.terminate(on=counter.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
33
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
34 return sched
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
35
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
36 def early_stop_trainer(model, validate, save, **kw):
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
37 sched = Scheduler()
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
38
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
39 # define plugins
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
40 [model, validate, save] = map(pluggin_wrapper, [model, validate, save])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
41
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
42 early_stop = Stopper(**kw)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
43
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
44 # register actions
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
45 model.act(sched, on=[sched.begin(), model.step(), validate.step()])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
46 validate.act(sched, on=model.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
47 early_stop.act(sched, on=validate.step())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
48 save_model.act(sched, on=[early_stop.step(), early_stop.end()])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
49
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
50 sched.terminate(on=early_stop.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
51
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
52 return sched
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
53
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
54 def dbn_trainer(rbm1, rbm2):
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
55 sched = Scheduler()
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
56
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
57 pretrain_layer1 = fixed_epoch_trainer(rbm1, save)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
58 pretrain_layer1.act(sched, on=sched.begin())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
59
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
60 pretrain_layer2 = fixed_epoch_trainer(rbm2, save)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
61 pretrain_layer2.act(sched, on=pretrain_layer1.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
62
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
63 ## TBD: by the layer committee
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
64 mlp = function(rbm1, rbm2)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
65
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
66 fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
67 fine_tuning.act(sched, on=pretrain_layer2.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
68
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
69 return sched
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
70
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
71 def single_crossval_run(trainer, kfold_plugin, kfold_measure)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
72
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
73 sched = Scheduler()
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
74
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
75 # k-fold plugin will call rbm.change_dataset using various splits of the data
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
76 kfold_plugin.act(sched, on=[sched.begin(), trainer.end()])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
77 trainer.act(sched, on=[kfold_plugin.step()])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
78
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
79 # trainer terminates on early_stop.end(). This means that trainer.end() will forward
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
80 # the early-stopping message which contains the best validation error.
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
81 kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()]
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
82
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
83 # this best validation error is then forwarded by single_crossval_run
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
84 sched.terminate(on=kfold_measure.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
85
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
86 return sched
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
87
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
88
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
89 #### MAIN LOOP ####
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
90 rbm1 = ...
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
91 rbm2 = ...
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
92 dataset = ....
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
93 dbn_trainer = dbn_trainer(rbm1, rbm2)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
94 kfold_plugin = KFold([rbm1, rbm2], dataset)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
95 kfold_measure = ...
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
96
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
97 # manually add "hook" to monitor early stopping statistics
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
98 # NB: advantage of plugins is that this code can go anywhere ...
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
99 print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
100
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
101 #### THIS SHOULD CORRESPOND TO THE OUTER LOOP ####
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
102 sched = Scheduler()
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
103
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
104 hyperparam_change = DBN_HyperParam([rbm1, rbm2])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
105 hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure)
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
106
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
107 hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()])
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
108 hyperparam_test.act(sched, on=hyperparam_change.step())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
109
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
110 sched.terminate(hyperparam_change.end())
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
111
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
112
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
113 ##### RUN THE WHOLE DAMN THING #####
bf41991692ea new plugin approach
gdesjardins
parents:
diff changeset
114 sched.run()