Mercurial > ift6266
comparison deep/amt/amt.py @ 402:83413ac10913
Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically
author | humel |
---|---|
date | Wed, 28 Apr 2010 11:28:28 -0400 |
parents | 86d5e583e278 |
children | a11692910312 |
comparison
equal
deleted
inserted
replaced
401:86d5e583e278 | 402:83413ac10913 |
---|---|
1 # Script usage : python amt.py filname.cvs | 1 # Script usage : python amt.py filname.cvs |
2 | 2 |
3 import csv,numpy,re | 3 import csv,numpy,re,decimal |
4 | 4 |
5 PATH = 'nist/' | 5 DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', |
6 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', | |
7 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } | |
8 | |
6 CVSFILE = None | 9 CVSFILE = None |
7 | 10 #PATH = None |
8 answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ] | 11 answer_labels = [ 'answer.c'+str(i+1) for i in range(10) ] |
9 img_url = 'Input.image_url' | 12 img_url = 'input.image_url' |
10 turks_per_batch = 3 | 13 turks_per_batch = 3 |
11 | 14 image_per_batch = 10 |
12 | 15 |
13 def setup_association(): | 16 def setup_association(): |
14 answer_assoc = {} | 17 answer_assoc = {} |
15 for i in range(0,10): | 18 for i in range(0,10): |
16 answer_assoc[str(i)]=i | 19 answer_assoc[str(i)]=i |
35 | 38 |
36 errors = numpy.zeros((len(entries),)) | 39 errors = numpy.zeros((len(entries),)) |
37 error_means = numpy.zeros((total_uniq_entries,)) | 40 error_means = numpy.zeros((total_uniq_entries,)) |
38 error_variances = numpy.zeros((total_uniq_entries,)) | 41 error_variances = numpy.zeros((total_uniq_entries,)) |
39 | 42 |
40 | 43 PATH = DATASET_PATH[find_dataset(entries[0])] |
41 for i in range(total_uniq_entries): | 44 for i in range(total_uniq_entries): |
42 for t in range(turks_per_batch): | 45 for t in range(turks_per_batch): |
43 errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t]) | 46 errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH) |
44 error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() | 47 error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() |
45 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() | 48 error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() |
46 | 49 |
47 print 'Average test error: ' + str(error_means.mean()*10) +'%' | 50 decimal.getcontext().prec = 3 |
48 print 'Error variance : ' + str(error_variances.mean()*10) +'%' | 51 print 'Total entries : ' + str(len(entries)) |
52 print 'Turks per batch : ' + str(turks_per_batch) | |
53 print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%' | |
54 print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' | |
49 | 55 |
50 | 56 |
51 def get_error(entry): | 57 def find_dataset(entry): |
58 file = parse_filename(entry[img_url]) | |
59 return file.split('_')[0] | |
60 | |
61 def get_error(entry,PATH): | |
52 file = parse_filename(entry[img_url]) | 62 file = parse_filename(entry[img_url]) |
53 f = open(PATH+file,'r') | 63 f = open(PATH+file,'r') |
54 labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') | 64 labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') |
55 f.close() | 65 f.close() |
56 test_error = 0 | 66 test_error = 0 |