# HG changeset patch # User humel # Date 1272565732 14400 # Node ID 6478eef4f8aa006032eebaba17eb48629adeddcd # Parent 4f69d915d1421fe6fbfb7fd1d367a1a8edfb3e04 Added support for calculating the test error over different set of classes (lower,upper,digits,all,36) diff -r 4f69d915d142 -r 6478eef4f8aa deep/amt/amt.py --- a/deep/amt/amt.py Thu Apr 29 13:18:15 2010 -0400 +++ b/deep/amt/amt.py Thu Apr 29 14:28:52 2010 -0400 @@ -1,4 +1,36 @@ -# Script usage : python amt.py filname.cvs +# Script usage : python amt.py filname.cvs type +""" +[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all +Testing on : all +Total entries : 300.0 +Turks per batch : 3 +Average test error : 45.3333333333% +Error variance : 7.77777777778% +[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36 +Testing on : 36 +Total entries : 300.0 +Turks per batch : 3 +Average test error : 51.6666666667% +Error variance : 3.33333333333% +[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper +Testing on : upper +Total entries : 63.0 +Turks per batch : 3 +Average test error : 53.9682539683% +Error variance : 1.77777777778% +[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower +Testing on : lower +Total entries : 135.0 +Turks per batch : 3 +Average test error : 37.037037037% +Error variance : 3.77777777778% +[rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits +Testing on : digits +Total entries : 102.0 +Turks per batch : 3 +Average test error : 50.9803921569% +Error variance : 1.33333333333% +""" import csv,numpy,re,decimal @@ -12,8 +44,9 @@ img_url = 'Input.image_url' turks_per_batch = 3 image_per_batch = 10 +TYPE = None -def setup_association(): +def all_classes_assoc(): answer_assoc = {} for i in range(0,10): answer_assoc[str(i)]=i @@ -23,9 +56,74 @@ answer_assoc[chr(i+61)]=i return answer_assoc -answer_assoc = setup_association() +def upper_classes_assoc(): + answer_assoc = {} + for i in range(10,36): + answer_assoc[chr(i+55)]=i + return answer_assoc + +def lower_classes_assoc(): + answer_assoc = {} + for i in range(36,62): + answer_assoc[chr(i+61)]=i + return answer_assoc + +def digit_classes_assoc(): + answer_assoc = {} + for i in range(0,10): + answer_assoc[str(i)]=i + return answer_assoc + +def tsix_classes_assoc(): + answer_assoc = {} + for i in range(10,36): + answer_assoc[chr(i+55)]=i + answer_assoc[chr(i+87)]=i + return answer_assoc + +def upper_label_assoc(ulabel): + for i in range(len(ulabel)): + if ulabel[i] < 10 or ulabel[i] > 35 : + ulabel[i] = -1 + return ulabel -def test_error(): +def lower_label_assoc(ulabel): + for i in range(len(ulabel)): + if ulabel[i] < 36 or ulabel[i] > 61 : + ulabel[i] = -1 + return ulabel + +def tsix_label_assoc(ulabel): + for i in range(len(ulabel)): + if ulabel[i] > 35 and ulabel[i] < 62 : + ulabel[i] = ulabel[i] - 26 + return ulabel + +def digit_label_assoc(ulabel): + for i in range(len(ulabel)): + if ulabel[i] < 0 or ulabel[i] > 9 : + ulabel[i] = -1 + + return ulabel + +def classes_answer(type): + if type == 'all': + return all_classes_assoc() + elif type == '36': + return tsix_classes_assoc() + elif type == 'lower': + return lower_classes_assoc() + elif type == 'upper': + return upper_classes_assoc() + elif type == 'digits': + return digit_classes_assoc() + else: + raise ('Inapropriate option for the type of classification :' + type) + + +def test_error(assoc_type=TYPE): + answer_assoc = classes_answer(assoc_type) + turks = [] reader = csv.DictReader(open(CVSFILE), delimiter=',') entries = [ turk for turk in reader ] @@ -37,20 +135,21 @@ total_uniq_entries = len(entries) / turks_per_batch errors = numpy.zeros((len(entries),)) + num_examples = numpy.zeros((len(entries),)) error_means = numpy.zeros((total_uniq_entries,)) error_variances = numpy.zeros((total_uniq_entries,)) - PATH = DATASET_PATH[find_dataset(entries[0])] for i in range(total_uniq_entries): for t in range(turks_per_batch): - errors[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],PATH) - error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() + errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) + #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() - - decimal.getcontext().prec = 3 - print 'Total entries : ' + str(len(entries)) + + percentage_error = 100. * errors.sum() / num_examples.sum() + print 'Testing on : ' + str(assoc_type) + print 'Total entries : ' + str(num_examples.sum()) print 'Turks per batch : ' + str(turks_per_batch) - print 'Average test error : ' + str(error_means.mean()*image_per_batch) +'%' + print 'Average test error : ' + str(percentage_error) +'%' print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' @@ -58,24 +157,43 @@ file = parse_filename(entry[img_url]) return file.split('_')[0] -def get_error(entry,PATH): - file = parse_filename(entry[img_url]) - f = open(PATH+file,'r') - labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') - f.close() +def get_error(entry, type): + answer_assoc = classes_answer(type) + labels = get_labels(entry,type) test_error = 0 + cnt = 0 for i in range(len(answer_labels)): - answer = entry[answer_labels[i]] - if len(answer) != 0: + if labels[i] == -1: + cnt+=1 + else: + answer = entry[answer_labels[i]] try: - if answer_assoc[answer] != int(labels[i]): + if answer_assoc[answer] != labels[i]: test_error+=1 except: test_error+=1 - else: - test_error+=1 - return test_error + return test_error,image_per_batch-cnt + +def get_labels(entry,type): + file = parse_filename(entry[img_url]) + path = DATASET_PATH[find_dataset(entry)] + f = open(path+file,'r') + str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') + unrestricted_labels = [ int(element) for element in str_labels ] + if type == 'all': + return unrestricted_labels + elif type == '36': + return tsix_label_assoc(unrestricted_labels) + elif type == 'lower': + return lower_label_assoc(unrestricted_labels) + elif type == 'upper': + return upper_label_assoc(unrestricted_labels) + elif type == 'digits': + return digit_label_assoc(unrestricted_labels) + else: + raise ('Inapropriate option for the type of classification :' + str(type)) + def parse_filename(string): filename = string.split('/')[-1] @@ -84,4 +202,4 @@ if __name__ =='__main__': import sys CVSFILE = sys.argv[1] - test_error() + test_error(sys.argv[2])