Mercurial > ift6266
view deep/amt/amt.py @ 402:83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
author | humel |
---|---|
date | Wed, 28 Apr 2010 11:28:28 -0400 |
parents | 86d5e583e278 |
children | a11692910312 |
line wrap: on
line source
# Script usage : python amt.py filname.cvs import csv,numpy,re,decimal DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } CVSFILE = None #PATH = None answer_labels = [ 'answer.c'+str(i+1) for i in range(10) ] img_url = 'input.image_url' turks_per_batch = 3 image_per_batch = 10 def setup_association(): answer_assoc = {} for i in range(0,10): answer_assoc[str(i)]=i for i in range(10,36): answer_assoc[chr(i+55)]=i for i in range(36,62): answer_assoc[chr(i+61)]=i return answer_assoc answer_assoc = setup_association() def test_error(): turks = [] reader = csv.DictReader(open(CVSFILE), delimiter=',') entries = [ turk for turk in reader ] errors = numpy.zeros((len(entries),)) if len(entries) % turks_per_batch != 0 : raise Exception('Wrong number of entries or turks_per_batch') total_uniq_entries = len(entries) / turks_per_batch errors = numpy.zeros((len(entries),)) error_means = numpy.zeros((total_uniq_entries,)) error_variances = numpy.zeros((total_uniq_entries,)) PATH = DATASET_PATH[find_dataset(entries[0])] for i in range(total_uniq_entries): for t in range(turks_per_batch): errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH) error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() decimal.getcontext().prec = 3 print 'Total entries : ' + str(len(entries)) print 'Turks per batch : ' + str(turks_per_batch) print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%' print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' def find_dataset(entry): file = parse_filename(entry[img_url]) return file.split('_')[0] def get_error(entry,PATH): file = parse_filename(entry[img_url]) f = open(PATH+file,'r') labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') f.close() test_error = 0 for i in range(len(answer_labels)): answer = entry[answer_labels[i]] if len(answer) != 0: try: if answer_assoc[answer] != int(labels[i]): test_error+=1 except: test_error+=1 else: test_error+=1 return test_error def parse_filename(string): filename = string.split('/')[-1] return filename.split('.')[0]+'.txt' if __name__ =='__main__': import sys CVSFILE = sys.argv[1] test_error()