changeset 439:5ca2936f2062

Bug fix in per-class a priori error : should be lower than before for both upper and lower
author Guillaume Sicard <guitch21@gmail.com>
date Mon, 03 May 2010 09:42:36 -0400
parents a6d339033d03
children 89258bb41e4c
files deep/stacked_dae/v_sylvain/nist_apriori_error.py
diffstat 1 files changed, 14 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/deep/stacked_dae/v_sylvain/nist_apriori_error.py	Mon May 03 07:46:18 2010 -0400
+++ b/deep/stacked_dae/v_sylvain/nist_apriori_error.py	Mon May 03 09:42:36 2010 -0400
@@ -1,3 +1,4 @@
+# -*- coding: utf-8 -*-
 __docformat__ = 'restructedtext en'
 
 import pdb
@@ -65,6 +66,8 @@
     
     total_error_count=0
     total_exemple_count=0
+    total_error_count_wap=0
+
     if part == 0:
         iter = dataset.train(1)
     if part == 1:
@@ -111,8 +114,8 @@
 
         #get grouped based error
         #with a priori
-        if(y>9 and y<35):
-            predicted_class=numpy.argmax(out[0,10:35])+10
+        if(y>9 and y<36):
+            predicted_class=numpy.argmax(out[0,10:36])+10
             if(predicted_class!=y):
                 total_error_count+=1
                 
@@ -120,14 +123,20 @@
             predicted_class=numpy.argmax(out[0,0:10])
             if(predicted_class!=y):
                 total_error_count+=1
-        if(y>34):
-            predicted_class=numpy.argmax(out[0,35:])+35
+        if(y>35):
+            predicted_class=numpy.argmax(out[0,36:])+36
             if(predicted_class!=y):
                 total_error_count+=1
-                
+	#without a priori
+	predicted_class=numpy.argmax(out)
+	if(predicted_class!=y):
+	  total_error_count_wap+=1
+
     print '\t total exemples count: '+str(total_exemple_count)
     print '\t total error count: '+str(total_error_count)
     print '\t percentage of error: '+str(total_error_count*100.0/total_exemple_count*1.0)+' %'
+    print '\t total error count without a priori: '+str(total_error_count_wap)
+    print '\t percentage of error without a priori: '+str(total_error_count_wap*100.0/total_exemple_count*1.0)+' %'
     
 
 def sigmoid(value):