Mercurial > ift6266
comparison transformations/pipeline.py @ 144:c958941c1b9d
merge
author | XavierMuller |
---|---|
date | Tue, 23 Feb 2010 18:16:55 -0500 |
parents | 4981c729149c |
children | 6f3b866c0182 |
comparison
equal
deleted
inserted
replaced
143:f341a4efb44a | 144:c958941c1b9d |
---|---|
53 from slant import Slant | 53 from slant import Slant |
54 from Occlusion import Occlusion | 54 from Occlusion import Occlusion |
55 from add_background_image import AddBackground | 55 from add_background_image import AddBackground |
56 from affine_transform import AffineTransformation | 56 from affine_transform import AffineTransformation |
57 from ttf2jpg import ttf2jpg | 57 from ttf2jpg import ttf2jpg |
58 from pycaptcha.Facade import generateCaptcha | |
58 | 59 |
59 if DEBUG: | 60 if DEBUG: |
60 from visualizer import Visualizer | 61 from visualizer import Visualizer |
61 # Either put the visualizer as in the MODULES_INSTANCES list | 62 # Either put the visualizer as in the MODULES_INSTANCES list |
62 # after each module you want to visualize, or in the | 63 # after each module you want to visualize, or in the |
100 total = self.num_img | 101 total = self.num_img |
101 num_px = self.image_size[0] * self.image_size[1] | 102 num_px = self.image_size[0] * self.image_size[1] |
102 | 103 |
103 self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8) | 104 self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8) |
104 # +1 to store complexity | 105 # +1 to store complexity |
105 self.params = numpy.empty((total, self.num_params_stored+1)) | 106 self.params = numpy.empty((total, self.num_params_stored+len(self.modules))) |
106 self.res_labels = numpy.empty(total, dtype=numpy.int32) | 107 self.res_labels = numpy.empty(total, dtype=numpy.int32) |
107 | 108 |
108 def run(self, img_iterator, complexity_iterator): | 109 def run(self, img_iterator, complexity_iterator): |
109 img_size = self.image_size | 110 img_size = self.image_size |
110 | 111 |
111 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 | 112 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 |
112 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 | 113 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 |
113 | 114 |
114 for img_no, (img, label) in enumerate(img_iterator): | 115 for img_no, (img, label) in enumerate(img_iterator): |
115 sys.stdout.flush() | 116 sys.stdout.flush() |
116 complexity = complexity_iterator.next() | 117 |
117 | |
118 global_idx = img_no | 118 global_idx = img_no |
119 | 119 |
120 img = img.reshape(img_size) | 120 img = img.reshape(img_size) |
121 | 121 |
122 param_idx = 1 | 122 param_idx = 0 |
123 # store complexity along with other params | 123 mod_idx = 0 |
124 self.params[global_idx, 0] = complexity | |
125 for mod in self.modules: | 124 for mod in self.modules: |
126 # This used to be done _per batch_, | 125 # This used to be done _per batch_, |
127 # ie. out of the "for img" loop | 126 # ie. out of the "for img" loop |
127 complexity = complexity_iterator.next() | |
128 #better to do a complexity sampling for each transformations in order to have more variability | |
129 #otherwise a lot of images similar to the source are generated (i.e. when complexity is close to 0 (1/8 of the time)) | |
130 #we need to save the complexity of each transformations and the sum of these complexity is a good indicator of the overall | |
131 #complexity | |
132 self.params[global_idx, mod_idx] = complexity | |
133 mod_idx += 1 | |
134 | |
128 p = mod.regenerate_parameters(complexity) | 135 p = mod.regenerate_parameters(complexity) |
129 self.params[global_idx, param_idx:param_idx+len(p)] = p | 136 self.params[global_idx, param_idx+len(self.modules):param_idx+len(p)+len(self.modules)] = p |
130 param_idx += len(p) | 137 param_idx += len(p) |
131 | 138 |
132 img = mod.transform_image(img) | 139 img = mod.transform_image(img) |
133 | 140 |
134 if should_hook_after_each: | 141 if should_hook_after_each: |
211 labels = ft.read(nist.train_labels) | 218 labels = ft.read(nist.train_labels) |
212 if prob_ocr: | 219 if prob_ocr: |
213 ocr_img = ft.read(nist.ocr_data) | 220 ocr_img = ft.read(nist.ocr_data) |
214 ocr_labels = ft.read(nist.ocr_labels) | 221 ocr_labels = ft.read(nist.ocr_labels) |
215 ttf = ttf2jpg() | 222 ttf = ttf2jpg() |
223 L = [chr(ord('0')+x) for x in range(10)] + [chr(ord('A')+x) for x in range(26)] + [chr(ord('a')+x) for x in range(26)] | |
216 | 224 |
217 for i in xrange(num_img): | 225 for i in xrange(num_img): |
218 r = numpy.random.rand() | 226 r = numpy.random.rand() |
219 if r <= prob_font: | 227 if r <= prob_font: |
220 yield ttf.generate_image() | 228 yield ttf.generate_image() |
221 elif r <= prob_font + prob_captcha: | 229 elif r <=prob_font + prob_captcha: |
222 pass #get captcha | 230 (arr, charac) = generateCaptcha(0,1) |
231 yield arr.astype(numpy.float32)/255, L.index(charac[0]) | |
223 elif r <= prob_font + prob_captcha + prob_ocr: | 232 elif r <= prob_font + prob_captcha + prob_ocr: |
224 j = numpy.random.randint(len(ocr_labels)) | 233 j = numpy.random.randint(len(ocr_labels)) |
225 yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j] | 234 yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j] |
226 else: | 235 else: |
227 j = numpy.random.randint(len(labels)) | 236 j = numpy.random.randint(len(labels)) |
257 -l, --label-file: path to filetensor (.ft) labels file (NIST labels) | 266 -l, --label-file: path to filetensor (.ft) labels file (NIST labels) |
258 -c, --ocr-file: path to filetensor (.ft) data file (OCR) | 267 -c, --ocr-file: path to filetensor (.ft) data file (OCR) |
259 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels) | 268 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels) |
260 -a, --prob-font: probability of using a raw font image | 269 -a, --prob-font: probability of using a raw font image |
261 -b, --prob-captcha: probability of using a captcha image | 270 -b, --prob-captcha: probability of using a captcha image |
262 -e, --prob-ocr: probability of using an ocr image | 271 -g, --prob-ocr: probability of using an ocr image |
263 ''' | 272 ''' |
264 | 273 |
265 # See run_pipeline.py | 274 # See run_pipeline.py |
266 def get_argv(): | 275 def get_argv(): |
267 with open(ARGS_FILE) as f: | 276 with open(ARGS_FILE) as f: |
289 prob_ocr = 0.0 | 298 prob_ocr = 0.0 |
290 stop_after = None | 299 stop_after = None |
291 reload_mode = False | 300 reload_mode = False |
292 | 301 |
293 try: | 302 try: |
294 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:e:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr="]) | 303 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:g:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", |
304 "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr="]) | |
295 except getopt.GetoptError, err: | 305 except getopt.GetoptError, err: |
296 # print help information and exit: | 306 # print help information and exit: |
297 print str(err) # will print something like "option -a not recognized" | 307 print str(err) # will print something like "option -a not recognized" |
298 usage() | 308 usage() |
299 pdb.gimp_quit(0) | 309 pdb.gimp_quit(0) |
326 ocrlabel_path = a | 336 ocrlabel_path = a |
327 elif o in ('-a', "--prob-font"): | 337 elif o in ('-a', "--prob-font"): |
328 prob_font = float(a) | 338 prob_font = float(a) |
329 elif o in ('-b', "--prob-captcha"): | 339 elif o in ('-b', "--prob-captcha"): |
330 prob_captcha = float(a) | 340 prob_captcha = float(a) |
331 elif o in ('-e', "--prob-ocr"): | 341 elif o in ('-g', "--prob-ocr"): |
332 prob_ocr = float(a) | 342 prob_ocr = float(a) |
333 else: | 343 else: |
334 assert False, "unhandled option" | 344 assert False, "unhandled option" |
335 | 345 |
336 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: | 346 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: |