comparison deep/amt/amt.py @ 399:99905d9bc9dd

Initial commit for calculating the test error of the AMT classifier
author humel
date Wed, 28 Apr 2010 00:38:31 -0400
parents
children 86d5e583e278
comparison
equal deleted inserted replaced
398:eb42bed0c13b 399:99905d9bc9dd
1 # Script usage : python amt.py filname.cvs
2
3 import csv,numpy,re
4
5 PATH = 'nist/'
6 CVSFILE = None
7
8 answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ]
9 img_url = 'Input.image_url'
10 turks_per_batch = 3
11
12
13 def test_error():
14 turks = []
15 reader = csv.DictReader(open(CVSFILE), delimiter=',')
16 entries = [ turk for turk in reader ]
17
18 errors = numpy.zeros((len(entries),))
19 if len(entries) % turks_per_batch != 0 :
20 raise Exception('Wrong number of entries or turks_per_batch')
21
22 total_uniq_entries = len(entries) / turks_per_batch
23
24 errors = numpy.zeros((len(entries),))
25 error_means = numpy.zeros((total_uniq_entries,))
26 error_variances = numpy.zeros((total_uniq_entries,))
27
28
29 for i in range(total_uniq_entries):
30 for t in range(turks_per_batch):
31 errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t])
32 error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
33 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
34
35 print 'Average test error: ' + str(error_means.mean()*10) +'%'
36 print 'Error variance : ' + str(error_variances.mean()*10) +'%'
37
38
39 def get_error(entry):
40 file = parse_filename(entry[img_url])
41 f = open(PATH+file,'r')
42 labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
43 f.close()
44 test_error = 0
45 for i in range(len(answer_labels)):
46 try:
47 answer = int(entry[answer_labels[i]])
48 except:
49 try :
50 answer = ord(entry[answer_labels[i]])
51 except:
52 test_error+=1
53 continue
54 if answer != int(labels[i]):
55 test_error+=1
56 return test_error
57
58
59 def parse_filename(string):
60 filename = string.split('/')[-1]
61 return filename.split('.')[0]+'.txt'
62
63 if __name__ =='__main__':
64 import sys
65 CVSFILE = sys.argv[1]
66 test_error()