view deep/amt/amt.py @ 402:83413ac10913

Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
author humel
date Wed, 28 Apr 2010 11:28:28 -0400
parents 86d5e583e278
children a11692910312
line wrap: on
line source

# Script usage : python amt.py filname.cvs

import csv,numpy,re,decimal

DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/',
                 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/',
                 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' }

CVSFILE = None
#PATH = None
answer_labels = [ 'answer.c'+str(i+1) for i in range(10) ]
img_url = 'input.image_url'
turks_per_batch = 3
image_per_batch = 10

def setup_association():
    answer_assoc = {}
    for i in range(0,10):
        answer_assoc[str(i)]=i
    for i in range(10,36):
        answer_assoc[chr(i+55)]=i
    for i in range(36,62):
        answer_assoc[chr(i+61)]=i
    return answer_assoc

answer_assoc = setup_association()

def test_error():
    turks = []
    reader = csv.DictReader(open(CVSFILE), delimiter=',')
    entries = [ turk for turk in reader ]

    errors = numpy.zeros((len(entries),))
    if len(entries) % turks_per_batch != 0 :
        raise Exception('Wrong number of entries or turks_per_batch')

    total_uniq_entries = len(entries) / turks_per_batch

    errors = numpy.zeros((len(entries),))
    error_means = numpy.zeros((total_uniq_entries,))
    error_variances = numpy.zeros((total_uniq_entries,))

    PATH = DATASET_PATH[find_dataset(entries[0])]
    for i in range(total_uniq_entries):
        for t in range(turks_per_batch):
            errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH)
        error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
        error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()

    decimal.getcontext().prec = 3
    print 'Total entries      : ' + str(len(entries))
    print 'Turks per batch    : ' + str(turks_per_batch)
    print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%'
    print 'Error variance     : ' + str(error_variances.mean()*image_per_batch) +'%' 


def find_dataset(entry):
    file = parse_filename(entry[img_url])
    return file.split('_')[0]
    
def get_error(entry,PATH):
    file = parse_filename(entry[img_url])
    f = open(PATH+file,'r')
    labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
    f.close()
    test_error = 0
    for i in range(len(answer_labels)):
        answer = entry[answer_labels[i]]
        if len(answer) != 0:
            try:
                if answer_assoc[answer] != int(labels[i]):
                    test_error+=1
            except:
                test_error+=1
        else:
            test_error+=1
    return test_error


def parse_filename(string):
    filename = string.split('/')[-1]
    return filename.split('.')[0]+'.txt'

if __name__ =='__main__':
    import sys
    CVSFILE = sys.argv[1]
    test_error()