# HG changeset patch # User humel # Date 1272468508 14400 # Node ID 83413ac10913a2e44bf37c93b2ba46525164ad11 # Parent 86d5e583e278aeceb44ab99b8a0d315b38475fcb Added more stats printing. Now you dont need to parameters which dataset you are testing, it will detect it automatically diff -r 86d5e583e278 -r 83413ac10913 deep/amt/amt.py --- a/deep/amt/amt.py Wed Apr 28 01:25:52 2010 -0400 +++ b/deep/amt/amt.py Wed Apr 28 11:28:28 2010 -0400 @@ -1,14 +1,17 @@ # Script usage : python amt.py filname.cvs -import csv,numpy,re +import csv,numpy,re,decimal -PATH = 'nist/' -CVSFILE = None +DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', + 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', + 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } -answer_labels = [ 'Answer.C'+str(i+1) for i in range(10) ] -img_url = 'Input.image_url' +CVSFILE = None +#PATH = None +answer_labels = [ 'answer.c'+str(i+1) for i in range(10) ] +img_url = 'input.image_url' turks_per_batch = 3 - +image_per_batch = 10 def setup_association(): answer_assoc = {} @@ -37,18 +40,25 @@ error_means = numpy.zeros((total_uniq_entries,)) error_variances = numpy.zeros((total_uniq_entries,)) - + PATH = DATASET_PATH[find_dataset(entries[0])] for i in range(total_uniq_entries): for t in range(turks_per_batch): - errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t]) + errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH) error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() - print 'Average test error: ' + str(error_means.mean()*10) +'%' - print 'Error variance : ' + str(error_variances.mean()*10) +'%' + decimal.getcontext().prec = 3 + print 'Total entries : ' + str(len(entries)) + print 'Turks per batch : ' + str(turks_per_batch) + print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%' + print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' -def get_error(entry): +def find_dataset(entry): + file = parse_filename(entry[img_url]) + return file.split('_')[0] + +def get_error(entry,PATH): file = parse_filename(entry[img_url]) f = open(PATH+file,'r') labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.')