Mercurial > pylearn
comparison doc/v2_planning/plugin_RP.py @ 1214:681b5e7e3b81
a few comments on James version
author | Razvan Pascanu <r.pascanu@gmail.com> |
---|---|
date | Wed, 22 Sep 2010 10:39:39 -0400 |
parents | 7fff3d5c7694 |
children |
comparison
equal
deleted
inserted
replaced
1213:33513a46c41b | 1214:681b5e7e3b81 |
---|---|
52 | 52 |
53 Step 2 | 53 Step 2 |
54 ====== | 54 ====== |
55 | 55 |
56 I will start with step 2 ( because I think that is more of a hot subject | 56 I will start with step 2 ( because I think that is more of a hot subject |
57 right now). I will assume you have the write plugins at had. | 57 right now). I will assume you have the right plugins at hand. |
58 This is a DBN with early stopping and .. | 58 This is a DBN with early stopping and .. |
59 | 59 |
60 .. code-block:: python | 60 .. code-block:: python |
61 ''' | 61 ''' |
62 data = load_mnist() | 62 data = load_mnist() |
74 ## Layer 1: | 74 ## Layer 1: |
75 h1 = sigmoid(dotW_b(x0,units = 200), constraint = L1( coeff = 0.1)) | 75 h1 = sigmoid(dotW_b(x0,units = 200), constraint = L1( coeff = 0.1)) |
76 x1 = recurrent_layer() | 76 x1 = recurrent_layer() |
77 x1.t0 = x0 | 77 x1.t0 = x0 |
78 x1.value = binomial_sample(sigmoid( reconstruct( binomial_sample(h1), x0))) | 78 x1.value = binomial_sample(sigmoid( reconstruct( binomial_sample(h1), x0))) |
79 cost = free_energy(train_x) - free_energy(x1.tp(5)) | 79 cost = free_energy(train_x) - free_energy(x1.t(5)) |
80 grads = [ (g.var, T.grad(cost.var, g.var)) for g in cost.params ] | 80 grads = [ (g.var, T.grad(cost.var, g.var)) for g in cost.params ] |
81 pseudo_cost = sum([ pl.sum(pl.abs(g)) for g in cost.params]) | 81 pseudo_cost = sum([ pl.sum(pl.abs(g)) for g in cost.params]) |
82 rbm1 = SGD( cost = pseudo_cost, grads = grads) | 82 rbm1 = SGD( cost = pseudo_cost, grads = grads) |
83 | 83 |
84 # Layer 2: | 84 # Layer 2: |
94 | 94 |
95 ca = Schedular() | 95 ca = Schedular() |
96 | 96 |
97 | 97 |
98 ### Constructing Modes ### | 98 ### Constructing Modes ### |
99 pretrain_layer1 = ca.mode('pretrain0') | 99 class pretrain_layer1 () |
100 | |
101 def register() | |
102 { | |
103 } | |
100 pretrain_layer2 = ca.mode('pretrain1') | 104 pretrain_layer2 = ca.mode('pretrain1') |
101 early_stopping = ca.mode('early') | 105 early_stopping = ca.mode('early') |
102 valid1 = ca.mode('stuff') | 106 code_block = ca.mode('code_block') |
103 kfolds = ca.mode('kfolds') | 107 kfolds = ca.mode('kfolds') |
104 | 108 |
105 # Construct modes dependency graph | 109 # Construct modes dependency graph |
106 valid0.include([ pretrian_layer1, pretrain_layer2, early_stopper]) | 110 code_block.include([ pretrian_layer1, pretrain_layer2, early_stopper]) |
107 kfolds.include( valid0 ) | 111 kfolds.include( code_block ) |
108 | 112 |
109 pretrain_layer1.act( on = valid1.begin(), when = always()) | 113 pretrain_layer1.act( on = code_block.begin(), when = always()) |
110 pretrain_layer2.act( on = pretrain_layer1.end(), when = always()) | 114 pretrain_layer2.act( on = pretrain_layer1.end(), when = always()) |
111 early_stopping.act ( on = pretrain_layer2.end(), when = always()) | 115 early_stopping.act ( on = pretrain_layer2.end(), when = always()) |
112 | 116 |
113 | 117 |
114 # Construct counter plugin that keeps track of number of epochs | 118 # Construct counter plugin that keeps track of number of epochs |
126 else: | 130 else: |
127 self.fire(Message('terminate')) | 131 self.fire(Message('terminate')) |
128 | 132 |
129 | 133 |
130 # Construct pre-training plugins | 134 # Construct pre-training plugins |
131 rbm1_plugin = plugin_wrapper(rbm1, sched = pretrain_layer1) | 135 rbm1_plugin = pretrain_layer1.include(plugin_wrapper(rbm1)) |
136 rbm2_plugin = pretrain_layer2.include(plugin_wrapper(rbm2)) | |
137 rbm1_counter = pretrain_layer1.include(counter) | |
138 rbm2_counter = pretrain_layer2.include(counter) | |
139 | |
132 rbm1_plugin.listen(Message('init'), update_hyperparameters) | 140 rbm1_plugin.listen(Message('init'), update_hyperparameters) |
133 rbm2_plugin = plugin_wrapper(rbm2, sched = pretrain_layer2) | 141 rbm1_plugin.listen(Message('continue'), dataset_restart) |
134 rbm2_plugin.listen(Message('init'), update_hyperparameters) | 142 rbm2_plugin.listen(Message('init'), update_hyperparameters) |
135 rbm1_counter = pretrain_layer1.register(counter) | 143 rbm2_plugin.listen(Message('continue'), dataset_restart) |
136 rbm2_counter = pretrain_layer2.register(counter) | |
137 | 144 |
138 | 145 |
139 # Dependency graph for pre-training layer 0 | 146 # Dependency graph for pre-training layer 0 |
140 rbm1_plugin.act( on = [ pretrain_layer1.begin() | 147 rbm1_plugin.act( on = [ pretrain_layer1.begin() , |
141 Message('continue') ], | 148 rbm1_plugin.value() ] , |
142 when = always()) | 149 when = always()) |
143 rbm1_counter.act( on = rbm1_plugin.eod(), when = always() ) | 150 rbm1_counter.act( on = rbm1_plugin.eod(), when = always() ) |
144 | 151 |
145 | 152 |
146 # Dependency graph for pre-training layer 1 | 153 # Dependency graph for pre-training layer 1 |
147 rbm2_plugin.act( on = pretrain_layer2.begin(), when = always()) | 154 rbm2_plugin.act( on = [ pretrain_layer2.begin() , |
155 rbm2_plugin.value() ] , | |
156 when = always()) | |
148 pretrain_layer2.stop( on = rbm2_plugin.eod(), when = always()) | 157 pretrain_layer2.stop( on = rbm2_plugin.eod(), when = always()) |
149 | 158 |
150 | 159 |
151 # Constructing fine-tunning plugins | 160 # Constructing fine-tunning plugins |
152 learner = early_stopper.register(plugin_wrapper(logreg)) | 161 learner = early_stopper.include(plugin_wrapper(logreg)) |
162 validation = early_stopper.include( plugin_wrapper(valid_err))) | |
163 clock = early_stopper.include( ca.generate_clock()) | |
164 early_stopper_plugin = early_stopper.include( early_stopper_plugin) | |
165 | |
166 | |
167 def save_model(plugin): | |
168 cPickle.dump(plugin.object, 'just_the_model.pkl') | |
169 | |
153 learner.listen(Message('init'), update_hyperparameters) | 170 learner.listen(Message('init'), update_hyperparameters) |
154 validation = early_stopper.register( plugin_wrapper(valid_err))) | |
155 validation.listen(Message('init'), update_hyperparameters) | 171 validation.listen(Message('init'), update_hyperparameters) |
156 clock = early_stopper.register( ca.generate_clock()) | 172 validation.listen(early_stopper_plugin.new_best_score(), save_model) |
157 early_stopper_plugin = early_stopper.register( early_stopper_plugin) | |
158 | |
159 @FnPlugin | |
160 def save_weights(self, message): | |
161 cPickle.dump(logreg, open('model.pkl')) | |
162 | |
163 | 173 |
164 learner.act( on = early_stopper.begin(), when = always()) | 174 learner.act( on = early_stopper.begin(), when = always()) |
165 learner.act( on = learner.value(), when = always()) | 175 learner.act( on = learner.value(), when = always()) |
166 validation.act( on = clock.hour(), when = every(n = 1)) | 176 validation.act( on = clock.hour(), when = every(n = 1)) |
167 early_stopper.act( on = validation.value(), when = always()) | 177 early_stopper.act( on = validation.value(), when = always()) |
168 save_model.act( on = early_stopper.new_best_error(), when =always()) | |
169 | 178 |
170 @FnPlugin | 179 @FnPlugin |
171 def kfolds_plugin(self,event): | 180 def kfolds_plugin(self,event): |
172 if not hasattr(self, 'n'): | 181 if not hasattr(self, 'n'): |
173 self.n = -1 | 182 self.n = -1 |
181 self.fire(msg) | 190 self.fire(msg) |
182 else: | 191 else: |
183 self.fire(Message('terminate')) | 192 self.fire(Message('terminate')) |
184 | 193 |
185 | 194 |
186 kfolds.register(kfolds_plugin) | 195 kfolds.include(kfolds_plugin) |
187 kfolds_plugin.act(kfolds.begin(), when = always()) | 196 kfolds_plugin.act([kfolds.begin(), Message('new split')], when = always()) |
188 kfolds_plugin.act(valid0.end(), always() ) | 197 kfolds_plugin.act(code_block.end(), always() ) |
189 valid0.act(Message('new split'), always() ) | 198 code_block.act(Message('new split'), always() ) |
190 | 199 |
191 sched.include(kfolds) | 200 sched.include(kfolds) |
192 | 201 |
193 sched.run() | 202 sched.run() |
194 | 203 |
195 ''' | 204 ''' |
205 | |
206 | |
196 | 207 |
197 Notes: | 208 Notes: |
198 when a mode is regstered to begin with a certain message, it will | 209 when a mode is regstered to begin with a certain message, it will |
199 rebroadcast that message when it starts, with only switching the | 210 rebroadcast that message when it starts, with only switching the |
200 type from whatever it was to 'init'. It will also send all 'init' messages | 211 type from whatever it was to 'init'. It will also send all 'init' messages |
273 be pretty useful : | 284 be pretty useful : |
274 | 285 |
275 * .replace(dict) -> method; replaces the subgraphs given as keys with | 286 * .replace(dict) -> method; replaces the subgraphs given as keys with |
276 the ones given as values; throws an exception if it | 287 the ones given as values; throws an exception if it |
277 is impossible | 288 is impossible |
289 | |
290 * replace(nodes, dict) -> function; call replace on all nodes given that dictionary | |
278 | 291 |
279 * reconstruct(dict) -> transform; tries to reconstruct the nodes given as | 292 * reconstruct(dict) -> transform; tries to reconstruct the nodes given as |
280 keys starting from the nodes given as values by | 293 keys starting from the nodes given as values by |
281 going through the inverse of all transforms that | 294 going through the inverse of all transforms that |
282 are in between | 295 are in between |
296 machines | 309 machines |
297 | 310 |
298 * switch(hyperparam, dict) -> transform; a lazy switch that allows you | 311 * switch(hyperparam, dict) -> transform; a lazy switch that allows you |
299 do construct by hyper-parameters | 312 do construct by hyper-parameters |
300 | 313 |
301 * get_hyperparameter -> method; given a name it will return the first node | 314 * get_hyperparameter(name) -> method; given a name it will return the first node |
302 starting from top that is a hyper parameter and has | 315 starting from top that is a hyper parameter and has |
303 that name | 316 that name |
304 * get_parameter -> method; given a name it will return the first node | 317 * get_parameter(name) -> method; given a name it will return the first node |
305 starting from top that is a parameter and has that | 318 starting from top that is a parameter and has that |
306 name | 319 name |
320 * get_hyperparameters() | |
321 * get_parameters() | |
322 | |
307 | 323 |
308 | 324 |
309 | 325 |
310 Because every node provides the dataset API it means you can iterate over | 326 Because every node provides the dataset API it means you can iterate over |
311 any of the nodes. They will produce the original dataset transformed up | 327 any of the nodes. They will produce the original dataset transformed up |
390 | 406 |
391 ''' | 407 ''' |
392 # sketch of writing a RNN | 408 # sketch of writing a RNN |
393 x = load_mnist() | 409 x = load_mnist() |
394 y = recurrent_layer() | 410 y = recurrent_layer() |
395 y.value = tanh(dotW(x, n=50) + dotW(y.tm(1),50)) | 411 y.value = tanh(dotW(x, n=50).t(0) + dotW(y.t(-1),50)) |
396 y.t0 = zeros( (50,)) | 412 y.t0 = zeros( (50,)) |
397 out = dotW(y,10) | 413 out = dotW(y,10) |
398 | 414 |
399 | 415 |
400 # sketch of writing CDk starting from x | 416 # sketch of writing CDk starting from x |