comparison doc/v2_planning/plugin_RP.py @ 1214:681b5e7e3b81

a few comments on James version
author Razvan Pascanu <r.pascanu@gmail.com>
date Wed, 22 Sep 2010 10:39:39 -0400
parents 7fff3d5c7694
children
comparison
equal deleted inserted replaced
1213:33513a46c41b 1214:681b5e7e3b81
52 52
53 Step 2 53 Step 2
54 ====== 54 ======
55 55
56 I will start with step 2 ( because I think that is more of a hot subject 56 I will start with step 2 ( because I think that is more of a hot subject
57 right now). I will assume you have the write plugins at had. 57 right now). I will assume you have the right plugins at hand.
58 This is a DBN with early stopping and .. 58 This is a DBN with early stopping and ..
59 59
60 .. code-block:: python 60 .. code-block:: python
61 ''' 61 '''
62 data = load_mnist() 62 data = load_mnist()
74 ## Layer 1: 74 ## Layer 1:
75 h1 = sigmoid(dotW_b(x0,units = 200), constraint = L1( coeff = 0.1)) 75 h1 = sigmoid(dotW_b(x0,units = 200), constraint = L1( coeff = 0.1))
76 x1 = recurrent_layer() 76 x1 = recurrent_layer()
77 x1.t0 = x0 77 x1.t0 = x0
78 x1.value = binomial_sample(sigmoid( reconstruct( binomial_sample(h1), x0))) 78 x1.value = binomial_sample(sigmoid( reconstruct( binomial_sample(h1), x0)))
79 cost = free_energy(train_x) - free_energy(x1.tp(5)) 79 cost = free_energy(train_x) - free_energy(x1.t(5))
80 grads = [ (g.var, T.grad(cost.var, g.var)) for g in cost.params ] 80 grads = [ (g.var, T.grad(cost.var, g.var)) for g in cost.params ]
81 pseudo_cost = sum([ pl.sum(pl.abs(g)) for g in cost.params]) 81 pseudo_cost = sum([ pl.sum(pl.abs(g)) for g in cost.params])
82 rbm1 = SGD( cost = pseudo_cost, grads = grads) 82 rbm1 = SGD( cost = pseudo_cost, grads = grads)
83 83
84 # Layer 2: 84 # Layer 2:
94 94
95 ca = Schedular() 95 ca = Schedular()
96 96
97 97
98 ### Constructing Modes ### 98 ### Constructing Modes ###
99 pretrain_layer1 = ca.mode('pretrain0') 99 class pretrain_layer1 ()
100
101 def register()
102 {
103 }
100 pretrain_layer2 = ca.mode('pretrain1') 104 pretrain_layer2 = ca.mode('pretrain1')
101 early_stopping = ca.mode('early') 105 early_stopping = ca.mode('early')
102 valid1 = ca.mode('stuff') 106 code_block = ca.mode('code_block')
103 kfolds = ca.mode('kfolds') 107 kfolds = ca.mode('kfolds')
104 108
105 # Construct modes dependency graph 109 # Construct modes dependency graph
106 valid0.include([ pretrian_layer1, pretrain_layer2, early_stopper]) 110 code_block.include([ pretrian_layer1, pretrain_layer2, early_stopper])
107 kfolds.include( valid0 ) 111 kfolds.include( code_block )
108 112
109 pretrain_layer1.act( on = valid1.begin(), when = always()) 113 pretrain_layer1.act( on = code_block.begin(), when = always())
110 pretrain_layer2.act( on = pretrain_layer1.end(), when = always()) 114 pretrain_layer2.act( on = pretrain_layer1.end(), when = always())
111 early_stopping.act ( on = pretrain_layer2.end(), when = always()) 115 early_stopping.act ( on = pretrain_layer2.end(), when = always())
112 116
113 117
114 # Construct counter plugin that keeps track of number of epochs 118 # Construct counter plugin that keeps track of number of epochs
126 else: 130 else:
127 self.fire(Message('terminate')) 131 self.fire(Message('terminate'))
128 132
129 133
130 # Construct pre-training plugins 134 # Construct pre-training plugins
131 rbm1_plugin = plugin_wrapper(rbm1, sched = pretrain_layer1) 135 rbm1_plugin = pretrain_layer1.include(plugin_wrapper(rbm1))
136 rbm2_plugin = pretrain_layer2.include(plugin_wrapper(rbm2))
137 rbm1_counter = pretrain_layer1.include(counter)
138 rbm2_counter = pretrain_layer2.include(counter)
139
132 rbm1_plugin.listen(Message('init'), update_hyperparameters) 140 rbm1_plugin.listen(Message('init'), update_hyperparameters)
133 rbm2_plugin = plugin_wrapper(rbm2, sched = pretrain_layer2) 141 rbm1_plugin.listen(Message('continue'), dataset_restart)
134 rbm2_plugin.listen(Message('init'), update_hyperparameters) 142 rbm2_plugin.listen(Message('init'), update_hyperparameters)
135 rbm1_counter = pretrain_layer1.register(counter) 143 rbm2_plugin.listen(Message('continue'), dataset_restart)
136 rbm2_counter = pretrain_layer2.register(counter)
137 144
138 145
139 # Dependency graph for pre-training layer 0 146 # Dependency graph for pre-training layer 0
140 rbm1_plugin.act( on = [ pretrain_layer1.begin() 147 rbm1_plugin.act( on = [ pretrain_layer1.begin() ,
141 Message('continue') ], 148 rbm1_plugin.value() ] ,
142 when = always()) 149 when = always())
143 rbm1_counter.act( on = rbm1_plugin.eod(), when = always() ) 150 rbm1_counter.act( on = rbm1_plugin.eod(), when = always() )
144 151
145 152
146 # Dependency graph for pre-training layer 1 153 # Dependency graph for pre-training layer 1
147 rbm2_plugin.act( on = pretrain_layer2.begin(), when = always()) 154 rbm2_plugin.act( on = [ pretrain_layer2.begin() ,
155 rbm2_plugin.value() ] ,
156 when = always())
148 pretrain_layer2.stop( on = rbm2_plugin.eod(), when = always()) 157 pretrain_layer2.stop( on = rbm2_plugin.eod(), when = always())
149 158
150 159
151 # Constructing fine-tunning plugins 160 # Constructing fine-tunning plugins
152 learner = early_stopper.register(plugin_wrapper(logreg)) 161 learner = early_stopper.include(plugin_wrapper(logreg))
162 validation = early_stopper.include( plugin_wrapper(valid_err)))
163 clock = early_stopper.include( ca.generate_clock())
164 early_stopper_plugin = early_stopper.include( early_stopper_plugin)
165
166
167 def save_model(plugin):
168 cPickle.dump(plugin.object, 'just_the_model.pkl')
169
153 learner.listen(Message('init'), update_hyperparameters) 170 learner.listen(Message('init'), update_hyperparameters)
154 validation = early_stopper.register( plugin_wrapper(valid_err)))
155 validation.listen(Message('init'), update_hyperparameters) 171 validation.listen(Message('init'), update_hyperparameters)
156 clock = early_stopper.register( ca.generate_clock()) 172 validation.listen(early_stopper_plugin.new_best_score(), save_model)
157 early_stopper_plugin = early_stopper.register( early_stopper_plugin)
158
159 @FnPlugin
160 def save_weights(self, message):
161 cPickle.dump(logreg, open('model.pkl'))
162
163 173
164 learner.act( on = early_stopper.begin(), when = always()) 174 learner.act( on = early_stopper.begin(), when = always())
165 learner.act( on = learner.value(), when = always()) 175 learner.act( on = learner.value(), when = always())
166 validation.act( on = clock.hour(), when = every(n = 1)) 176 validation.act( on = clock.hour(), when = every(n = 1))
167 early_stopper.act( on = validation.value(), when = always()) 177 early_stopper.act( on = validation.value(), when = always())
168 save_model.act( on = early_stopper.new_best_error(), when =always())
169 178
170 @FnPlugin 179 @FnPlugin
171 def kfolds_plugin(self,event): 180 def kfolds_plugin(self,event):
172 if not hasattr(self, 'n'): 181 if not hasattr(self, 'n'):
173 self.n = -1 182 self.n = -1
181 self.fire(msg) 190 self.fire(msg)
182 else: 191 else:
183 self.fire(Message('terminate')) 192 self.fire(Message('terminate'))
184 193
185 194
186 kfolds.register(kfolds_plugin) 195 kfolds.include(kfolds_plugin)
187 kfolds_plugin.act(kfolds.begin(), when = always()) 196 kfolds_plugin.act([kfolds.begin(), Message('new split')], when = always())
188 kfolds_plugin.act(valid0.end(), always() ) 197 kfolds_plugin.act(code_block.end(), always() )
189 valid0.act(Message('new split'), always() ) 198 code_block.act(Message('new split'), always() )
190 199
191 sched.include(kfolds) 200 sched.include(kfolds)
192 201
193 sched.run() 202 sched.run()
194 203
195 ''' 204 '''
205
206
196 207
197 Notes: 208 Notes:
198 when a mode is regstered to begin with a certain message, it will 209 when a mode is regstered to begin with a certain message, it will
199 rebroadcast that message when it starts, with only switching the 210 rebroadcast that message when it starts, with only switching the
200 type from whatever it was to 'init'. It will also send all 'init' messages 211 type from whatever it was to 'init'. It will also send all 'init' messages
273 be pretty useful : 284 be pretty useful :
274 285
275 * .replace(dict) -> method; replaces the subgraphs given as keys with 286 * .replace(dict) -> method; replaces the subgraphs given as keys with
276 the ones given as values; throws an exception if it 287 the ones given as values; throws an exception if it
277 is impossible 288 is impossible
289
290 * replace(nodes, dict) -> function; call replace on all nodes given that dictionary
278 291
279 * reconstruct(dict) -> transform; tries to reconstruct the nodes given as 292 * reconstruct(dict) -> transform; tries to reconstruct the nodes given as
280 keys starting from the nodes given as values by 293 keys starting from the nodes given as values by
281 going through the inverse of all transforms that 294 going through the inverse of all transforms that
282 are in between 295 are in between
296 machines 309 machines
297 310
298 * switch(hyperparam, dict) -> transform; a lazy switch that allows you 311 * switch(hyperparam, dict) -> transform; a lazy switch that allows you
299 do construct by hyper-parameters 312 do construct by hyper-parameters
300 313
301 * get_hyperparameter -> method; given a name it will return the first node 314 * get_hyperparameter(name) -> method; given a name it will return the first node
302 starting from top that is a hyper parameter and has 315 starting from top that is a hyper parameter and has
303 that name 316 that name
304 * get_parameter -> method; given a name it will return the first node 317 * get_parameter(name) -> method; given a name it will return the first node
305 starting from top that is a parameter and has that 318 starting from top that is a parameter and has that
306 name 319 name
320 * get_hyperparameters()
321 * get_parameters()
322
307 323
308 324
309 325
310 Because every node provides the dataset API it means you can iterate over 326 Because every node provides the dataset API it means you can iterate over
311 any of the nodes. They will produce the original dataset transformed up 327 any of the nodes. They will produce the original dataset transformed up
390 406
391 ''' 407 '''
392 # sketch of writing a RNN 408 # sketch of writing a RNN
393 x = load_mnist() 409 x = load_mnist()
394 y = recurrent_layer() 410 y = recurrent_layer()
395 y.value = tanh(dotW(x, n=50) + dotW(y.tm(1),50)) 411 y.value = tanh(dotW(x, n=50).t(0) + dotW(y.t(-1),50))
396 y.t0 = zeros( (50,)) 412 y.t0 = zeros( (50,))
397 out = dotW(y,10) 413 out = dotW(y,10)
398 414
399 415
400 # sketch of writing CDk starting from x 416 # sketch of writing CDk starting from x