Mercurial > ift6266
view deep/amt/amt.py @ 399:99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
author | humel |
---|---|
date | Wed, 28 Apr 2010 00:38:31 -0400 |
parents | |
children | 86d5e583e278 |
line wrap: on
line source
# Script usage : python amt.py filname.cvs import csv,numpy,re PATH = 'nist/' CVSFILE = None answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ] img_url = 'Input.image_url' turks_per_batch = 3 def test_error(): turks = [] reader = csv.DictReader(open(CVSFILE), delimiter=',') entries = [ turk for turk in reader ] errors = numpy.zeros((len(entries),)) if len(entries) % turks_per_batch != 0 : raise Exception('Wrong number of entries or turks_per_batch') total_uniq_entries = len(entries) / turks_per_batch errors = numpy.zeros((len(entries),)) error_means = numpy.zeros((total_uniq_entries,)) error_variances = numpy.zeros((total_uniq_entries,)) for i in range(total_uniq_entries): for t in range(turks_per_batch): errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t]) error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() print 'Average test error: ' + str(error_means.mean()*10) +'%' print 'Error variance : ' + str(error_variances.mean()*10) +'%' def get_error(entry): file = parse_filename(entry[img_url]) f = open(PATH+file,'r') labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') f.close() test_error = 0 for i in range(len(answer_labels)): try: answer = int(entry[answer_labels[i]]) except: try : answer = ord(entry[answer_labels[i]]) except: test_error+=1 continue if answer != int(labels[i]): test_error+=1 return test_error def parse_filename(string): filename = string.split('/')[-1] return filename.split('.')[0]+'.txt' if __name__ =='__main__': import sys CVSFILE = sys.argv[1] test_error()