Mercurial > ift6266
comparison baseline/mlp/mlp_nist.py @ 414:3dba84c0fbc1
saving test score from best validation score in db now
author | xaviermuller |
---|---|
date | Thu, 29 Apr 2010 17:04:12 -0400 |
parents | 195f95c3d461 |
children | 868f82777839 |
comparison
equal
deleted
inserted
replaced
413:f2dd75248483 | 414:3dba84c0fbc1 |
---|---|
192 min_error_count=0.0 | 192 min_error_count=0.0 |
193 min_exemple_count=0.0 | 193 min_exemple_count=0.0 |
194 | 194 |
195 maj_error_count=0.0 | 195 maj_error_count=0.0 |
196 maj_exemple_count=0.0 | 196 maj_exemple_count=0.0 |
197 | |
198 vtotal_error_count=0.0 | |
199 vtotal_exemple_count=0.0 | |
200 | |
201 vnb_error_count=0.0 | |
202 vnb_exemple_count=0.0 | |
203 | |
204 vchar_error_count=0.0 | |
205 vchar_exemple_count=0.0 | |
206 | |
207 vmin_error_count=0.0 | |
208 vmin_exemple_count=0.0 | |
209 | |
210 vmaj_error_count=0.0 | |
211 vmaj_exemple_count=0.0 | |
197 | 212 |
198 | 213 |
199 | 214 |
200 if data_set==0: | 215 if data_set==0: |
201 dataset=datasets.nist_all() | 216 dataset=datasets.nist_all() |
254 predicted_class=numpy.argmax(a1_out[36:62])+36 | 269 predicted_class=numpy.argmax(a1_out[36:62])+36 |
255 if(predicted_class!=wanted_class): | 270 if(predicted_class!=wanted_class): |
256 min_error_count = min_error_count +1 | 271 min_error_count = min_error_count +1 |
257 | 272 |
258 | 273 |
274 | |
275 vtest_score=0 | |
276 vtemp=0 | |
277 for xt,yt in dataset.valid(1): | |
278 | |
279 vtotal_exemple_count = vtotal_exemple_count +1 | |
280 #get activation for layer 1 | |
281 a0=numpy.dot(numpy.transpose(W1),numpy.transpose(xt[0])) + b1 | |
282 #add non linear function to layer 1 activation | |
283 a0_out=numpy.tanh(a0) | |
284 | |
285 #get activation for output layer | |
286 a1= numpy.dot(numpy.transpose(W2),a0_out) + b2 | |
287 #add non linear function for output activation (softmax) | |
288 a1_exp = numpy.exp(a1) | |
289 sum_a1=numpy.sum(a1_exp) | |
290 a1_out=a1_exp/sum_a1 | |
291 | |
292 predicted_class=numpy.argmax(a1_out) | |
293 wanted_class=yt[0] | |
294 if(predicted_class!=wanted_class): | |
295 vtotal_error_count = vtotal_error_count +1 | |
296 | |
297 #treat digit error | |
298 if(wanted_class<10): | |
299 vnb_exemple_count=vnb_exemple_count + 1 | |
300 predicted_class=numpy.argmax(a1_out[0:10]) | |
301 if(predicted_class!=wanted_class): | |
302 vnb_error_count = vnb_error_count +1 | |
303 | |
304 if(wanted_class>9): | |
305 vchar_exemple_count=vchar_exemple_count + 1 | |
306 predicted_class=numpy.argmax(a1_out[10:62])+10 | |
307 if((predicted_class!=wanted_class) and ((predicted_class+26)!=wanted_class) and ((predicted_class-26)!=wanted_class)): | |
308 vchar_error_count = vchar_error_count +1 | |
309 | |
310 #minuscule | |
311 if(wanted_class>9 and wanted_class<36): | |
312 vmaj_exemple_count=vmaj_exemple_count + 1 | |
313 predicted_class=numpy.argmax(a1_out[10:35])+10 | |
314 if(predicted_class!=wanted_class): | |
315 vmaj_error_count = vmaj_error_count +1 | |
316 #majuscule | |
317 if(wanted_class>35): | |
318 vmin_exemple_count=vmin_exemple_count + 1 | |
319 predicted_class=numpy.argmax(a1_out[36:62])+36 | |
320 if(predicted_class!=wanted_class): | |
321 vmin_error_count = vmin_error_count +1 | |
322 | |
259 | 323 |
260 print (('total error = %f') % ((total_error_count/total_exemple_count)*100.0)) | 324 print (('total error = %f') % ((total_error_count/total_exemple_count)*100.0)) |
261 print (('number error = %f') % ((nb_error_count/nb_exemple_count)*100.0)) | 325 print (('number error = %f') % ((nb_error_count/nb_exemple_count)*100.0)) |
262 print (('char error = %f') % ((char_error_count/char_exemple_count)*100.0)) | 326 print (('char error = %f') % ((char_error_count/char_exemple_count)*100.0)) |
263 print (('min error = %f') % ((min_error_count/min_exemple_count)*100.0)) | 327 print (('min error = %f') % ((min_error_count/min_exemple_count)*100.0)) |
264 print (('maj error = %f') % ((maj_error_count/maj_exemple_count)*100.0)) | 328 print (('maj error = %f') % ((maj_error_count/maj_exemple_count)*100.0)) |
329 | |
330 print (('valid total error = %f') % ((vtotal_error_count/vtotal_exemple_count)*100.0)) | |
331 print (('valid number error = %f') % ((vnb_error_count/vnb_exemple_count)*100.0)) | |
332 print (('valid char error = %f') % ((vchar_error_count/vchar_exemple_count)*100.0)) | |
333 print (('valid min error = %f') % ((vmin_error_count/vmin_exemple_count)*100.0)) | |
334 print (('valid maj error = %f') % ((vmaj_error_count/vmaj_exemple_count)*100.0)) | |
335 | |
336 print ((' num total = %d,%d') % (total_exemple_count,total_error_count)) | |
337 print ((' num nb = %d,%d') % (nb_exemple_count,nb_error_count)) | |
338 print ((' num min = %d,%d') % (min_exemple_count,min_error_count)) | |
339 print ((' num maj = %d,%d') % (maj_exemple_count,maj_error_count)) | |
340 print ((' num char = %d,%d') % (char_exemple_count,char_error_count)) | |
265 return (total_error_count/total_exemple_count)*100.0 | 341 return (total_error_count/total_exemple_count)*100.0 |
266 | 342 |
267 | 343 |
268 | 344 |
269 | 345 |
290 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr] | 366 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr] |
291 | 367 |
292 #save initial learning rate if classical adaptive lr is used | 368 #save initial learning rate if classical adaptive lr is used |
293 initial_lr=learning_rate | 369 initial_lr=learning_rate |
294 max_div_count=1000 | 370 max_div_count=1000 |
371 optimal_test_error=0 | |
295 | 372 |
296 | 373 |
297 total_validation_error_list = [] | 374 total_validation_error_list = [] |
298 total_train_error_list = [] | 375 total_train_error_list = [] |
299 learning_rate_list=[] | 376 learning_rate_list=[] |
480 print(('epoch %i, minibatch %i, test error of best ' | 557 print(('epoch %i, minibatch %i, test error of best ' |
481 'model %f %%') % | 558 'model %f %%') % |
482 (epoch, minibatch_index+1, | 559 (epoch, minibatch_index+1, |
483 test_score*100.)) | 560 test_score*100.)) |
484 sys.stdout.flush() | 561 sys.stdout.flush() |
562 optimal_test_error=test_score | |
485 | 563 |
486 # if the validation error is going up, we are overfitting (or oscillating) | 564 # if the validation error is going up, we are overfitting (or oscillating) |
487 # check if we are allowed to continue and if we will adjust the learning rate | 565 # check if we are allowed to continue and if we will adjust the learning rate |
488 elif this_validation_loss >= best_validation_loss: | 566 elif this_validation_loss >= best_validation_loss: |
489 | 567 |
549 #save the model and the weights | 627 #save the model and the weights |
550 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) | 628 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) |
551 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ | 629 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ |
552 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list) | 630 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list) |
553 | 631 |
554 return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60) | 632 return (best_training_error*100.0,best_validation_loss * 100.,optimal_test_error*100.,best_iter*batch_size,(end_time-start_time)/60) |
555 | 633 |
556 | 634 |
557 if __name__ == '__main__': | 635 if __name__ == '__main__': |
558 mlp_full_mnist() | 636 mlp_full_mnist() |
559 | 637 |