Mercurial > pylearn
comparison doc/v2_planning/plugin_RP_GD.py @ 1256:bf41991692ea
new plugin approach
author | gdesjardins |
---|---|
date | Fri, 24 Sep 2010 12:53:53 -0400 |
parents | |
children | c88db30f4e08 |
comparison
equal
deleted
inserted
replaced
1253:826d78f0135f | 1256:bf41991692ea |
---|---|
1 #### | |
2 # H1: everything works in term of iterator | |
3 # everything has a step() and end() method | |
4 #### | |
5 | |
6 # Construct counter plugin that keeps track of number of epochs | |
7 class Counter(Plugin): | |
8 | |
9 def __init__(self, sch, name, threshold): | |
10 super(self, Counter).__init__(sch, name) | |
11 self.n = 0 | |
12 self.threshold = threshold | |
13 | |
14 def execute(self, msg): | |
15 self.n += 1 | |
16 if self.n > self.threshold: | |
17 self.fire(Event('terminate', value = self.n)) | |
18 | |
19 def fixed_epoch_trainer(model, save, n_epochs): | |
20 sched = Scheduler() | |
21 | |
22 # define plugins | |
23 [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) | |
24 | |
25 counter = Counter(sched, 'epoch', n_epochs) | |
26 | |
27 # register actions | |
28 model.act(sched, on=[sched.begin(), model.step(), counter.step()]) | |
29 counter.act(sched, on=model.end()) | |
30 save_model.act(sched, on=counter.end()) | |
31 | |
32 sched.terminate(on=counter.end()) | |
33 | |
34 return sched | |
35 | |
36 def early_stop_trainer(model, validate, save, **kw): | |
37 sched = Scheduler() | |
38 | |
39 # define plugins | |
40 [model, validate, save] = map(pluggin_wrapper, [model, validate, save]) | |
41 | |
42 early_stop = Stopper(**kw) | |
43 | |
44 # register actions | |
45 model.act(sched, on=[sched.begin(), model.step(), validate.step()]) | |
46 validate.act(sched, on=model.end()) | |
47 early_stop.act(sched, on=validate.step()) | |
48 save_model.act(sched, on=[early_stop.step(), early_stop.end()]) | |
49 | |
50 sched.terminate(on=early_stop.end()) | |
51 | |
52 return sched | |
53 | |
54 def dbn_trainer(rbm1, rbm2): | |
55 sched = Scheduler() | |
56 | |
57 pretrain_layer1 = fixed_epoch_trainer(rbm1, save) | |
58 pretrain_layer1.act(sched, on=sched.begin()) | |
59 | |
60 pretrain_layer2 = fixed_epoch_trainer(rbm2, save) | |
61 pretrain_layer2.act(sched, on=pretrain_layer1.end()) | |
62 | |
63 ## TBD: by the layer committee | |
64 mlp = function(rbm1, rbm2) | |
65 | |
66 fine_tuning = early_stop_trainer(mlp, validate_mlp, save_mlp) | |
67 fine_tuning.act(sched, on=pretrain_layer2.end()) | |
68 | |
69 return sched | |
70 | |
71 def single_crossval_run(trainer, kfold_plugin, kfold_measure) | |
72 | |
73 sched = Scheduler() | |
74 | |
75 # k-fold plugin will call rbm.change_dataset using various splits of the data | |
76 kfold_plugin.act(sched, on=[sched.begin(), trainer.end()]) | |
77 trainer.act(sched, on=[kfold_plugin.step()]) | |
78 | |
79 # trainer terminates on early_stop.end(). This means that trainer.end() will forward | |
80 # the early-stopping message which contains the best validation error. | |
81 kfold_measure.act(sched, on=[trainer.end(), kill=kfold_plugin.end()] | |
82 | |
83 # this best validation error is then forwarded by single_crossval_run | |
84 sched.terminate(on=kfold_measure.end()) | |
85 | |
86 return sched | |
87 | |
88 | |
89 #### MAIN LOOP #### | |
90 rbm1 = ... | |
91 rbm2 = ... | |
92 dataset = .... | |
93 dbn_trainer = dbn_trainer(rbm1, rbm2) | |
94 kfold_plugin = KFold([rbm1, rbm2], dataset) | |
95 kfold_measure = ... | |
96 | |
97 # manually add "hook" to monitor early stopping statistics | |
98 # NB: advantage of plugins is that this code can go anywhere ... | |
99 print_stat.act(pretrain_layer1, on=pretrain_layer1.plugins['early_stop'].step()) | |
100 | |
101 #### THIS SHOULD CORRESPOND TO THE OUTER LOOP #### | |
102 sched = Scheduler() | |
103 | |
104 hyperparam_change = DBN_HyperParam([rbm1, rbm2]) | |
105 hyperparam_test = single_crossval_run(dbn_trainer, kfold_plugin, kfold_measure) | |
106 | |
107 hyperparam_change.act(sched, on=[sched.begin(), hyperparam_test.end()]) | |
108 hyperparam_test.act(sched, on=hyperparam_change.step()) | |
109 | |
110 sched.terminate(hyperparam_change.end()) | |
111 | |
112 | |
113 ##### RUN THE WHOLE DAMN THING ##### | |
114 sched.run() |