comparison transformations/pipeline.py @ 144:c958941c1b9d

merge
author XavierMuller
date Tue, 23 Feb 2010 18:16:55 -0500
parents 4981c729149c
children 6f3b866c0182
comparison
equal deleted inserted replaced
143:f341a4efb44a 144:c958941c1b9d
53 from slant import Slant 53 from slant import Slant
54 from Occlusion import Occlusion 54 from Occlusion import Occlusion
55 from add_background_image import AddBackground 55 from add_background_image import AddBackground
56 from affine_transform import AffineTransformation 56 from affine_transform import AffineTransformation
57 from ttf2jpg import ttf2jpg 57 from ttf2jpg import ttf2jpg
58 from pycaptcha.Facade import generateCaptcha
58 59
59 if DEBUG: 60 if DEBUG:
60 from visualizer import Visualizer 61 from visualizer import Visualizer
61 # Either put the visualizer as in the MODULES_INSTANCES list 62 # Either put the visualizer as in the MODULES_INSTANCES list
62 # after each module you want to visualize, or in the 63 # after each module you want to visualize, or in the
100 total = self.num_img 101 total = self.num_img
101 num_px = self.image_size[0] * self.image_size[1] 102 num_px = self.image_size[0] * self.image_size[1]
102 103
103 self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8) 104 self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8)
104 # +1 to store complexity 105 # +1 to store complexity
105 self.params = numpy.empty((total, self.num_params_stored+1)) 106 self.params = numpy.empty((total, self.num_params_stored+len(self.modules)))
106 self.res_labels = numpy.empty(total, dtype=numpy.int32) 107 self.res_labels = numpy.empty(total, dtype=numpy.int32)
107 108
108 def run(self, img_iterator, complexity_iterator): 109 def run(self, img_iterator, complexity_iterator):
109 img_size = self.image_size 110 img_size = self.image_size
110 111
111 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 112 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
112 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 113 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0
113 114
114 for img_no, (img, label) in enumerate(img_iterator): 115 for img_no, (img, label) in enumerate(img_iterator):
115 sys.stdout.flush() 116 sys.stdout.flush()
116 complexity = complexity_iterator.next() 117
117
118 global_idx = img_no 118 global_idx = img_no
119 119
120 img = img.reshape(img_size) 120 img = img.reshape(img_size)
121 121
122 param_idx = 1 122 param_idx = 0
123 # store complexity along with other params 123 mod_idx = 0
124 self.params[global_idx, 0] = complexity
125 for mod in self.modules: 124 for mod in self.modules:
126 # This used to be done _per batch_, 125 # This used to be done _per batch_,
127 # ie. out of the "for img" loop 126 # ie. out of the "for img" loop
127 complexity = complexity_iterator.next()
128 #better to do a complexity sampling for each transformations in order to have more variability
129 #otherwise a lot of images similar to the source are generated (i.e. when complexity is close to 0 (1/8 of the time))
130 #we need to save the complexity of each transformations and the sum of these complexity is a good indicator of the overall
131 #complexity
132 self.params[global_idx, mod_idx] = complexity
133 mod_idx += 1
134
128 p = mod.regenerate_parameters(complexity) 135 p = mod.regenerate_parameters(complexity)
129 self.params[global_idx, param_idx:param_idx+len(p)] = p 136 self.params[global_idx, param_idx+len(self.modules):param_idx+len(p)+len(self.modules)] = p
130 param_idx += len(p) 137 param_idx += len(p)
131 138
132 img = mod.transform_image(img) 139 img = mod.transform_image(img)
133 140
134 if should_hook_after_each: 141 if should_hook_after_each:
211 labels = ft.read(nist.train_labels) 218 labels = ft.read(nist.train_labels)
212 if prob_ocr: 219 if prob_ocr:
213 ocr_img = ft.read(nist.ocr_data) 220 ocr_img = ft.read(nist.ocr_data)
214 ocr_labels = ft.read(nist.ocr_labels) 221 ocr_labels = ft.read(nist.ocr_labels)
215 ttf = ttf2jpg() 222 ttf = ttf2jpg()
223 L = [chr(ord('0')+x) for x in range(10)] + [chr(ord('A')+x) for x in range(26)] + [chr(ord('a')+x) for x in range(26)]
216 224
217 for i in xrange(num_img): 225 for i in xrange(num_img):
218 r = numpy.random.rand() 226 r = numpy.random.rand()
219 if r <= prob_font: 227 if r <= prob_font:
220 yield ttf.generate_image() 228 yield ttf.generate_image()
221 elif r <= prob_font + prob_captcha: 229 elif r <=prob_font + prob_captcha:
222 pass #get captcha 230 (arr, charac) = generateCaptcha(0,1)
231 yield arr.astype(numpy.float32)/255, L.index(charac[0])
223 elif r <= prob_font + prob_captcha + prob_ocr: 232 elif r <= prob_font + prob_captcha + prob_ocr:
224 j = numpy.random.randint(len(ocr_labels)) 233 j = numpy.random.randint(len(ocr_labels))
225 yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j] 234 yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j]
226 else: 235 else:
227 j = numpy.random.randint(len(labels)) 236 j = numpy.random.randint(len(labels))
257 -l, --label-file: path to filetensor (.ft) labels file (NIST labels) 266 -l, --label-file: path to filetensor (.ft) labels file (NIST labels)
258 -c, --ocr-file: path to filetensor (.ft) data file (OCR) 267 -c, --ocr-file: path to filetensor (.ft) data file (OCR)
259 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels) 268 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels)
260 -a, --prob-font: probability of using a raw font image 269 -a, --prob-font: probability of using a raw font image
261 -b, --prob-captcha: probability of using a captcha image 270 -b, --prob-captcha: probability of using a captcha image
262 -e, --prob-ocr: probability of using an ocr image 271 -g, --prob-ocr: probability of using an ocr image
263 ''' 272 '''
264 273
265 # See run_pipeline.py 274 # See run_pipeline.py
266 def get_argv(): 275 def get_argv():
267 with open(ARGS_FILE) as f: 276 with open(ARGS_FILE) as f:
289 prob_ocr = 0.0 298 prob_ocr = 0.0
290 stop_after = None 299 stop_after = None
291 reload_mode = False 300 reload_mode = False
292 301
293 try: 302 try:
294 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:e:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr="]) 303 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:g:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=",
304 "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr="])
295 except getopt.GetoptError, err: 305 except getopt.GetoptError, err:
296 # print help information and exit: 306 # print help information and exit:
297 print str(err) # will print something like "option -a not recognized" 307 print str(err) # will print something like "option -a not recognized"
298 usage() 308 usage()
299 pdb.gimp_quit(0) 309 pdb.gimp_quit(0)
326 ocrlabel_path = a 336 ocrlabel_path = a
327 elif o in ('-a', "--prob-font"): 337 elif o in ('-a', "--prob-font"):
328 prob_font = float(a) 338 prob_font = float(a)
329 elif o in ('-b', "--prob-captcha"): 339 elif o in ('-b', "--prob-captcha"):
330 prob_captcha = float(a) 340 prob_captcha = float(a)
331 elif o in ('-e', "--prob-ocr"): 341 elif o in ('-g', "--prob-ocr"):
332 prob_ocr = float(a) 342 prob_ocr = float(a)
333 else: 343 else:
334 assert False, "unhandled option" 344 assert False, "unhandled option"
335 345
336 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: 346 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: