comparison data_generation/pipeline/pipeline.py @ 254:dd2df78fcf47

added option to pipeline and gimp_script to produce NIST-friendly data
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Wed, 17 Mar 2010 13:57:15 -0400
parents 992ca8035a4d
children 6d16a2bf142b
comparison
equal deleted inserted replaced
246:2024368a8d3d 254:dd2df78fcf47
8 8
9 import sys, os, getopt 9 import sys, os, getopt
10 import numpy 10 import numpy
11 import ift6266.data_generation.transformations.filetensor as ft 11 import ift6266.data_generation.transformations.filetensor as ft
12 import random 12 import random
13 import copy
13 14
14 # To debug locally, also call with -s 100 (to stop after ~100) 15 # To debug locally, also call with -s 100 (to stop after ~100)
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash 16 # (otherwise we allocate all needed memory, might be loonnng and/or crash
16 # if, lucky like me, you have an age-old laptop creaking from everywhere) 17 # if, lucky like me, you have an age-old laptop creaking from everywhere)
17 DEBUG = False 18 DEBUG = False
57 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels) 58 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels)
58 -a, --prob-font: probability of using a raw font image 59 -a, --prob-font: probability of using a raw font image
59 -b, --prob-captcha: probability of using a captcha image 60 -b, --prob-captcha: probability of using a captcha image
60 -g, --prob-ocr: probability of using an ocr image 61 -g, --prob-ocr: probability of using an ocr image
61 -y, --seed: the job seed 62 -y, --seed: the job seed
63 -t, --type: [default: 0:full transformations], 1:Nist-friendly transformations
62 ''' 64 '''
63 65
64 try: 66 try:
65 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:g:y:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", 67 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:g:y:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=",
66 "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr=", "seed="]) 68 "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr=", "seed="])
73 75
74 for o, a in opts: 76 for o, a in opts:
75 if o in ('-y','--seed'): 77 if o in ('-y','--seed'):
76 random.seed(int(a)) 78 random.seed(int(a))
77 numpy.random.seed(int(a)) 79 numpy.random.seed(int(a))
80
81 for o, a in opts:
82 if o in ('-t','--type'):
83 type_pipeline = int(a)
84 else:
85 type_pipeline = 0
78 86
79 if DEBUG_X: 87 if DEBUG_X:
80 import pylab 88 import pylab
81 pylab.ion() 89 pylab.ion()
82 90
102 # after each module you want to visualize, or in the 110 # after each module you want to visualize, or in the
103 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant) 111 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant)
104 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False) 112 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False)
105 113
106 ###---------------------order of transformation module 114 ###---------------------order of transformation module
107 MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()] 115 if type_pipeline == 0:
116 MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()]
117 stop_idx = 0
118 if type_pipeline == 1:
119 MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(False),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()]
120 stop_idx = 5
121 #we disable transformation corresponding to MODULE_INSTANCES[stop_idx:] but we still need to apply them on dummy images
122 #in order to be sure to have the same random generator state than with the default pipeline.
123 #This is not optimal (we do more calculus than necessary) but it is a quick hack to produce similar results than previous generation
124
125
108 126
109 # These should have a "after_transform_callback(self, image)" method 127 # These should have a "after_transform_callback(self, image)" method
110 # (called after each call to transform_image in a module) 128 # (called after each call to transform_image in a module)
111 AFTER_EACH_MODULE_HOOK = [] 129 AFTER_EACH_MODULE_HOOK = []
112 if DEBUG: 130 if DEBUG:
153 171
154 for img_no, (img, label) in enumerate(img_iterator): 172 for img_no, (img, label) in enumerate(img_iterator):
155 sys.stdout.flush() 173 sys.stdout.flush()
156 174
157 global_idx = img_no 175 global_idx = img_no
158 176
159 img = img.reshape(img_size) 177 img = img.reshape(img_size)
160 178
161 param_idx = 0 179 param_idx = 0
162 mod_idx = 0 180 mod_idx = 0
163 for mod in self.modules: 181 for mod in self.modules:
172 mod_idx += 1 190 mod_idx += 1
173 191
174 p = mod.regenerate_parameters(complexity) 192 p = mod.regenerate_parameters(complexity)
175 self.params[global_idx, param_idx+len(self.modules):param_idx+len(p)+len(self.modules)] = p 193 self.params[global_idx, param_idx+len(self.modules):param_idx+len(p)+len(self.modules)] = p
176 param_idx += len(p) 194 param_idx += len(p)
177 195
178 img = mod.transform_image(img) 196 if not(stop_idx) or stop_idx > mod_idx:
197 img = mod.transform_image(img)
198 else:
199 tmp = mod.transform_image(copy.copy(img))
200 #this is done to be sure to have the same global random generator state
201 #we don't apply the transformation on the original image but on a copy in case of in-place transformations
179 202
180 if should_hook_after_each: 203 if should_hook_after_each:
181 for hook in AFTER_EACH_MODULE_HOOK: 204 for hook in AFTER_EACH_MODULE_HOOK:
182 hook.after_transform_callback(img) 205 hook.after_transform_callback(img)
183 206
347 prob_captcha = float(a) 370 prob_captcha = float(a)
348 elif o in ('-g', "--prob-ocr"): 371 elif o in ('-g', "--prob-ocr"):
349 prob_ocr = float(a) 372 prob_ocr = float(a)
350 elif o in ('-y', "--seed"): 373 elif o in ('-y', "--seed"):
351 pass 374 pass
375 elif o in ('-t', "--type"):
376 type_pipeline = int(a)
352 else: 377 else:
353 assert False, "unhandled option" 378 assert False, "unhandled option"
354 379
355 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: 380 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None:
356 print "Must specify the three output files." 381 print "Must specify the three output files."