view baseline/mlp/mlp_get_error_from_model.py @ 222:4cfd0eb438af

Add mnist to datasets (and supporting code).
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 11 Mar 2010 14:41:31 -0500
parents e390b0454515
children 9b6e0af062af
line wrap: on
line source

__docformat__ = 'restructedtext en'

import pdb
import numpy as np
import pylab
import time 
import pylearn
from pylearn.io import filetensor as ft

data_path = '/data/lisa/data/nist/by_class/'
test_data = 'all/all_test_data.ft'
test_labels = 'all/all_test_labels.ft'

def read_test_data(mlp_model):
    
    
    #read the data
    h = open(data_path+test_data)
    i= open(data_path+test_labels)
    raw_test_data = ft.read(h)
    raw_test_labels = ft.read(i)
    i.close()
    h.close()
    
    #read the model chosen
    a=np.load(mlp_model)
    W1=a['W1']
    W2=a['W2']
    b1=a['b1']
    b2=a['b2']
    
    return (W1,b1,W2,b2,raw_test_data,raw_test_labels)
    
    
    

def get_total_test_error(everything):
    
    W1=everything[0]
    b1=everything[1]
    W2=everything[2]
    b2=everything[3]
    test_data=everything[4]/255.0
    test_labels=everything[5]
    total_error_count=0
    total_exemple_count=0
    
    nb_error_count=0
    nb_exemple_count=0
    
    char_error_count=0
    char_exemple_count=0
    
    min_error_count=0
    min_exemple_count=0
    
    maj_error_count=0
    maj_exemple_count=0
    
    for i in range(test_labels.size):
        total_exemple_count = total_exemple_count +1
        #get activation for layer 1
        a0=np.dot(np.transpose(W1),np.transpose(test_data[i])) + b1
        #add non linear function to layer 1 activation
        a0_out=np.tanh(a0)
        
        #get activation for output layer
        a1= np.dot(np.transpose(W2),a0_out) + b2
        #add non linear function for output activation (softmax)
        a1_exp = np.exp(a1)
        sum_a1=np.sum(a1_exp)
        a1_out=a1_exp/sum_a1
        
        predicted_class=np.argmax(a1_out)
        wanted_class=test_labels[i]
        
        if(predicted_class!=wanted_class):
            total_error_count = total_error_count +1
            
        #get grouped based error
        if(wanted_class>9 and wanted_class<35):
            min_exemple_count=min_exemple_count+1
            predicted_class=np.argmax(a1_out)
            if(predicted_class!=wanted_class):
		min_error_count=min_error_count+1
        elif(wanted_class<10):
            nb_exemple_count=nb_exemple_count+1
            predicted_class=np.argmax(a1_out)
            if(predicted_class!=wanted_class):
                nb_error_count=nb_error_count+1
        elif(wanted_class>34):
            maj_exemple_count=maj_exemple_count+1
            predicted_class=np.argmax(a1_out)
            if(predicted_class!=wanted_class):
                maj_error_count=maj_error_count+1
                
        if(wanted_class>9):
            char_exemple_count=char_exemple_count+1
            predicted_class=np.argmax(a1_out)
            if(predicted_class!=wanted_class):
                char_error_count=char_error_count+1
    
    
    #convert to float 
    return ( total_exemple_count,nb_exemple_count,char_exemple_count,min_exemple_count,maj_exemple_count,\
            total_error_count,nb_error_count,char_error_count,min_error_count,maj_error_count,\
            total_error_count*100.0/total_exemple_count*1.0,\
            nb_error_count*100.0/nb_exemple_count*1.0,\
            char_error_count*100.0/char_exemple_count*1.0,\
            min_error_count*100.0/min_exemple_count*1.0,\
            maj_error_count*100.0/maj_exemple_count*1.0)