comparison baseline/mlp/mlp_nist.py @ 414:3dba84c0fbc1

saving test score from best validation score in db now
author xaviermuller
date Thu, 29 Apr 2010 17:04:12 -0400
parents 195f95c3d461
children 868f82777839
comparison
equal deleted inserted replaced
413:f2dd75248483 414:3dba84c0fbc1
192 min_error_count=0.0 192 min_error_count=0.0
193 min_exemple_count=0.0 193 min_exemple_count=0.0
194 194
195 maj_error_count=0.0 195 maj_error_count=0.0
196 maj_exemple_count=0.0 196 maj_exemple_count=0.0
197
198 vtotal_error_count=0.0
199 vtotal_exemple_count=0.0
200
201 vnb_error_count=0.0
202 vnb_exemple_count=0.0
203
204 vchar_error_count=0.0
205 vchar_exemple_count=0.0
206
207 vmin_error_count=0.0
208 vmin_exemple_count=0.0
209
210 vmaj_error_count=0.0
211 vmaj_exemple_count=0.0
197 212
198 213
199 214
200 if data_set==0: 215 if data_set==0:
201 dataset=datasets.nist_all() 216 dataset=datasets.nist_all()
254 predicted_class=numpy.argmax(a1_out[36:62])+36 269 predicted_class=numpy.argmax(a1_out[36:62])+36
255 if(predicted_class!=wanted_class): 270 if(predicted_class!=wanted_class):
256 min_error_count = min_error_count +1 271 min_error_count = min_error_count +1
257 272
258 273
274
275 vtest_score=0
276 vtemp=0
277 for xt,yt in dataset.valid(1):
278
279 vtotal_exemple_count = vtotal_exemple_count +1
280 #get activation for layer 1
281 a0=numpy.dot(numpy.transpose(W1),numpy.transpose(xt[0])) + b1
282 #add non linear function to layer 1 activation
283 a0_out=numpy.tanh(a0)
284
285 #get activation for output layer
286 a1= numpy.dot(numpy.transpose(W2),a0_out) + b2
287 #add non linear function for output activation (softmax)
288 a1_exp = numpy.exp(a1)
289 sum_a1=numpy.sum(a1_exp)
290 a1_out=a1_exp/sum_a1
291
292 predicted_class=numpy.argmax(a1_out)
293 wanted_class=yt[0]
294 if(predicted_class!=wanted_class):
295 vtotal_error_count = vtotal_error_count +1
296
297 #treat digit error
298 if(wanted_class<10):
299 vnb_exemple_count=vnb_exemple_count + 1
300 predicted_class=numpy.argmax(a1_out[0:10])
301 if(predicted_class!=wanted_class):
302 vnb_error_count = vnb_error_count +1
303
304 if(wanted_class>9):
305 vchar_exemple_count=vchar_exemple_count + 1
306 predicted_class=numpy.argmax(a1_out[10:62])+10
307 if((predicted_class!=wanted_class) and ((predicted_class+26)!=wanted_class) and ((predicted_class-26)!=wanted_class)):
308 vchar_error_count = vchar_error_count +1
309
310 #minuscule
311 if(wanted_class>9 and wanted_class<36):
312 vmaj_exemple_count=vmaj_exemple_count + 1
313 predicted_class=numpy.argmax(a1_out[10:35])+10
314 if(predicted_class!=wanted_class):
315 vmaj_error_count = vmaj_error_count +1
316 #majuscule
317 if(wanted_class>35):
318 vmin_exemple_count=vmin_exemple_count + 1
319 predicted_class=numpy.argmax(a1_out[36:62])+36
320 if(predicted_class!=wanted_class):
321 vmin_error_count = vmin_error_count +1
322
259 323
260 print (('total error = %f') % ((total_error_count/total_exemple_count)*100.0)) 324 print (('total error = %f') % ((total_error_count/total_exemple_count)*100.0))
261 print (('number error = %f') % ((nb_error_count/nb_exemple_count)*100.0)) 325 print (('number error = %f') % ((nb_error_count/nb_exemple_count)*100.0))
262 print (('char error = %f') % ((char_error_count/char_exemple_count)*100.0)) 326 print (('char error = %f') % ((char_error_count/char_exemple_count)*100.0))
263 print (('min error = %f') % ((min_error_count/min_exemple_count)*100.0)) 327 print (('min error = %f') % ((min_error_count/min_exemple_count)*100.0))
264 print (('maj error = %f') % ((maj_error_count/maj_exemple_count)*100.0)) 328 print (('maj error = %f') % ((maj_error_count/maj_exemple_count)*100.0))
329
330 print (('valid total error = %f') % ((vtotal_error_count/vtotal_exemple_count)*100.0))
331 print (('valid number error = %f') % ((vnb_error_count/vnb_exemple_count)*100.0))
332 print (('valid char error = %f') % ((vchar_error_count/vchar_exemple_count)*100.0))
333 print (('valid min error = %f') % ((vmin_error_count/vmin_exemple_count)*100.0))
334 print (('valid maj error = %f') % ((vmaj_error_count/vmaj_exemple_count)*100.0))
335
336 print ((' num total = %d,%d') % (total_exemple_count,total_error_count))
337 print ((' num nb = %d,%d') % (nb_exemple_count,nb_error_count))
338 print ((' num min = %d,%d') % (min_exemple_count,min_error_count))
339 print ((' num maj = %d,%d') % (maj_exemple_count,maj_error_count))
340 print ((' num char = %d,%d') % (char_exemple_count,char_error_count))
265 return (total_error_count/total_exemple_count)*100.0 341 return (total_error_count/total_exemple_count)*100.0
266 342
267 343
268 344
269 345
290 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr] 366 configuration = [learning_rate,nb_max_exemples,nb_hidden,adaptive_lr]
291 367
292 #save initial learning rate if classical adaptive lr is used 368 #save initial learning rate if classical adaptive lr is used
293 initial_lr=learning_rate 369 initial_lr=learning_rate
294 max_div_count=1000 370 max_div_count=1000
371 optimal_test_error=0
295 372
296 373
297 total_validation_error_list = [] 374 total_validation_error_list = []
298 total_train_error_list = [] 375 total_train_error_list = []
299 learning_rate_list=[] 376 learning_rate_list=[]
480 print(('epoch %i, minibatch %i, test error of best ' 557 print(('epoch %i, minibatch %i, test error of best '
481 'model %f %%') % 558 'model %f %%') %
482 (epoch, minibatch_index+1, 559 (epoch, minibatch_index+1,
483 test_score*100.)) 560 test_score*100.))
484 sys.stdout.flush() 561 sys.stdout.flush()
562 optimal_test_error=test_score
485 563
486 # if the validation error is going up, we are overfitting (or oscillating) 564 # if the validation error is going up, we are overfitting (or oscillating)
487 # check if we are allowed to continue and if we will adjust the learning rate 565 # check if we are allowed to continue and if we will adjust the learning rate
488 elif this_validation_loss >= best_validation_loss: 566 elif this_validation_loss >= best_validation_loss:
489 567
549 #save the model and the weights 627 #save the model and the weights
550 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value) 628 numpy.savez('model.npy', config=configuration, W1=classifier.W1.value,W2=classifier.W2.value, b1=classifier.b1.value,b2=classifier.b2.value)
551 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\ 629 numpy.savez('results.npy',config=configuration,total_train_error_list=total_train_error_list,total_validation_error_list=total_validation_error_list,\
552 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list) 630 learning_rate_list=learning_rate_list, divergence_flag_list=divergence_flag_list)
553 631
554 return (best_training_error*100.0,best_validation_loss * 100.,test_score*100.,best_iter*batch_size,(end_time-start_time)/60) 632 return (best_training_error*100.0,best_validation_loss * 100.,optimal_test_error*100.,best_iter*batch_size,(end_time-start_time)/60)
555 633
556 634
557 if __name__ == '__main__': 635 if __name__ == '__main__':
558 mlp_full_mnist() 636 mlp_full_mnist()
559 637