comparison deep/amt/amt.py @ 402:83413ac10913

Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
author humel
date Wed, 28 Apr 2010 11:28:28 -0400
parents 86d5e583e278
children a11692910312
comparison
equal deleted inserted replaced
401:86d5e583e278 402:83413ac10913
1 # Script usage : python amt.py filname.cvs 1 # Script usage : python amt.py filname.cvs
2 2
3 import csv,numpy,re 3 import csv,numpy,re,decimal
4 4
5 PATH = 'nist/' 5 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/',
6 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/',
7 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' }
8
6 CVSFILE = None 9 CVSFILE = None
7 10 #PATH = None
8 answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ] 11 answer_labels = [ 'answer.c'+str(i+1) for i in range(10) ]
9 img_url = 'Input.image_url' 12 img_url = 'input.image_url'
10 turks_per_batch = 3 13 turks_per_batch = 3
11 14 image_per_batch = 10
12 15
13 def setup_association(): 16 def setup_association():
14 answer_assoc = {} 17 answer_assoc = {}
15 for i in range(0,10): 18 for i in range(0,10):
16 answer_assoc[str(i)]=i 19 answer_assoc[str(i)]=i
35 38
36 errors = numpy.zeros((len(entries),)) 39 errors = numpy.zeros((len(entries),))
37 error_means = numpy.zeros((total_uniq_entries,)) 40 error_means = numpy.zeros((total_uniq_entries,))
38 error_variances = numpy.zeros((total_uniq_entries,)) 41 error_variances = numpy.zeros((total_uniq_entries,))
39 42
40 43 PATH = DATASET_PATH[find_dataset(entries[0])]
41 for i in range(total_uniq_entries): 44 for i in range(total_uniq_entries):
42 for t in range(turks_per_batch): 45 for t in range(turks_per_batch):
43 errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t]) 46 errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH)
44 error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() 47 error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean()
45 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() 48 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var()
46 49
47 print 'Average test error: ' + str(error_means.mean()*10) +'%' 50 decimal.getcontext().prec = 3
48 print 'Error variance : ' + str(error_variances.mean()*10) +'%' 51 print 'Total entries : ' + str(len(entries))
52 print 'Turks per batch : ' + str(turks_per_batch)
53 print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%'
54 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%'
49 55
50 56
51 def get_error(entry): 57 def find_dataset(entry):
58 file = parse_filename(entry[img_url])
59 return file.split('_')[0]
60
61 def get_error(entry,PATH):
52 file = parse_filename(entry[img_url]) 62 file = parse_filename(entry[img_url])
53 f = open(PATH+file,'r') 63 f = open(PATH+file,'r')
54 labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') 64 labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')
55 f.close() 65 f.close()
56 test_error = 0 66 test_error = 0