Mercurial > ift6266
changeset 361:b599886e3655
Ajout d'une fonctionnalite utile avec le programme voir_erreurs.py afin de voir les exemples ainsi que la prediction du modele donne dans le fichier config.py
author | SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca> |
---|---|
date | Thu, 22 Apr 2010 13:17:19 -0400 |
parents | f37c0705649d |
children | 793e89fcdab7 |
files | deep/stacked_dae/v_sylvain/sgd_optimization.py |
diffstat | 1 files changed, 141 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- a/deep/stacked_dae/v_sylvain/sgd_optimization.py Thu Apr 22 10:34:26 2010 -0400 +++ b/deep/stacked_dae/v_sylvain/sgd_optimization.py Thu Apr 22 13:17:19 2010 -0400 @@ -377,6 +377,147 @@ train_losses2 = [test_model(x,y) for x,y in iter2] train_score2 = numpy.mean(train_losses2) print "Training error is: " + str(train_score2) + + #To see the prediction of the model, the real answer and the image to judge + def see_error(self, dataset): + import pylab + #The function to know the prediction + test_model = \ + theano.function( + [self.classifier.x,self.classifier.y], self.classifier.logLayer.y_pred) + user = [] + nb_total = 0 #total number of exemples seen + nb_error = 0 #total number of errors + for x,y in dataset.test(1): + nb_total += 1 + pred = self.translate(test_model(x,y)) + rep = self.translate(y) + error = pred != rep + print 'prediction: ' + str(pred) +'\t answer: ' + str(rep) + '\t right: ' + str(not(error)) + pylab.imshow(x.reshape((32,32))) + pylab.draw() + if error: + nb_error += 1 + user.append(int(raw_input("1 = The error is normal, 0 = The error is not normal : "))) + print '\t\t character is hard to distinguish: ' + str(user[-1]) + else: + time.sleep(3) + print '\n Over the '+str(nb_total)+' exemples, there is '+str(nb_error)+' errors. \nThe percentage of errors is'+ str(float(nb_error)/float(nb_total)) + print 'The percentage of errors done by the model that an human will also do: ' + str(numpy.mean(user)) + + + + + #To translate the numeric prediction in character if necessary + def translate(self,y): + + if y <= 9: + return y[0] + elif y == 10: + return 'A' + elif y == 11: + return 'B' + elif y == 12: + return 'C' + elif y == 13: + return 'D' + elif y == 14: + return 'E' + elif y == 15: + return 'F' + elif y == 16: + return 'G' + elif y == 17: + return 'H' + elif y == 18: + return 'I' + elif y == 19: + return 'J' + elif y == 20: + return 'K' + elif y == 21: + return 'L' + elif y == 22: + return 'M' + elif y == 23: + return 'N' + elif y == 24: + return 'O' + elif y == 25: + return 'P' + elif y == 26: + return 'Q' + elif y == 27: + return 'R' + elif y == 28: + return 'S' + elif y == 28: + return 'T' + elif y == 30: + return 'U' + elif y == 31: + return 'V' + elif y == 32: + return 'W' + elif y == 33: + return 'X' + elif y == 34: + return 'Y' + elif y == 35: + return 'Z' + + elif y == 36: + return 'a' + elif y == 37: + return 'b' + elif y == 38: + return 'c' + elif y == 39: + return 'd' + elif y == 40: + return 'e' + elif y == 41: + return 'f' + elif y == 42: + return 'g' + elif y == 43: + return 'h' + elif y == 44: + return 'i' + elif y == 45: + return 'j' + elif y == 46: + return 'k' + elif y == 47: + return 'l' + elif y == 48: + return 'm' + elif y == 49: + return 'n' + elif y == 50: + return 'o' + elif y == 51: + return 'p' + elif y == 52: + return 'q' + elif y == 53: + return 'r' + elif y == 54: + return 's' + elif y == 55: + return 't' + elif y == 56: + return 'u' + elif y == 57: + return 'v' + elif y == 58: + return 'w' + elif y == 59: + return 'x' + elif y == 60: + return 'y' + elif y == 61: + return 'z'