Mercurial > ift6266
diff transformations/pipeline.py @ 67:5e448ea129b3
Ajouté la source (optionnelle) de données OCR Autriche avec une probabilité passée en argument
author | boulanni <nicolas_boulanger@hotmail.com> |
---|---|
date | Tue, 09 Feb 2010 21:33:57 -0500 |
parents | 1afa95285b9c |
children | 95c491bb5662 |
line wrap: on
line diff
--- a/transformations/pipeline.py Tue Feb 09 18:45:35 2010 -0500 +++ b/transformations/pipeline.py Tue Feb 09 21:33:57 2010 -0500 @@ -14,7 +14,7 @@ # To debug locally, also call with -s 100 (to stop after ~100) # (otherwise we allocate all needed memory, might be loonnng and/or crash # if, lucky like me, you have an age-old laptop creaking from everywhere) -DEBUG = True +DEBUG = False DEBUG_X = False if DEBUG: DEBUG_X = False # Debug under X (pylab.show()) @@ -33,6 +33,8 @@ DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft' DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft' +DEFAULT_OCR_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-shuffled.ft' +DEFAULT_OCRLABEL_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-labels-shuffled.ft' ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] if DEBUG_X: @@ -188,28 +190,34 @@ ''' class NistData(): - def __init__(self, nist_path, label_path): + def __init__(self, nist_path, label_path, ocr_path, ocrlabel_path): self.train_data = open(nist_path, 'rb') self.train_labels = open(label_path, 'rb') self.dim = tuple(ft._read_header(self.train_data)[3]) # in order to seek to the beginning of the file self.train_data.close() self.train_data = open(nist_path, 'rb') - + self.ocr_data = open(ocr_path, 'rb') + self.ocr_labels = open(ocrlabel_path, 'rb') -def nist_supp_iterator(nist, prob_font, prob_captcha, num_img): - subtensor = slice(0, num_img) - img = ft.read(nist.train_data, subtensor).astype(numpy.float32)/255 - labels = ft.read(nist.train_labels, subtensor) +def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img): + img = ft.read(nist.train_data).astype(numpy.float32)/255 + labels = ft.read(nist.train_labels) + if prob_ocr: + ocr_img = ft.read(nist.ocr_data).astype(numpy.float32)/255 + ocr_labels = ft.read(nist.ocr_labels) for i in xrange(num_img): r = numpy.random.rand() - if r<= prob_font: + if r <= prob_font: pass #get font - elif r<= prob_font + prob_captcha: + elif r <= prob_font + prob_captcha: pass #get captcha + elif r <= prob_font + prob_captcha + prob_ocr: + j = numpy.random.randint(len(ocr_labels)) + yield ocr_img[j], ocr_labels[j] else: - j = numpy.random.randint(num_img) + j = numpy.random.randint(len(labels)) yield img[j], labels[j] @@ -237,11 +245,14 @@ -z, --probability-zero: probability of using complexity=0 for an image -o, --output-file: full path to file to use for output of images -p, --params-output-file: path to file to output params to - -r, --labels-output-file: path to file to output labels to + -x, --labels-output-file: path to file to output labels to -f, --data-file: path to filetensor (.ft) data file (NIST) -l, --label-file: path to filetensor (.ft) labels file (NIST labels) + -c, --ocr-file: path to filetensor (.ft) data file (OCR) + -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels) -a, --prob-font: probability of using a raw font image -b, --prob-captcha: probability of using a captcha image + -e, --prob-ocr: probability of using an ocr image ''' # See run_pipeline.py @@ -254,6 +265,9 @@ # passed to the GIMP executable to be able to use GIMP filters. # Ex: def _main(): + #global DEFAULT_NIST_PATH, DEFAULT_LABEL_PATH, DEFAULT_OCR_PATH, DEFAULT_OCRLABEL_PATH + #global getopt, get_argv + max_complexity = 0.5 # default probability_zero = 0.1 # default output_file_path = None @@ -261,17 +275,21 @@ labels_output_file_path = None nist_path = DEFAULT_NIST_PATH label_path = DEFAULT_LABEL_PATH + ocr_path = DEFAULT_OCR_PATH + ocrlabel_path = DEFAULT_OCRLABEL_PATH prob_font = 0.0 prob_captcha = 0.0 + prob_ocr = 0.0 stop_after = None reload_mode = False try: - opts, args = getopt.getopt(get_argv(), "rm:z:o:p:r:s:f:l:a:b:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "prob-font=", "prob-captcha="]) + opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:e:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr="]) except getopt.GetoptError, err: # print help information and exit: print str(err) # will print something like "option -a not recognized" usage() + pdb.gimp_quit(0) sys.exit(2) for o, a in opts: @@ -287,7 +305,7 @@ output_file_path = a elif o in ('-p', "--params-output-file"): params_output_file_path = a - elif o in ('-r', "--labels-output-file"): + elif o in ('-x', "--labels-output-file"): labels_output_file_path = a elif o in ('-s', "--stop-after"): stop_after = int(a) @@ -295,17 +313,23 @@ nist_path = a elif o in ('-l', "--label-file"): label_path = a + elif o in ('-c', "--ocr-file"): + ocr_path = a + elif o in ('-d', "--ocrlabel-file"): + ocrlabel_path = a elif o in ('-a', "--prob-font"): prob_font = float(a) elif o in ('-b', "--prob-captcha"): prob_captcha = float(a) + elif o in ('-e', "--prob-ocr"): + prob_ocr = float(a) else: assert False, "unhandled option" if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: print "Must specify the three output files." - print usage() + pdb.gimp_quit(0) sys.exit(2) if reload_mode: @@ -320,12 +344,12 @@ img_it = debug_images_iterator(debug_images) ''' else: - nist = NistData(nist_path, label_path) - num_img = nist.dim[0] + nist = NistData(nist_path, label_path, ocr_path, ocrlabel_path) + num_img = 819200 # 800 Mb file if stop_after: num_img = stop_after pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) - img_it = nist_supp_iterator(nist, prob_font, prob_captcha, num_img) + img_it = nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img) cpx_it = range_complexity_iterator(probability_zero, max_complexity) pl.run(img_it, cpx_it)