Mercurial > ift6266
diff deep/amt/amt.py @ 399:99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
author | humel |
---|---|
date | Wed, 28 Apr 2010 00:38:31 -0400 |
parents | |
children | 86d5e583e278 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/deep/amt/amt.py Wed Apr 28 00:38:31 2010 -0400 @@ -0,0 +1,66 @@ +# Script usage : python amt.py filname.cvs + +import csv,numpy,re + +PATH = 'nist/' +CVSFILE = None + +answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ] +img_url = 'Input.image_url' +turks_per_batch = 3 + + +def test_error(): + turks = [] + reader = csv.DictReader(open(CVSFILE), delimiter=',') + entries = [ turk for turk in reader ] + + errors = numpy.zeros((len(entries),)) + if len(entries) % turks_per_batch != 0 : + raise Exception('Wrong number of entries or turks_per_batch') + + total_uniq_entries = len(entries) / turks_per_batch + + errors = numpy.zeros((len(entries),)) + error_means = numpy.zeros((total_uniq_entries,)) + error_variances = numpy.zeros((total_uniq_entries,)) + + + for i in range(total_uniq_entries): + for t in range(turks_per_batch): + errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t]) + error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() + error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() + + print 'Average test error: ' + str(error_means.mean()*10) +'%' + print 'Error variance : ' + str(error_variances.mean()*10) +'%' + + +def get_error(entry): + file = parse_filename(entry[img_url]) + f = open(PATH+file,'r') + labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') + f.close() + test_error = 0 + for i in range(len(answer_labels)): + try: + answer = int(entry[answer_labels[i]]) + except: + try : + answer = ord(entry[answer_labels[i]]) + except: + test_error+=1 + continue + if answer != int(labels[i]): + test_error+=1 + return test_error + + +def parse_filename(string): + filename = string.split('/')[-1] + return filename.split('.')[0]+'.txt' + +if __name__ =='__main__': + import sys + CVSFILE = sys.argv[1] + test_error()