Mercurial > ift6266
comparison deep/stacked_dae/sgd_optimization.py @ 284:8a3af19ae272
Enlevé mécanique pour limiter le nombre d'exemples utilisés (remplacé par paramètre dans l'appel au code de dataset), et ajouté option pour sauvegarde des poids à la fin de l'entraînement
author | fsavard |
---|---|
date | Wed, 24 Mar 2010 15:13:48 -0400 |
parents | 7b4507295eba |
children |
comparison
equal
deleted
inserted
replaced
279:206374eed2fb | 284:8a3af19ae272 |
---|---|
1 #!/usr/bin/python | 1 #!/usr/bin/python |
2 # coding: utf-8 | 2 # coding: utf-8 |
3 | 3 |
4 # Generic SdA optimization loop, adapted from the deeplearning.net tutorial | 4 # Generic SdA optimization loop, adapted from the deeplearning.net tutorial |
5 | |
6 from __future__ import with_statement | |
5 | 7 |
6 import numpy | 8 import numpy |
7 import theano | 9 import theano |
8 import time | 10 import time |
9 import datetime | 11 import datetime |
23 'validation_error' : DummySeries(), | 25 'validation_error' : DummySeries(), |
24 'test_error' : DummySeries(), | 26 'test_error' : DummySeries(), |
25 'params' : DummySeries() | 27 'params' : DummySeries() |
26 } | 28 } |
27 | 29 |
28 def itermax(iter, max): | |
29 for i,it in enumerate(iter): | |
30 if i >= max: | |
31 break | |
32 yield it | |
33 | |
34 class SdaSgdOptimizer: | 30 class SdaSgdOptimizer: |
35 def __init__(self, dataset, hyperparameters, n_ins, n_outs, | 31 def __init__(self, dataset, hyperparameters, n_ins, n_outs, |
36 examples_per_epoch, series=default_series, max_minibatches=None): | 32 examples_per_epoch, series=default_series, |
33 save_params=False): | |
37 self.dataset = dataset | 34 self.dataset = dataset |
38 self.hp = hyperparameters | 35 self.hp = hyperparameters |
39 self.n_ins = n_ins | 36 self.n_ins = n_ins |
40 self.n_outs = n_outs | 37 self.n_outs = n_outs |
41 | 38 |
42 self.max_minibatches = max_minibatches | 39 self.save_params = save_params |
43 print "SdaSgdOptimizer, max_minibatches =", max_minibatches | |
44 | 40 |
45 self.ex_per_epoch = examples_per_epoch | 41 self.ex_per_epoch = examples_per_epoch |
46 self.mb_per_epoch = examples_per_epoch / self.hp.minibatch_size | 42 self.mb_per_epoch = examples_per_epoch / self.hp.minibatch_size |
47 | 43 |
48 self.series = series | 44 self.series = series |
99 batch_index+=1 | 95 batch_index+=1 |
100 | 96 |
101 #if batch_index % 100 == 0: | 97 #if batch_index % 100 == 0: |
102 # print "100 batches" | 98 # print "100 batches" |
103 | 99 |
104 # useful when doing tests | |
105 if self.max_minibatches and batch_index >= self.max_minibatches: | |
106 break | |
107 | |
108 print 'Pre-training layer %i, epoch %d, cost '%(i,epoch),c | 100 print 'Pre-training layer %i, epoch %d, cost '%(i,epoch),c |
109 sys.stdout.flush() | 101 sys.stdout.flush() |
110 | 102 |
111 self.series['params'].append((epoch,), self.classifier.all_params) | 103 self.series['params'].append((epoch,), self.classifier.all_params) |
112 | 104 |
148 validation_frequency = min(self.mb_per_epoch, patience/2) | 140 validation_frequency = min(self.mb_per_epoch, patience/2) |
149 # go through this many | 141 # go through this many |
150 # minibatche before checking the network | 142 # minibatche before checking the network |
151 # on the validation set; in this case we | 143 # on the validation set; in this case we |
152 # check every epoch | 144 # check every epoch |
153 if self.max_minibatches and validation_frequency > self.max_minibatches: | |
154 validation_frequency = self.max_minibatches / 2 | |
155 | 145 |
156 best_params = None | 146 best_params = None |
157 best_validation_loss = float('inf') | 147 best_validation_loss = float('inf') |
158 test_score = 0. | 148 test_score = 0. |
159 start_time = time.clock() | 149 start_time = time.clock() |
174 self.series["training_error"].append((epoch, minibatch_index), cost_ij) | 164 self.series["training_error"].append((epoch, minibatch_index), cost_ij) |
175 | 165 |
176 if (total_mb_index+1) % validation_frequency == 0: | 166 if (total_mb_index+1) % validation_frequency == 0: |
177 | 167 |
178 iter = dataset.valid(minibatch_size) | 168 iter = dataset.valid(minibatch_size) |
179 if self.max_minibatches: | |
180 iter = itermax(iter, self.max_minibatches) | |
181 validation_losses = [validate_model(x,y) for x,y in iter] | 169 validation_losses = [validate_model(x,y) for x,y in iter] |
182 this_validation_loss = numpy.mean(validation_losses) | 170 this_validation_loss = numpy.mean(validation_losses) |
183 | 171 |
184 self.series["validation_error"].\ | 172 self.series["validation_error"].\ |
185 append((epoch, minibatch_index), this_validation_loss*100.) | 173 append((epoch, minibatch_index), this_validation_loss*100.) |
201 best_validation_loss = this_validation_loss | 189 best_validation_loss = this_validation_loss |
202 best_iter = total_mb_index | 190 best_iter = total_mb_index |
203 | 191 |
204 # test it on the test set | 192 # test it on the test set |
205 iter = dataset.test(minibatch_size) | 193 iter = dataset.test(minibatch_size) |
206 if self.max_minibatches: | |
207 iter = itermax(iter, self.max_minibatches) | |
208 test_losses = [test_model(x,y) for x,y in iter] | 194 test_losses = [test_model(x,y) for x,y in iter] |
209 test_score = numpy.mean(test_losses) | 195 test_score = numpy.mean(test_losses) |
210 | 196 |
211 self.series["test_error"].\ | 197 self.series["test_error"].\ |
212 append((epoch, minibatch_index), test_score*100.) | 198 append((epoch, minibatch_index), test_score*100.) |
215 'model %f %%') % | 201 'model %f %%') % |
216 (epoch, minibatch_index+1, self.mb_per_epoch, | 202 (epoch, minibatch_index+1, self.mb_per_epoch, |
217 test_score*100.)) | 203 test_score*100.)) |
218 | 204 |
219 sys.stdout.flush() | 205 sys.stdout.flush() |
220 | |
221 # useful when doing tests | |
222 if self.max_minibatches and minibatch_index >= self.max_minibatches: | |
223 break | |
224 | 206 |
225 self.series['params'].append((epoch,), self.classifier.all_params) | 207 self.series['params'].append((epoch,), self.classifier.all_params) |
226 | 208 |
227 if patience <= total_mb_index: | 209 if patience <= total_mb_index: |
228 done_looping = True | 210 done_looping = True |
232 self.hp.update({'finetuning_time':end_time-start_time,\ | 214 self.hp.update({'finetuning_time':end_time-start_time,\ |
233 'best_validation_error':best_validation_loss,\ | 215 'best_validation_error':best_validation_loss,\ |
234 'test_score':test_score, | 216 'test_score':test_score, |
235 'num_finetuning_epochs':epoch}) | 217 'num_finetuning_epochs':epoch}) |
236 | 218 |
219 if self.save_params: | |
220 save_params(self.classifier.all_params, "weights.dat") | |
221 | |
237 print(('Optimization complete with best validation score of %f %%,' | 222 print(('Optimization complete with best validation score of %f %%,' |
238 'with test performance %f %%') % | 223 'with test performance %f %%') % |
239 (best_validation_loss * 100., test_score*100.)) | 224 (best_validation_loss * 100., test_score*100.)) |
240 print ('The finetuning ran for %f minutes' % ((end_time-start_time)/60.)) | 225 print ('The finetuning ran for %f minutes' % ((end_time-start_time)/60.)) |
241 | 226 |
242 | 227 |
243 | 228 |
229 def save_params(all_params, filename): | |
230 import pickle | |
231 with open(filename, 'wb') as f: | |
232 values = [p.value for p in all_params] | |
233 | |
234 # -1 for HIGHEST_PROTOCOL | |
235 pickle.dump(values, f, -1) | |
236 |