Mercurial > ift6266
comparison deep/amt/amt.py @ 399:99905d9bc9dd
Initial commit for calculating the test error of the AMT classifier
author | humel |
---|---|
date | Wed, 28 Apr 2010 00:38:31 -0400 |
parents | |
children | 86d5e583e278 |
comparison
equal
deleted
inserted
replaced
398:eb42bed0c13b | 399:99905d9bc9dd |
---|---|
1 # Script usage : python amt.py filname.cvs | |
2 | |
3 import csv,numpy,re | |
4 | |
5 PATH = 'nist/' | |
6 CVSFILE = None | |
7 | |
8 answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ] | |
9 img_url = 'Input.image_url' | |
10 turks_per_batch = 3 | |
11 | |
12 | |
13 def test_error(): | |
14 turks = [] | |
15 reader = csv.DictReader(open(CVSFILE), delimiter=',') | |
16 entries = [ turk for turk in reader ] | |
17 | |
18 errors = numpy.zeros((len(entries),)) | |
19 if len(entries) % turks_per_batch != 0 : | |
20 raise Exception('Wrong number of entries or turks_per_batch') | |
21 | |
22 total_uniq_entries = len(entries) / turks_per_batch | |
23 | |
24 errors = numpy.zeros((len(entries),)) | |
25 error_means = numpy.zeros((total_uniq_entries,)) | |
26 error_variances = numpy.zeros((total_uniq_entries,)) | |
27 | |
28 | |
29 for i in range(total_uniq_entries): | |
30 for t in range(turks_per_batch): | |
31 errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t]) | |
32 error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() | |
33 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() | |
34 | |
35 print 'Average test error: ' + str(error_means.mean()*10) +'%' | |
36 print 'Error variance : ' + str(error_variances.mean()*10) +'%' | |
37 | |
38 | |
39 def get_error(entry): | |
40 file = parse_filename(entry[img_url]) | |
41 f = open(PATH+file,'r') | |
42 labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') | |
43 f.close() | |
44 test_error = 0 | |
45 for i in range(len(answer_labels)): | |
46 try: | |
47 answer = int(entry[answer_labels[i]]) | |
48 except: | |
49 try : | |
50 answer = ord(entry[answer_labels[i]]) | |
51 except: | |
52 test_error+=1 | |
53 continue | |
54 if answer != int(labels[i]): | |
55 test_error+=1 | |
56 return test_error | |
57 | |
58 | |
59 def parse_filename(string): | |
60 filename = string.split('/')[-1] | |
61 return filename.split('.')[0]+'.txt' | |
62 | |
63 if __name__ =='__main__': | |
64 import sys | |
65 CVSFILE = sys.argv[1] | |
66 test_error() |