Mercurial > ift6266
view deep/amt/amt.py @ 437:479f2f518fc9
added Training with More Classes than Necessary
author | Guillaume Sicard <guitch21@gmail.com> |
---|---|
date | Mon, 03 May 2010 06:17:54 -0400 |
parents | 6478eef4f8aa |
children | 5777b5041ac9 |
line wrap: on
line source
# Script usage : python amt.py filname.cvs type """ [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv all Testing on : all Total entries : 300.0 Turks per batch : 3 Average test error : 45.3333333333% Error variance : 7.77777777778% [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv 36 Testing on : 36 Total entries : 300.0 Turks per batch : 3 Average test error : 51.6666666667% Error variance : 3.33333333333% [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv upper Testing on : upper Total entries : 63.0 Turks per batch : 3 Average test error : 53.9682539683% Error variance : 1.77777777778% [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv lower Testing on : lower Total entries : 135.0 Turks per batch : 3 Average test error : 37.037037037% Error variance : 3.77777777778% [rifaisal@timide ../fix/ift6266/deep/amt]$ python amt.py pnist.csv digits Testing on : digits Total entries : 102.0 Turks per batch : 3 Average test error : 50.9803921569% Error variance : 1.33333333333% """ import csv,numpy,re,decimal DATASET_PATH = { 'nist' : '/data/lisa/data/ift6266h10/amt_data/nist/', 'p07' : '/data/lisa/data/ift6266h10/amt_data/p07/', 'pnist' : '/data/lisa/data/ift6266h10/amt_data/pnist/' } CVSFILE = None #PATH = None answer_labels = [ 'Answer.c'+str(i+1) for i in range(10) ] img_url = 'Input.image_url' turks_per_batch = 3 image_per_batch = 10 TYPE = None def all_classes_assoc(): answer_assoc = {} for i in range(0,10): answer_assoc[str(i)]=i for i in range(10,36): answer_assoc[chr(i+55)]=i for i in range(36,62): answer_assoc[chr(i+61)]=i return answer_assoc def upper_classes_assoc(): answer_assoc = {} for i in range(10,36): answer_assoc[chr(i+55)]=i return answer_assoc def lower_classes_assoc(): answer_assoc = {} for i in range(36,62): answer_assoc[chr(i+61)]=i return answer_assoc def digit_classes_assoc(): answer_assoc = {} for i in range(0,10): answer_assoc[str(i)]=i return answer_assoc def tsix_classes_assoc(): answer_assoc = {} for i in range(10,36): answer_assoc[chr(i+55)]=i answer_assoc[chr(i+87)]=i return answer_assoc def upper_label_assoc(ulabel): for i in range(len(ulabel)): if ulabel[i] < 10 or ulabel[i] > 35 : ulabel[i] = -1 return ulabel def lower_label_assoc(ulabel): for i in range(len(ulabel)): if ulabel[i] < 36 or ulabel[i] > 61 : ulabel[i] = -1 return ulabel def tsix_label_assoc(ulabel): for i in range(len(ulabel)): if ulabel[i] > 35 and ulabel[i] < 62 : ulabel[i] = ulabel[i] - 26 return ulabel def digit_label_assoc(ulabel): for i in range(len(ulabel)): if ulabel[i] < 0 or ulabel[i] > 9 : ulabel[i] = -1 return ulabel def classes_answer(type): if type == 'all': return all_classes_assoc() elif type == '36': return tsix_classes_assoc() elif type == 'lower': return lower_classes_assoc() elif type == 'upper': return upper_classes_assoc() elif type == 'digits': return digit_classes_assoc() else: raise ('Inapropriate option for the type of classification :' + type) def test_error(assoc_type=TYPE): answer_assoc = classes_answer(assoc_type) turks = [] reader = csv.DictReader(open(CVSFILE), delimiter=',') entries = [ turk for turk in reader ] errors = numpy.zeros((len(entries),)) if len(entries) % turks_per_batch != 0 : raise Exception('Wrong number of entries or turks_per_batch') total_uniq_entries = len(entries) / turks_per_batch errors = numpy.zeros((len(entries),)) num_examples = numpy.zeros((len(entries),)) error_means = numpy.zeros((total_uniq_entries,)) error_variances = numpy.zeros((total_uniq_entries,)) for i in range(total_uniq_entries): for t in range(turks_per_batch): errors[i*turks_per_batch+t],num_examples[i*turks_per_batch+t] = get_error(entries[i*turks_per_batch+t],assoc_type) #error_means[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].mean() error_variances[i] = errors[i*turks_per_batch:(i+1)*turks_per_batch].var() percentage_error = 100. * errors.sum() / num_examples.sum() print 'Testing on : ' + str(assoc_type) print 'Total entries : ' + str(num_examples.sum()) print 'Turks per batch : ' + str(turks_per_batch) print 'Average test error : ' + str(percentage_error) +'%' print 'Error variance : ' + str(error_variances.mean()*image_per_batch) +'%' def find_dataset(entry): file = parse_filename(entry[img_url]) return file.split('_')[0] def get_error(entry, type): answer_assoc = classes_answer(type) labels = get_labels(entry,type) test_error = 0 cnt = 0 for i in range(len(answer_labels)): if labels[i] == -1: cnt+=1 else: answer = entry[answer_labels[i]] try: if answer_assoc[answer] != labels[i]: test_error+=1 except: test_error+=1 return test_error,image_per_batch-cnt def get_labels(entry,type): file = parse_filename(entry[img_url]) path = DATASET_PATH[find_dataset(entry)] f = open(path+file,'r') str_labels = re.sub("\s+", "",f.readline()).strip()[1:-2].split('.') unrestricted_labels = [ int(element) for element in str_labels ] if type == 'all': return unrestricted_labels elif type == '36': return tsix_label_assoc(unrestricted_labels) elif type == 'lower': return lower_label_assoc(unrestricted_labels) elif type == 'upper': return upper_label_assoc(unrestricted_labels) elif type == 'digits': return digit_label_assoc(unrestricted_labels) else: raise ('Inapropriate option for the type of classification :' + str(type)) def parse_filename(string): filename = string.split('/')[-1] return filename.split('.')[0]+'.txt' if __name__ =='__main__': import sys CVSFILE = sys.argv[1] test_error(sys.argv[2])