Mercurial > ift6266
comparison data_generation/transformations/pipeline.py @ 167:1f5937e9e530
More moves - transformations into data_generation, added "deep" folder
author | Dumitru Erhan <dumitru.erhan@gmail.com> |
---|---|
date | Fri, 26 Feb 2010 14:15:38 -0500 |
parents | transformations/pipeline.py@6f3b866c0182 |
children |
comparison
equal
deleted
inserted
replaced
166:17ae5a1a4dd1 | 167:1f5937e9e530 |
---|---|
1 #!/usr/bin/python | |
2 # coding: utf-8 | |
3 | |
4 from __future__ import with_statement | |
5 | |
6 # This is intended to be run as a GIMP script | |
7 #from gimpfu import * | |
8 | |
9 import sys, os, getopt | |
10 import numpy | |
11 import filetensor as ft | |
12 import random | |
13 | |
14 # To debug locally, also call with -s 100 (to stop after ~100) | |
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash | |
16 # if, lucky like me, you have an age-old laptop creaking from everywhere) | |
17 DEBUG = False | |
18 DEBUG_X = False | |
19 if DEBUG: | |
20 DEBUG_X = False # Debug under X (pylab.show()) | |
21 | |
22 DEBUG_IMAGES_PATH = None | |
23 if DEBUG: | |
24 # UNTESTED YET | |
25 # To avoid loading NIST if you don't have it handy | |
26 # (use with debug_images_iterator(), see main()) | |
27 # To use NIST, leave as = None | |
28 DEBUG_IMAGES_PATH = None#'/home/francois/Desktop/debug_images' | |
29 | |
30 # Directory where to dump images to visualize results | |
31 # (create it, otherwise it'll crash) | |
32 DEBUG_OUTPUT_DIR = 'debug_out' | |
33 | |
34 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft' | |
35 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft' | |
36 DEFAULT_OCR_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-shuffled.ft' | |
37 DEFAULT_OCRLABEL_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-labels-shuffled.ft' | |
38 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] | |
39 | |
40 # PARSE COMMAND LINE ARGUMENTS | |
41 def get_argv(): | |
42 with open(ARGS_FILE) as f: | |
43 args = [l.rstrip() for l in f.readlines()] | |
44 return args | |
45 | |
46 def usage(): | |
47 print ''' | |
48 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] | |
49 -m, --max-complexity: max complexity to generate for an image | |
50 -z, --probability-zero: probability of using complexity=0 for an image | |
51 -o, --output-file: full path to file to use for output of images | |
52 -p, --params-output-file: path to file to output params to | |
53 -x, --labels-output-file: path to file to output labels to | |
54 -f, --data-file: path to filetensor (.ft) data file (NIST) | |
55 -l, --label-file: path to filetensor (.ft) labels file (NIST labels) | |
56 -c, --ocr-file: path to filetensor (.ft) data file (OCR) | |
57 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels) | |
58 -a, --prob-font: probability of using a raw font image | |
59 -b, --prob-captcha: probability of using a captcha image | |
60 -g, --prob-ocr: probability of using an ocr image | |
61 -y, --seed: the job seed | |
62 ''' | |
63 | |
64 try: | |
65 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:g:y:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", | |
66 "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr=", "seed="]) | |
67 except getopt.GetoptError, err: | |
68 # print help information and exit: | |
69 print str(err) # will print something like "option -a not recognized" | |
70 usage() | |
71 pdb.gimp_quit(0) | |
72 sys.exit(2) | |
73 | |
74 for o, a in opts: | |
75 if o in ('-y','--seed'): | |
76 random.seed(int(a)) | |
77 numpy.random.seed(int(a)) | |
78 | |
79 if DEBUG_X: | |
80 import pylab | |
81 pylab.ion() | |
82 | |
83 from PoivreSel import PoivreSel | |
84 from thick import Thick | |
85 from BruitGauss import BruitGauss | |
86 from DistorsionGauss import DistorsionGauss | |
87 from PermutPixel import PermutPixel | |
88 from gimp_script import GIMP1 | |
89 from Rature import Rature | |
90 from contrast import Contrast | |
91 from local_elastic_distortions import LocalElasticDistorter | |
92 from slant import Slant | |
93 from Occlusion import Occlusion | |
94 from add_background_image import AddBackground | |
95 from affine_transform import AffineTransformation | |
96 from ttf2jpg import ttf2jpg | |
97 from Facade import generateCaptcha | |
98 | |
99 if DEBUG: | |
100 from visualizer import Visualizer | |
101 # Either put the visualizer as in the MODULES_INSTANCES list | |
102 # after each module you want to visualize, or in the | |
103 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant) | |
104 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False) | |
105 | |
106 ###---------------------order of transformation module | |
107 MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()] | |
108 | |
109 # These should have a "after_transform_callback(self, image)" method | |
110 # (called after each call to transform_image in a module) | |
111 AFTER_EACH_MODULE_HOOK = [] | |
112 if DEBUG: | |
113 AFTER_EACH_MODULE_HOOK = [VISUALIZER] | |
114 | |
115 # These should have a "end_transform_callback(self, final_image" method | |
116 # (called after all modules have been called) | |
117 END_TRANSFORM_HOOK = [] | |
118 if DEBUG: | |
119 END_TRANSFORM_HOOK = [VISUALIZER] | |
120 | |
121 class Pipeline(): | |
122 def __init__(self, modules, num_img, image_size=(32,32)): | |
123 self.modules = modules | |
124 self.num_img = num_img | |
125 self.num_params_stored = 0 | |
126 self.image_size = image_size | |
127 | |
128 self.init_memory() | |
129 | |
130 def init_num_params_stored(self): | |
131 # just a dummy call to regenerate_parameters() to get the | |
132 # real number of params (only those which are stored) | |
133 self.num_params_stored = 0 | |
134 for m in self.modules: | |
135 self.num_params_stored += len(m.regenerate_parameters(0.0)) | |
136 | |
137 def init_memory(self): | |
138 self.init_num_params_stored() | |
139 | |
140 total = self.num_img | |
141 num_px = self.image_size[0] * self.image_size[1] | |
142 | |
143 self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8) | |
144 # +1 to store complexity | |
145 self.params = numpy.empty((total, self.num_params_stored+len(self.modules))) | |
146 self.res_labels = numpy.empty(total, dtype=numpy.int32) | |
147 | |
148 def run(self, img_iterator, complexity_iterator): | |
149 img_size = self.image_size | |
150 | |
151 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 | |
152 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 | |
153 | |
154 for img_no, (img, label) in enumerate(img_iterator): | |
155 sys.stdout.flush() | |
156 | |
157 global_idx = img_no | |
158 | |
159 img = img.reshape(img_size) | |
160 | |
161 param_idx = 0 | |
162 mod_idx = 0 | |
163 for mod in self.modules: | |
164 # This used to be done _per batch_, | |
165 # ie. out of the "for img" loop | |
166 complexity = complexity_iterator.next() | |
167 #better to do a complexity sampling for each transformations in order to have more variability | |
168 #otherwise a lot of images similar to the source are generated (i.e. when complexity is close to 0 (1/8 of the time)) | |
169 #we need to save the complexity of each transformations and the sum of these complexity is a good indicator of the overall | |
170 #complexity | |
171 self.params[global_idx, mod_idx] = complexity | |
172 mod_idx += 1 | |
173 | |
174 p = mod.regenerate_parameters(complexity) | |
175 self.params[global_idx, param_idx+len(self.modules):param_idx+len(p)+len(self.modules)] = p | |
176 param_idx += len(p) | |
177 | |
178 img = mod.transform_image(img) | |
179 | |
180 if should_hook_after_each: | |
181 for hook in AFTER_EACH_MODULE_HOOK: | |
182 hook.after_transform_callback(img) | |
183 | |
184 self.res_data[global_idx] = \ | |
185 img.reshape((img_size[0] * img_size[1],))*255 | |
186 self.res_labels[global_idx] = label | |
187 | |
188 if should_hook_at_the_end: | |
189 for hook in END_TRANSFORM_HOOK: | |
190 hook.end_transform_callback(img) | |
191 | |
192 def write_output(self, output_file_path, params_output_file_path, labels_output_file_path): | |
193 with open(output_file_path, 'wb') as f: | |
194 ft.write(f, self.res_data) | |
195 | |
196 numpy.save(params_output_file_path, self.params) | |
197 | |
198 with open(labels_output_file_path, 'wb') as f: | |
199 ft.write(f, self.res_labels) | |
200 | |
201 | |
202 ############################################################################## | |
203 # COMPLEXITY ITERATORS | |
204 # They're called once every img, to get the complexity to use for that img | |
205 # they must be infinite (should never throw StopIteration when calling next()) | |
206 | |
207 # probability of generating 0 complexity, otherwise | |
208 # uniform over 0.0-max_complexity | |
209 def range_complexity_iterator(probability_zero, max_complexity): | |
210 assert max_complexity <= 1.0 | |
211 n = numpy.random.uniform(0.0, 1.0) | |
212 while True: | |
213 if n < probability_zero: | |
214 yield 0.0 | |
215 else: | |
216 yield numpy.random.uniform(0.0, max_complexity) | |
217 | |
218 ############################################################################## | |
219 # DATA ITERATORS | |
220 # They can be used to interleave different data sources etc. | |
221 | |
222 ''' | |
223 # Following code (DebugImages and iterator) is untested | |
224 | |
225 def load_image(filepath): | |
226 _RGB_TO_GRAYSCALE = [0.3, 0.59, 0.11, 0.0] | |
227 img = Image.open(filepath) | |
228 img = numpy.asarray(img) | |
229 if len(img.shape) > 2: | |
230 img = (img * _RGB_TO_GRAYSCALE).sum(axis=2) | |
231 return (img / 255.0).astype('float') | |
232 | |
233 class DebugImages(): | |
234 def __init__(self, images_dir_path): | |
235 import glob, os.path | |
236 self.filelist = glob.glob(os.path.join(images_dir_path, "*.png")) | |
237 | |
238 def debug_images_iterator(debug_images): | |
239 for path in debug_images.filelist: | |
240 yield load_image(path) | |
241 ''' | |
242 | |
243 class NistData(): | |
244 def __init__(self, nist_path, label_path, ocr_path, ocrlabel_path): | |
245 self.train_data = open(nist_path, 'rb') | |
246 self.train_labels = open(label_path, 'rb') | |
247 self.dim = tuple(ft._read_header(self.train_data)[3]) | |
248 # in order to seek to the beginning of the file | |
249 self.train_data.close() | |
250 self.train_data = open(nist_path, 'rb') | |
251 self.ocr_data = open(ocr_path, 'rb') | |
252 self.ocr_labels = open(ocrlabel_path, 'rb') | |
253 | |
254 # cet iterator load tout en ram | |
255 def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img): | |
256 img = ft.read(nist.train_data) | |
257 labels = ft.read(nist.train_labels) | |
258 if prob_ocr: | |
259 ocr_img = ft.read(nist.ocr_data) | |
260 ocr_labels = ft.read(nist.ocr_labels) | |
261 ttf = ttf2jpg() | |
262 L = [chr(ord('0')+x) for x in range(10)] + [chr(ord('A')+x) for x in range(26)] + [chr(ord('a')+x) for x in range(26)] | |
263 | |
264 for i in xrange(num_img): | |
265 r = numpy.random.rand() | |
266 if r <= prob_font: | |
267 yield ttf.generate_image() | |
268 elif r <=prob_font + prob_captcha: | |
269 (arr, charac) = generateCaptcha(0,1) | |
270 yield arr.astype(numpy.float32)/255, L.index(charac[0]) | |
271 elif r <= prob_font + prob_captcha + prob_ocr: | |
272 j = numpy.random.randint(len(ocr_labels)) | |
273 yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j] | |
274 else: | |
275 j = numpy.random.randint(len(labels)) | |
276 yield img[j].astype(numpy.float32)/255, labels[j] | |
277 | |
278 | |
279 # Mostly for debugging, for the moment, just to see if we can | |
280 # reload the images and parameters. | |
281 def reload(output_file_path, params_output_file_path): | |
282 images_ft = open(output_file_path, 'rb') | |
283 images_ft_dim = tuple(ft._read_header(images_ft)[3]) | |
284 | |
285 print "Images dimensions: ", images_ft_dim | |
286 | |
287 params = numpy.load(params_output_file_path) | |
288 | |
289 print "Params dimensions: ", params.shape | |
290 print params | |
291 | |
292 | |
293 ############################################################################## | |
294 # MAIN | |
295 | |
296 | |
297 # Might be called locally or through dbidispatch. In all cases it should be | |
298 # passed to the GIMP executable to be able to use GIMP filters. | |
299 # Ex: | |
300 def _main(): | |
301 #global DEFAULT_NIST_PATH, DEFAULT_LABEL_PATH, DEFAULT_OCR_PATH, DEFAULT_OCRLABEL_PATH | |
302 #global getopt, get_argv | |
303 | |
304 max_complexity = 0.5 # default | |
305 probability_zero = 0.1 # default | |
306 output_file_path = None | |
307 params_output_file_path = None | |
308 labels_output_file_path = None | |
309 nist_path = DEFAULT_NIST_PATH | |
310 label_path = DEFAULT_LABEL_PATH | |
311 ocr_path = DEFAULT_OCR_PATH | |
312 ocrlabel_path = DEFAULT_OCRLABEL_PATH | |
313 prob_font = 0.0 | |
314 prob_captcha = 0.0 | |
315 prob_ocr = 0.0 | |
316 stop_after = None | |
317 reload_mode = False | |
318 | |
319 for o, a in opts: | |
320 if o in ('-m', '--max-complexity'): | |
321 max_complexity = float(a) | |
322 assert max_complexity >= 0.0 and max_complexity <= 1.0 | |
323 elif o in ('-r', '--reload'): | |
324 reload_mode = True | |
325 elif o in ("-z", "--probability-zero"): | |
326 probability_zero = float(a) | |
327 assert probability_zero >= 0.0 and probability_zero <= 1.0 | |
328 elif o in ("-o", "--output-file"): | |
329 output_file_path = a | |
330 elif o in ('-p', "--params-output-file"): | |
331 params_output_file_path = a | |
332 elif o in ('-x', "--labels-output-file"): | |
333 labels_output_file_path = a | |
334 elif o in ('-s', "--stop-after"): | |
335 stop_after = int(a) | |
336 elif o in ('-f', "--data-file"): | |
337 nist_path = a | |
338 elif o in ('-l', "--label-file"): | |
339 label_path = a | |
340 elif o in ('-c', "--ocr-file"): | |
341 ocr_path = a | |
342 elif o in ('-d', "--ocrlabel-file"): | |
343 ocrlabel_path = a | |
344 elif o in ('-a', "--prob-font"): | |
345 prob_font = float(a) | |
346 elif o in ('-b', "--prob-captcha"): | |
347 prob_captcha = float(a) | |
348 elif o in ('-g', "--prob-ocr"): | |
349 prob_ocr = float(a) | |
350 elif o in ('-y', "--seed"): | |
351 pass | |
352 else: | |
353 assert False, "unhandled option" | |
354 | |
355 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: | |
356 print "Must specify the three output files." | |
357 usage() | |
358 pdb.gimp_quit(0) | |
359 sys.exit(2) | |
360 | |
361 if reload_mode: | |
362 reload(output_file_path, params_output_file_path) | |
363 else: | |
364 if DEBUG_IMAGES_PATH: | |
365 ''' | |
366 # This code is yet untested | |
367 debug_images = DebugImages(DEBUG_IMAGES_PATH) | |
368 num_img = len(debug_images.filelist) | |
369 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) | |
370 img_it = debug_images_iterator(debug_images) | |
371 ''' | |
372 else: | |
373 nist = NistData(nist_path, label_path, ocr_path, ocrlabel_path) | |
374 num_img = 819200 # 800 Mb file | |
375 if stop_after: | |
376 num_img = stop_after | |
377 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) | |
378 img_it = nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img) | |
379 | |
380 cpx_it = range_complexity_iterator(probability_zero, max_complexity) | |
381 pl.run(img_it, cpx_it) | |
382 pl.write_output(output_file_path, params_output_file_path, labels_output_file_path) | |
383 | |
384 _main() | |
385 | |
386 if DEBUG_X: | |
387 pylab.ioff() | |
388 pylab.show() | |
389 | |
390 pdb.gimp_quit(0) | |
391 |