Mercurial > ift6266
view deep/amt/amt.py @ 401:86d5e583e278
Fixed class number bug
author | humel |
---|---|
date | Wed, 28 Apr 2010 01:25:52 -0400 |
parents | 99905d9bc9dd |
children | 83413ac10913 |
line wrap: on
line source
# Script usage : python amt.py filname.cvs import csv,numpy,re PATH = 'nist/' CVSFILE = None answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ] img_url = 'Input.image_url' turks_per_batch = 3 def setup_association(): answer_assoc = {} for i in range(0,10): answer_assoc[str(i)]=i for i in range(10,36): answer_assoc[chr(i+55)]=i for i in range(36,62): answer_assoc[chr(i+61)]=i return answer_assoc answer_assoc = setup_association() def test_error(): turks = [] reader = csv.DictReader(open(CVSFILE), delimiter=',') entries = [ turk for turk in reader ] errors = numpy.zeros((len(entries),)) if len(entries) % turks_per_batch != 0 : raise Exception('Wrong number of entries or turks_per_batch') total_uniq_entries = len(entries) / turks_per_batch errors = numpy.zeros((len(entries),)) error_means = numpy.zeros((total_uniq_entries,)) error_variances = numpy.zeros((total_uniq_entries,)) for i in range(total_uniq_entries): for t in range(turks_per_batch): errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t]) error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() print 'Average test error: ' + str(error_means.mean()*10) +'%' print 'Error variance : ' + str(error_variances.mean()*10) +'%' def get_error(entry): file = parse_filename(entry[img_url]) f = open(PATH+file,'r') labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') f.close() test_error = 0 for i in range(len(answer_labels)): answer = entry[answer_labels[i]] if len(answer) != 0: try: if answer_assoc[answer] != int(labels[i]): test_error+=1 except: test_error+=1 else: test_error+=1 return test_error def parse_filename(string): filename = string.split('/')[-1] return filename.split('.')[0]+'.txt' if __name__ =='__main__': import sys CVSFILE = sys.argv[1] test_error()