Mercurial > ift6266
comparison baseline/mlp/mlp_nist.py @ 323:7a7615f940e8
finished code clean up and testing
author | xaviermuller |
---|---|
date | Thu, 08 Apr 2010 11:01:55 -0400 |
parents | 743907366476 |
children | 1763c64030d1 |
comparison
equal
deleted
inserted
replaced
322:743907366476 | 323:7a7615f940e8 |
---|---|
180 | 180 |
181 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr] | 181 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr] |
182 | 182 |
183 #save initial learning rate if classical adaptive lr is used | 183 #save initial learning rate if classical adaptive lr is used |
184 initial_lr=learning_rate | 184 initial_lr=learning_rate |
185 max_div_count=3 | |
186 | |
185 | 187 |
186 total_validation_error_list = [] | 188 total_validation_error_list = [] |
187 total_train_error_list = [] | 189 total_train_error_list = [] |
188 learning_rate_list=[] | 190 learning_rate_list=[] |
189 best_training_error=float('inf'); | 191 best_training_error=float('inf'); |
192 divergence_flag_list=[] | |
190 | 193 |
191 if data_set==0: | 194 if data_set==0: |
192 dataset=datasets.nist_all() | 195 dataset=datasets.nist_all() |
196 elif data_set==1: | |
197 dataset=datasets.nist_P07() | |
193 | 198 |
194 | 199 |
195 | 200 |
196 | 201 |
197 ishape = (32,32) # this is the size of NIST images | 202 ishape = (32,32) # this is the size of NIST images |
248 | 253 |
249 | 254 |
250 | 255 |
251 | 256 |
252 #conditions for stopping the adaptation: | 257 #conditions for stopping the adaptation: |
253 #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size) | 258 #1) we have reached nb_max_exemples (this is rounded up to be a multiple of the train size so we always do at least 1 epoch) |
254 #2) validation error is going up twice in a row(probable overfitting) | 259 #2) validation error is going up twice in a row(probable overfitting) |
255 | 260 |
256 # This means we no longer stop on slow convergence as low learning rates stopped | 261 # This means we no longer stop on slow convergence as low learning rates stopped |
257 # too fast. | 262 # too fast but instead we will wait for the valid error going up 3 times in a row |
258 | 263 # We save the curb of the validation error so we can always go back to check on it |
259 #approximate number of samples in the training set | 264 # and we save the absolute best model anyway, so we might as well explore |
265 # a bit when diverging | |
266 | |
267 #approximate number of samples in the nist training set | |
260 #this is just to have a validation frequency | 268 #this is just to have a validation frequency |
261 #roughly proportionnal to the training set | 269 #roughly proportionnal to the original nist training set |
262 n_minibatches = 650000/batch_size | 270 n_minibatches = 650000/batch_size |
263 | 271 |
264 | 272 |
265 patience =nb_max_exemples/batch_size #in units of minibatch | 273 patience =2*nb_max_exemples/batch_size #in units of minibatch |
266 patience_increase = 2 # wait this much longer when a new best is | |
267 # found | |
268 improvement_threshold = 0.995 # a relative improvement of this much is | |
269 # considered significant | |
270 validation_frequency = n_minibatches/4 | 274 validation_frequency = n_minibatches/4 |
271 | 275 |
272 | 276 |
273 | 277 |
274 | 278 |
279 start_time = time.clock() | 283 start_time = time.clock() |
280 time_n=0 #in unit of exemples | 284 time_n=0 #in unit of exemples |
281 minibatch_index=0 | 285 minibatch_index=0 |
282 epoch=0 | 286 epoch=0 |
283 temp=0 | 287 temp=0 |
288 divergence_flag=0 | |
284 | 289 |
285 | 290 |
286 | 291 |
287 if verbose == 1: | 292 if verbose == 1: |
288 print 'looking at most at %i exemples' %nb_max_exemples | 293 print 'starting training' |
289 while(minibatch_index*batch_size<nb_max_exemples): | 294 while(minibatch_index*batch_size<nb_max_exemples): |
290 | 295 |
291 for x, y in dataset.train(batch_size): | 296 for x, y in dataset.train(batch_size): |
292 | 297 |
293 | 298 #if we are using the classic learning rate deacay, adjust it before training of current mini-batch |
294 minibatch_index = minibatch_index + 1 | |
295 if adaptive_lr==2: | 299 if adaptive_lr==2: |
296 classifier.lr.value = tau*initial_lr/(tau+time_n) | 300 classifier.lr.value = tau*initial_lr/(tau+time_n) |
297 | 301 |
298 | 302 |
299 #train model | 303 #train model |
300 cost_ij = train_model(x,y) | 304 cost_ij = train_model(x,y) |
301 | 305 |
302 if (minibatch_index+1) % validation_frequency == 0: | 306 if (minibatch_index+1) % validation_frequency == 0: |
303 | |
304 #save the current learning rate | 307 #save the current learning rate |
305 learning_rate_list.append(classifier.lr.value) | 308 learning_rate_list.append(classifier.lr.value) |
309 divergence_flag_list.append(divergence_flag) | |
306 | 310 |
307 # compute the validation error | 311 # compute the validation error |
308 this_validation_loss = 0. | 312 this_validation_loss = 0. |
309 temp=0 | 313 temp=0 |
310 for xv,yv in dataset.valid(1): | 314 for xv,yv in dataset.valid(1): |
311 # sum up the errors for each minibatch | 315 # sum up the errors for each minibatch |
312 axxa=test_model(xv,yv) | 316 this_validation_loss += test_model(xv,yv) |
313 this_validation_loss += axxa | |
314 temp=temp+1 | 317 temp=temp+1 |
315 # get the average by dividing with the number of minibatches | 318 # get the average by dividing with the number of minibatches |
316 this_validation_loss /= temp | 319 this_validation_loss /= temp |
317 #save the validation loss | 320 #save the validation loss |
318 total_validation_error_list.append(this_validation_loss) | 321 total_validation_error_list.append(this_validation_loss) |
324 # if we got the best validation score until now | 327 # if we got the best validation score until now |
325 if this_validation_loss < best_validation_loss: | 328 if this_validation_loss < best_validation_loss: |
326 # save best validation score and iteration number | 329 # save best validation score and iteration number |
327 best_validation_loss = this_validation_loss | 330 best_validation_loss = this_validation_loss |
328 best_iter = minibatch_index | 331 best_iter = minibatch_index |
329 # reset patience if we are going down again | 332 #reset divergence flag |
330 # so we continue exploring | 333 divergence_flag=0 |
331 patience=nb_max_exemples/batch_size | 334 |
335 #save the best model. Overwrite the current saved best model so | |
336 #we only keep the best | |
337 numpy.savez('best_model.npy', config=configuration, W1=classifier.W1.value, W2=classifier.W2.value, b1=classifier.b1.value,\ | |
338 b2=classifier.b2.value, minibatch_index=minibatch_index) | |
339 | |
332 # test it on the test set | 340 # test it on the test set |
333 test_score = 0. | 341 test_score = 0. |
334 temp =0 | 342 temp =0 |
335 for xt,yt in dataset.test(batch_size): | 343 for xt,yt in dataset.test(batch_size): |
336 test_score += test_model(xt,yt) | 344 test_score += test_model(xt,yt) |
341 'model %f %%') % | 349 'model %f %%') % |
342 (epoch, minibatch_index+1, | 350 (epoch, minibatch_index+1, |
343 test_score*100.)) | 351 test_score*100.)) |
344 | 352 |
345 # if the validation error is going up, we are overfitting (or oscillating) | 353 # if the validation error is going up, we are overfitting (or oscillating) |
346 # stop converging but run at least to next validation | 354 # check if we are allowed to continue and if we will adjust the learning rate |
347 # to check overfitting or ocsillation | |
348 # the saved weights of the model will be a bit off in that case | |
349 elif this_validation_loss >= best_validation_loss: | 355 elif this_validation_loss >= best_validation_loss: |
356 | |
357 | |
358 # In non-classic learning rate decay, we modify the weight only when | |
359 # validation error is going up | |
360 if adaptive_lr==1: | |
361 classifier.lr.value=classifier.lr.value*lr_t2_factor | |
362 | |
363 | |
364 #cap the patience so we are allowed to diverge max_div_count times | |
365 #if we are going up max_div_count in a row, we will stop immediatelty by modifying the patience | |
366 divergence_flag = divergence_flag +1 | |
367 | |
368 | |
350 #calculate the test error at this point and exit | 369 #calculate the test error at this point and exit |
351 # test it on the test set | 370 # test it on the test set |
352 # however, if adaptive_lr is true, try reducing the lr to | |
353 # get us out of an oscilliation | |
354 if adaptive_lr==1: | |
355 classifier.lr.value=classifier.lr.value*lr_t2_factor | |
356 | |
357 test_score = 0. | 371 test_score = 0. |
358 #cap the patience so we are allowed one more validation error | |
359 #calculation before aborting | |
360 patience = minibatch_index+validation_frequency+1 | |
361 temp=0 | 372 temp=0 |
362 for xt,yt in dataset.test(batch_size): | 373 for xt,yt in dataset.test(batch_size): |
363 test_score += test_model(xt,yt) | 374 test_score += test_model(xt,yt) |
364 temp=temp+1 | 375 temp=temp+1 |
365 test_score /= temp | 376 test_score /= temp |
370 (epoch, minibatch_index+1, | 381 (epoch, minibatch_index+1, |
371 test_score*100.)) | 382 test_score*100.)) |
372 | 383 |
373 | 384 |
374 | 385 |
375 | 386 # check early stop condition |
376 if minibatch_index>patience: | 387 if divergence_flag==max_div_count: |
377 print 'we have diverged' | 388 minibatch_index=nb_max_exemples |
389 print 'we have diverged, early stopping kicks in' | |
378 break | 390 break |
391 | |
392 #check if we have seen enough exemples | |
393 #force one epoch at least | |
394 if epoch>0 and minibatch_index*batch_size>nb_max_exemples: | |
395 break | |
379 | 396 |
380 | 397 |
381 time_n= time_n + batch_size | 398 time_n= time_n + batch_size |
399 minibatch_index = minibatch_index + 1 | |
400 | |
401 # we have finished looping through the training set | |
382 epoch = epoch+1 | 402 epoch = epoch+1 |
383 end_time = time.clock() | 403 end_time = time.clock() |
384 if verbose == 1: | 404 if verbose == 1: |
385 print(('Optimization complete. Best validation score of %f %% ' | 405 print(('Optimization complete. Best validation score of %f %% ' |
386 'obtained at iteration %i, with test performance %f %%') % | 406 'obtained at iteration %i, with test performance %f %%') % |
389 print minibatch_index | 409 print minibatch_index |
390 | 410 |
391 #save the model and the weights | 411 #save the model and the weights |
392 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) | 412 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) |
393 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ | 413 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ |
394 learning_rate_list=learning_rate_list) | 414 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list) |
395 | 415 |
396 return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60) | 416 return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60) |
397 | 417 |
398 | 418 |
399 if __name__ == '__main__': | 419 if __name__ == '__main__': |
408 verbose = state.verbose,\ | 428 verbose = state.verbose,\ |
409 train_data = state.train_data,\ | 429 train_data = state.train_data,\ |
410 train_labels = state.train_labels,\ | 430 train_labels = state.train_labels,\ |
411 test_data = state.test_data,\ | 431 test_data = state.test_data,\ |
412 test_labels = state.test_labels,\ | 432 test_labels = state.test_labels,\ |
413 lr_t2_factor=state.lr_t2_factor) | 433 lr_t2_factor=state.lr_t2_factor,\ |
434 data_set=state.data_set) | |
414 state.train_error=train_error | 435 state.train_error=train_error |
415 state.validation_error=validation_error | 436 state.validation_error=validation_error |
416 state.test_error=test_error | 437 state.test_error=test_error |
417 state.nb_exemples=nb_exemples | 438 state.nb_exemples=nb_exemples |
418 state.time=time | 439 state.time=time |