comparison data_generation/transformations/pipeline.py @ 167:1f5937e9e530

More moves - transformations into data_generation, added "deep" folder
author Dumitru Erhan <dumitru.erhan@gmail.com>
date Fri, 26 Feb 2010 14:15:38 -0500
parents transformations/pipeline.py@6f3b866c0182
children
comparison
equal deleted inserted replaced
166:17ae5a1a4dd1 167:1f5937e9e530
1 #!/usr/bin/python
2 # coding: utf-8
3
4 from __future__ import with_statement
5
6 # This is intended to be run as a GIMP script
7 #from gimpfu import *
8
9 import sys, os, getopt
10 import numpy
11 import filetensor as ft
12 import random
13
14 # To debug locally, also call with -s 100 (to stop after ~100)
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash
16 # if, lucky like me, you have an age-old laptop creaking from everywhere)
17 DEBUG = False
18 DEBUG_X = False
19 if DEBUG:
20 DEBUG_X = False # Debug under X (pylab.show())
21
22 DEBUG_IMAGES_PATH = None
23 if DEBUG:
24 # UNTESTED YET
25 # To avoid loading NIST if you don't have it handy
26 # (use with debug_images_iterator(), see main())
27 # To use NIST, leave as = None
28 DEBUG_IMAGES_PATH = None#'/home/francois/Desktop/debug_images'
29
30 # Directory where to dump images to visualize results
31 # (create it, otherwise it'll crash)
32 DEBUG_OUTPUT_DIR = 'debug_out'
33
34 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft'
35 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft'
36 DEFAULT_OCR_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-shuffled.ft'
37 DEFAULT_OCRLABEL_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-labels-shuffled.ft'
38 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
39
40 # PARSE COMMAND LINE ARGUMENTS
41 def get_argv():
42 with open(ARGS_FILE) as f:
43 args = [l.rstrip() for l in f.readlines()]
44 return args
45
46 def usage():
47 print '''
48 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
49 -m, --max-complexity: max complexity to generate for an image
50 -z, --probability-zero: probability of using complexity=0 for an image
51 -o, --output-file: full path to file to use for output of images
52 -p, --params-output-file: path to file to output params to
53 -x, --labels-output-file: path to file to output labels to
54 -f, --data-file: path to filetensor (.ft) data file (NIST)
55 -l, --label-file: path to filetensor (.ft) labels file (NIST labels)
56 -c, --ocr-file: path to filetensor (.ft) data file (OCR)
57 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels)
58 -a, --prob-font: probability of using a raw font image
59 -b, --prob-captcha: probability of using a captcha image
60 -g, --prob-ocr: probability of using an ocr image
61 -y, --seed: the job seed
62 '''
63
64 try:
65 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:g:y:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=",
66 "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr=", "seed="])
67 except getopt.GetoptError, err:
68 # print help information and exit:
69 print str(err) # will print something like "option -a not recognized"
70 usage()
71 pdb.gimp_quit(0)
72 sys.exit(2)
73
74 for o, a in opts:
75 if o in ('-y','--seed'):
76 random.seed(int(a))
77 numpy.random.seed(int(a))
78
79 if DEBUG_X:
80 import pylab
81 pylab.ion()
82
83 from PoivreSel import PoivreSel
84 from thick import Thick
85 from BruitGauss import BruitGauss
86 from DistorsionGauss import DistorsionGauss
87 from PermutPixel import PermutPixel
88 from gimp_script import GIMP1
89 from Rature import Rature
90 from contrast import Contrast
91 from local_elastic_distortions import LocalElasticDistorter
92 from slant import Slant
93 from Occlusion import Occlusion
94 from add_background_image import AddBackground
95 from affine_transform import AffineTransformation
96 from ttf2jpg import ttf2jpg
97 from Facade import generateCaptcha
98
99 if DEBUG:
100 from visualizer import Visualizer
101 # Either put the visualizer as in the MODULES_INSTANCES list
102 # after each module you want to visualize, or in the
103 # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant)
104 VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False)
105
106 ###---------------------order of transformation module
107 MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()]
108
109 # These should have a "after_transform_callback(self, image)" method
110 # (called after each call to transform_image in a module)
111 AFTER_EACH_MODULE_HOOK = []
112 if DEBUG:
113 AFTER_EACH_MODULE_HOOK = [VISUALIZER]
114
115 # These should have a "end_transform_callback(self, final_image" method
116 # (called after all modules have been called)
117 END_TRANSFORM_HOOK = []
118 if DEBUG:
119 END_TRANSFORM_HOOK = [VISUALIZER]
120
121 class Pipeline():
122 def __init__(self, modules, num_img, image_size=(32,32)):
123 self.modules = modules
124 self.num_img = num_img
125 self.num_params_stored = 0
126 self.image_size = image_size
127
128 self.init_memory()
129
130 def init_num_params_stored(self):
131 # just a dummy call to regenerate_parameters() to get the
132 # real number of params (only those which are stored)
133 self.num_params_stored = 0
134 for m in self.modules:
135 self.num_params_stored += len(m.regenerate_parameters(0.0))
136
137 def init_memory(self):
138 self.init_num_params_stored()
139
140 total = self.num_img
141 num_px = self.image_size[0] * self.image_size[1]
142
143 self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8)
144 # +1 to store complexity
145 self.params = numpy.empty((total, self.num_params_stored+len(self.modules)))
146 self.res_labels = numpy.empty(total, dtype=numpy.int32)
147
148 def run(self, img_iterator, complexity_iterator):
149 img_size = self.image_size
150
151 should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
152 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0
153
154 for img_no, (img, label) in enumerate(img_iterator):
155 sys.stdout.flush()
156
157 global_idx = img_no
158
159 img = img.reshape(img_size)
160
161 param_idx = 0
162 mod_idx = 0
163 for mod in self.modules:
164 # This used to be done _per batch_,
165 # ie. out of the "for img" loop
166 complexity = complexity_iterator.next()
167 #better to do a complexity sampling for each transformations in order to have more variability
168 #otherwise a lot of images similar to the source are generated (i.e. when complexity is close to 0 (1/8 of the time))
169 #we need to save the complexity of each transformations and the sum of these complexity is a good indicator of the overall
170 #complexity
171 self.params[global_idx, mod_idx] = complexity
172 mod_idx += 1
173
174 p = mod.regenerate_parameters(complexity)
175 self.params[global_idx, param_idx+len(self.modules):param_idx+len(p)+len(self.modules)] = p
176 param_idx += len(p)
177
178 img = mod.transform_image(img)
179
180 if should_hook_after_each:
181 for hook in AFTER_EACH_MODULE_HOOK:
182 hook.after_transform_callback(img)
183
184 self.res_data[global_idx] = \
185 img.reshape((img_size[0] * img_size[1],))*255
186 self.res_labels[global_idx] = label
187
188 if should_hook_at_the_end:
189 for hook in END_TRANSFORM_HOOK:
190 hook.end_transform_callback(img)
191
192 def write_output(self, output_file_path, params_output_file_path, labels_output_file_path):
193 with open(output_file_path, 'wb') as f:
194 ft.write(f, self.res_data)
195
196 numpy.save(params_output_file_path, self.params)
197
198 with open(labels_output_file_path, 'wb') as f:
199 ft.write(f, self.res_labels)
200
201
202 ##############################################################################
203 # COMPLEXITY ITERATORS
204 # They're called once every img, to get the complexity to use for that img
205 # they must be infinite (should never throw StopIteration when calling next())
206
207 # probability of generating 0 complexity, otherwise
208 # uniform over 0.0-max_complexity
209 def range_complexity_iterator(probability_zero, max_complexity):
210 assert max_complexity <= 1.0
211 n = numpy.random.uniform(0.0, 1.0)
212 while True:
213 if n < probability_zero:
214 yield 0.0
215 else:
216 yield numpy.random.uniform(0.0, max_complexity)
217
218 ##############################################################################
219 # DATA ITERATORS
220 # They can be used to interleave different data sources etc.
221
222 '''
223 # Following code (DebugImages and iterator) is untested
224
225 def load_image(filepath):
226 _RGB_TO_GRAYSCALE = [0.3, 0.59, 0.11, 0.0]
227 img = Image.open(filepath)
228 img = numpy.asarray(img)
229 if len(img.shape) > 2:
230 img = (img * _RGB_TO_GRAYSCALE).sum(axis=2)
231 return (img / 255.0).astype('float')
232
233 class DebugImages():
234 def __init__(self, images_dir_path):
235 import glob, os.path
236 self.filelist = glob.glob(os.path.join(images_dir_path, "*.png"))
237
238 def debug_images_iterator(debug_images):
239 for path in debug_images.filelist:
240 yield load_image(path)
241 '''
242
243 class NistData():
244 def __init__(self, nist_path, label_path, ocr_path, ocrlabel_path):
245 self.train_data = open(nist_path, 'rb')
246 self.train_labels = open(label_path, 'rb')
247 self.dim = tuple(ft._read_header(self.train_data)[3])
248 # in order to seek to the beginning of the file
249 self.train_data.close()
250 self.train_data = open(nist_path, 'rb')
251 self.ocr_data = open(ocr_path, 'rb')
252 self.ocr_labels = open(ocrlabel_path, 'rb')
253
254 # cet iterator load tout en ram
255 def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img):
256 img = ft.read(nist.train_data)
257 labels = ft.read(nist.train_labels)
258 if prob_ocr:
259 ocr_img = ft.read(nist.ocr_data)
260 ocr_labels = ft.read(nist.ocr_labels)
261 ttf = ttf2jpg()
262 L = [chr(ord('0')+x) for x in range(10)] + [chr(ord('A')+x) for x in range(26)] + [chr(ord('a')+x) for x in range(26)]
263
264 for i in xrange(num_img):
265 r = numpy.random.rand()
266 if r <= prob_font:
267 yield ttf.generate_image()
268 elif r <=prob_font + prob_captcha:
269 (arr, charac) = generateCaptcha(0,1)
270 yield arr.astype(numpy.float32)/255, L.index(charac[0])
271 elif r <= prob_font + prob_captcha + prob_ocr:
272 j = numpy.random.randint(len(ocr_labels))
273 yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j]
274 else:
275 j = numpy.random.randint(len(labels))
276 yield img[j].astype(numpy.float32)/255, labels[j]
277
278
279 # Mostly for debugging, for the moment, just to see if we can
280 # reload the images and parameters.
281 def reload(output_file_path, params_output_file_path):
282 images_ft = open(output_file_path, 'rb')
283 images_ft_dim = tuple(ft._read_header(images_ft)[3])
284
285 print "Images dimensions: ", images_ft_dim
286
287 params = numpy.load(params_output_file_path)
288
289 print "Params dimensions: ", params.shape
290 print params
291
292
293 ##############################################################################
294 # MAIN
295
296
297 # Might be called locally or through dbidispatch. In all cases it should be
298 # passed to the GIMP executable to be able to use GIMP filters.
299 # Ex:
300 def _main():
301 #global DEFAULT_NIST_PATH, DEFAULT_LABEL_PATH, DEFAULT_OCR_PATH, DEFAULT_OCRLABEL_PATH
302 #global getopt, get_argv
303
304 max_complexity = 0.5 # default
305 probability_zero = 0.1 # default
306 output_file_path = None
307 params_output_file_path = None
308 labels_output_file_path = None
309 nist_path = DEFAULT_NIST_PATH
310 label_path = DEFAULT_LABEL_PATH
311 ocr_path = DEFAULT_OCR_PATH
312 ocrlabel_path = DEFAULT_OCRLABEL_PATH
313 prob_font = 0.0
314 prob_captcha = 0.0
315 prob_ocr = 0.0
316 stop_after = None
317 reload_mode = False
318
319 for o, a in opts:
320 if o in ('-m', '--max-complexity'):
321 max_complexity = float(a)
322 assert max_complexity >= 0.0 and max_complexity <= 1.0
323 elif o in ('-r', '--reload'):
324 reload_mode = True
325 elif o in ("-z", "--probability-zero"):
326 probability_zero = float(a)
327 assert probability_zero >= 0.0 and probability_zero <= 1.0
328 elif o in ("-o", "--output-file"):
329 output_file_path = a
330 elif o in ('-p', "--params-output-file"):
331 params_output_file_path = a
332 elif o in ('-x', "--labels-output-file"):
333 labels_output_file_path = a
334 elif o in ('-s', "--stop-after"):
335 stop_after = int(a)
336 elif o in ('-f', "--data-file"):
337 nist_path = a
338 elif o in ('-l', "--label-file"):
339 label_path = a
340 elif o in ('-c', "--ocr-file"):
341 ocr_path = a
342 elif o in ('-d', "--ocrlabel-file"):
343 ocrlabel_path = a
344 elif o in ('-a', "--prob-font"):
345 prob_font = float(a)
346 elif o in ('-b', "--prob-captcha"):
347 prob_captcha = float(a)
348 elif o in ('-g', "--prob-ocr"):
349 prob_ocr = float(a)
350 elif o in ('-y', "--seed"):
351 pass
352 else:
353 assert False, "unhandled option"
354
355 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None:
356 print "Must specify the three output files."
357 usage()
358 pdb.gimp_quit(0)
359 sys.exit(2)
360
361 if reload_mode:
362 reload(output_file_path, params_output_file_path)
363 else:
364 if DEBUG_IMAGES_PATH:
365 '''
366 # This code is yet untested
367 debug_images = DebugImages(DEBUG_IMAGES_PATH)
368 num_img = len(debug_images.filelist)
369 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
370 img_it = debug_images_iterator(debug_images)
371 '''
372 else:
373 nist = NistData(nist_path, label_path, ocr_path, ocrlabel_path)
374 num_img = 819200 # 800 Mb file
375 if stop_after:
376 num_img = stop_after
377 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
378 img_it = nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img)
379
380 cpx_it = range_complexity_iterator(probability_zero, max_complexity)
381 pl.run(img_it, cpx_it)
382 pl.write_output(output_file_path, params_output_file_path, labels_output_file_path)
383
384 _main()
385
386 if DEBUG_X:
387 pylab.ioff()
388 pylab.show()
389
390 pdb.gimp_quit(0)
391