diff transformations/pipeline.py @ 67:5e448ea129b3

Ajouté la source (optionnelle) de données OCR Autriche avec une probabilité passée en argument
author boulanni <nicolas_boulanger@hotmail.com>
date Tue, 09 Feb 2010 21:33:57 -0500
parents 1afa95285b9c
children 95c491bb5662
line wrap: on
line diff
--- a/transformations/pipeline.py	Tue Feb 09 18:45:35 2010 -0500
+++ b/transformations/pipeline.py	Tue Feb 09 21:33:57 2010 -0500
@@ -14,7 +14,7 @@
 # To debug locally, also call with -s 100 (to stop after ~100)
 # (otherwise we allocate all needed memory, might be loonnng and/or crash
 # if, lucky like me, you have an age-old laptop creaking from everywhere)
-DEBUG = True
+DEBUG = False
 DEBUG_X = False
 if DEBUG:
     DEBUG_X = False # Debug under X (pylab.show())
@@ -33,6 +33,8 @@
 
 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft'
 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft'
+DEFAULT_OCR_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-shuffled.ft'
+DEFAULT_OCRLABEL_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-labels-shuffled.ft'
 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
 
 if DEBUG_X:
@@ -188,28 +190,34 @@
 '''
 
 class NistData():
-    def __init__(self, nist_path, label_path):
+    def __init__(self, nist_path, label_path, ocr_path, ocrlabel_path):
         self.train_data = open(nist_path, 'rb')
         self.train_labels = open(label_path, 'rb')
         self.dim = tuple(ft._read_header(self.train_data)[3])
         # in order to seek to the beginning of the file
         self.train_data.close()
         self.train_data = open(nist_path, 'rb')
-
+        self.ocr_data = open(ocr_path, 'rb')
+        self.ocr_labels = open(ocrlabel_path, 'rb')
 
-def nist_supp_iterator(nist, prob_font, prob_captcha, num_img):
-    subtensor = slice(0, num_img)
-    img = ft.read(nist.train_data, subtensor).astype(numpy.float32)/255
-    labels = ft.read(nist.train_labels, subtensor)
+def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img):
+    img = ft.read(nist.train_data).astype(numpy.float32)/255
+    labels = ft.read(nist.train_labels)
+    if prob_ocr:
+        ocr_img = ft.read(nist.ocr_data).astype(numpy.float32)/255
+        ocr_labels = ft.read(nist.ocr_labels)
 
     for i in xrange(num_img):
         r = numpy.random.rand()
-        if r<= prob_font:
+        if r <= prob_font:
             pass #get font
-        elif r<= prob_font + prob_captcha:
+        elif r <= prob_font + prob_captcha:
             pass #get captcha
+        elif r <= prob_font + prob_captcha + prob_ocr:
+            j = numpy.random.randint(len(ocr_labels))
+            yield ocr_img[j], ocr_labels[j]
         else:
-            j = numpy.random.randint(num_img)
+            j = numpy.random.randint(len(labels))
             yield img[j], labels[j]
 
 
@@ -237,11 +245,14 @@
     -z, --probability-zero: probability of using complexity=0 for an image
     -o, --output-file: full path to file to use for output of images
     -p, --params-output-file: path to file to output params to
-    -r, --labels-output-file: path to file to output labels to
+    -x, --labels-output-file: path to file to output labels to
     -f, --data-file: path to filetensor (.ft) data file (NIST)
     -l, --label-file: path to filetensor (.ft) labels file (NIST labels)
+    -c, --ocr-file: path to filetensor (.ft) data file (OCR)
+    -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels)
     -a, --prob-font: probability of using a raw font image
     -b, --prob-captcha: probability of using a captcha image
+    -e, --prob-ocr: probability of using an ocr image
     '''
 
 # See run_pipeline.py
@@ -254,6 +265,9 @@
 # passed to the GIMP executable to be able to use GIMP filters.
 # Ex: 
 def _main():
+    #global DEFAULT_NIST_PATH, DEFAULT_LABEL_PATH, DEFAULT_OCR_PATH, DEFAULT_OCRLABEL_PATH
+    #global getopt, get_argv
+
     max_complexity = 0.5 # default
     probability_zero = 0.1 # default
     output_file_path = None
@@ -261,17 +275,21 @@
     labels_output_file_path = None
     nist_path = DEFAULT_NIST_PATH
     label_path = DEFAULT_LABEL_PATH
+    ocr_path = DEFAULT_OCR_PATH
+    ocrlabel_path = DEFAULT_OCRLABEL_PATH
     prob_font = 0.0
     prob_captcha = 0.0
+    prob_ocr = 0.0
     stop_after = None
     reload_mode = False
 
     try:
-        opts, args = getopt.getopt(get_argv(), "rm:z:o:p:r:s:f:l:a:b:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "prob-font=", "prob-captcha="])
+        opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:e:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr="])
     except getopt.GetoptError, err:
         # print help information and exit:
         print str(err) # will print something like "option -a not recognized"
         usage()
+        pdb.gimp_quit(0)
         sys.exit(2)
 
     for o, a in opts:
@@ -287,7 +305,7 @@
             output_file_path = a
         elif o in ('-p', "--params-output-file"):
             params_output_file_path = a
-        elif o in ('-r', "--labels-output-file"):
+        elif o in ('-x', "--labels-output-file"):
             labels_output_file_path = a
         elif o in ('-s', "--stop-after"):
             stop_after = int(a)
@@ -295,17 +313,23 @@
             nist_path = a
         elif o in ('-l', "--label-file"):
             label_path = a
+        elif o in ('-c', "--ocr-file"):
+            ocr_path = a
+        elif o in ('-d', "--ocrlabel-file"):
+            ocrlabel_path = a
         elif o in ('-a', "--prob-font"):
             prob_font = float(a)
         elif o in ('-b', "--prob-captcha"):
             prob_captcha = float(a)
+        elif o in ('-e', "--prob-ocr"):
+            prob_ocr = float(a)
         else:
             assert False, "unhandled option"
 
     if output_file_path == None or params_output_file_path == None or labels_output_file_path == None:
         print "Must specify the three output files."
-        print
         usage()
+        pdb.gimp_quit(0)
         sys.exit(2)
 
     if reload_mode:
@@ -320,12 +344,12 @@
             img_it = debug_images_iterator(debug_images)
             '''
         else:
-            nist = NistData(nist_path, label_path)
-            num_img = nist.dim[0]
+            nist = NistData(nist_path, label_path, ocr_path, ocrlabel_path)
+            num_img = 819200 # 800 Mb file
             if stop_after:
                 num_img = stop_after
             pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
-            img_it = nist_supp_iterator(nist, prob_font, prob_captcha, num_img)
+            img_it = nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img)
 
         cpx_it = range_complexity_iterator(probability_zero, max_complexity)
         pl.run(img_it, cpx_it)