diff deep/stacked_dae/v_sylvain/sgd_optimization.py @ 361:b599886e3655

Ajout d'une fonctionnalite utile avec le programme voir_erreurs.py afin de voir les exemples ainsi que la prediction du modele donne dans le fichier config.py
author SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca>
date Thu, 22 Apr 2010 13:17:19 -0400
parents cfb79f9fd1a4
children f24b10e43a6f
line wrap: on
line diff
--- a/deep/stacked_dae/v_sylvain/sgd_optimization.py	Thu Apr 22 10:34:26 2010 -0400
+++ b/deep/stacked_dae/v_sylvain/sgd_optimization.py	Thu Apr 22 13:17:19 2010 -0400
@@ -377,6 +377,147 @@
         train_losses2 = [test_model(x,y) for x,y in iter2]
         train_score2 = numpy.mean(train_losses2)
         print "Training error is: " + str(train_score2)
+    
+    #To see the prediction of the model, the real answer and the image to judge    
+    def see_error(self, dataset):
+        import pylab
+        #The function to know the prediction
+        test_model = \
+            theano.function(
+                [self.classifier.x,self.classifier.y], self.classifier.logLayer.y_pred)
+        user = []
+        nb_total = 0     #total number of exemples seen
+        nb_error = 0   #total number of errors
+        for x,y in dataset.test(1):
+            nb_total += 1
+            pred = self.translate(test_model(x,y))
+            rep =  self.translate(y)
+            error = pred != rep
+            print 'prediction: ' + str(pred) +'\t answer: ' + str(rep) + '\t right: ' + str(not(error))
+            pylab.imshow(x.reshape((32,32)))
+            pylab.draw()
+            if error:
+                nb_error += 1
+                user.append(int(raw_input("1 = The error is normal, 0 = The error is not normal : ")))
+                print '\t\t character is hard to distinguish: ' + str(user[-1])
+            else:
+                time.sleep(3)
+        print '\n Over the '+str(nb_total)+' exemples, there is '+str(nb_error)+' errors. \nThe percentage of errors is'+ str(float(nb_error)/float(nb_total))
+        print 'The percentage of errors done by the model that an human will also do: ' + str(numpy.mean(user))
+        
+            
+            
+            
+    #To translate the numeric prediction in character if necessary     
+    def translate(self,y):
+        
+        if y <= 9:
+            return y[0]
+        elif y == 10:
+            return 'A'
+        elif y == 11:
+            return 'B'
+        elif y == 12:
+            return 'C'
+        elif y == 13:
+            return 'D'
+        elif y == 14:
+            return 'E'
+        elif y == 15:
+            return 'F'
+        elif y == 16:
+            return 'G'
+        elif y == 17:
+            return 'H'
+        elif y == 18:
+            return 'I'
+        elif y == 19:
+            return 'J'
+        elif y == 20:
+            return 'K'
+        elif y == 21:
+            return 'L'
+        elif y == 22:
+            return 'M'
+        elif y == 23:
+            return 'N'
+        elif y == 24:
+            return 'O'
+        elif y == 25:
+            return 'P'
+        elif y == 26:
+            return 'Q'
+        elif y == 27:
+            return 'R'
+        elif y == 28:
+            return 'S'
+        elif y == 28:
+            return 'T'
+        elif y == 30:
+            return 'U'
+        elif y == 31:
+            return 'V'
+        elif y == 32:
+            return 'W'
+        elif y == 33:
+            return 'X'
+        elif y == 34:
+            return 'Y'
+        elif y == 35:
+            return 'Z'
+            
+        elif y == 36:
+            return 'a'
+        elif y == 37:
+            return 'b'
+        elif y == 38:
+            return 'c'
+        elif y == 39:
+            return 'd'
+        elif y == 40:
+            return 'e'
+        elif y == 41:
+            return 'f'
+        elif y == 42:
+            return 'g'
+        elif y == 43:
+            return 'h'
+        elif y == 44:
+            return 'i'
+        elif y == 45:
+            return 'j'
+        elif y == 46:
+            return 'k'
+        elif y == 47:
+            return 'l'
+        elif y == 48:
+            return 'm'
+        elif y == 49:
+            return 'n'
+        elif y == 50:
+            return 'o'
+        elif y == 51:
+            return 'p'
+        elif y == 52:
+            return 'q'
+        elif y == 53:
+            return 'r'
+        elif y == 54:
+            return 's'
+        elif y == 55:
+            return 't'
+        elif y == 56:
+            return 'u'
+        elif y == 57:
+            return 'v'
+        elif y == 58:
+            return 'w'
+        elif y == 59:
+            return 'x'
+        elif y == 60:
+            return 'y'
+        elif y == 61:
+            return 'z'