Mercurial > ift6266
comparison transformations/pipeline.py @ 67:5e448ea129b3
Ajouté la source (optionnelle) de données OCR Autriche avec une probabilité passée en argument
author | boulanni <nicolas_boulanger@hotmail.com> |
---|---|
date | Tue, 09 Feb 2010 21:33:57 -0500 |
parents | 1afa95285b9c |
children | 95c491bb5662 |
comparison
equal
deleted
inserted
replaced
66:bf83682c827b | 67:5e448ea129b3 |
---|---|
12 import random | 12 import random |
13 | 13 |
14 # To debug locally, also call with -s 100 (to stop after ~100) | 14 # To debug locally, also call with -s 100 (to stop after ~100) |
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash | 15 # (otherwise we allocate all needed memory, might be loonnng and/or crash |
16 # if, lucky like me, you have an age-old laptop creaking from everywhere) | 16 # if, lucky like me, you have an age-old laptop creaking from everywhere) |
17 DEBUG = True | 17 DEBUG = False |
18 DEBUG_X = False | 18 DEBUG_X = False |
19 if DEBUG: | 19 if DEBUG: |
20 DEBUG_X = False # Debug under X (pylab.show()) | 20 DEBUG_X = False # Debug under X (pylab.show()) |
21 | 21 |
22 DEBUG_IMAGES_PATH = None | 22 DEBUG_IMAGES_PATH = None |
31 # (create it, otherwise it'll crash) | 31 # (create it, otherwise it'll crash) |
32 DEBUG_OUTPUT_DIR = 'debug_out' | 32 DEBUG_OUTPUT_DIR = 'debug_out' |
33 | 33 |
34 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft' | 34 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft' |
35 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft' | 35 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft' |
36 DEFAULT_OCR_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-shuffled.ft' | |
37 DEFAULT_OCRLABEL_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-labels-shuffled.ft' | |
36 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] | 38 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] |
37 | 39 |
38 if DEBUG_X: | 40 if DEBUG_X: |
39 import pylab | 41 import pylab |
40 pylab.ion() | 42 pylab.ion() |
186 for path in debug_images.filelist: | 188 for path in debug_images.filelist: |
187 yield load_image(path) | 189 yield load_image(path) |
188 ''' | 190 ''' |
189 | 191 |
190 class NistData(): | 192 class NistData(): |
191 def __init__(self, nist_path, label_path): | 193 def __init__(self, nist_path, label_path, ocr_path, ocrlabel_path): |
192 self.train_data = open(nist_path, 'rb') | 194 self.train_data = open(nist_path, 'rb') |
193 self.train_labels = open(label_path, 'rb') | 195 self.train_labels = open(label_path, 'rb') |
194 self.dim = tuple(ft._read_header(self.train_data)[3]) | 196 self.dim = tuple(ft._read_header(self.train_data)[3]) |
195 # in order to seek to the beginning of the file | 197 # in order to seek to the beginning of the file |
196 self.train_data.close() | 198 self.train_data.close() |
197 self.train_data = open(nist_path, 'rb') | 199 self.train_data = open(nist_path, 'rb') |
198 | 200 self.ocr_data = open(ocr_path, 'rb') |
199 | 201 self.ocr_labels = open(ocrlabel_path, 'rb') |
200 def nist_supp_iterator(nist, prob_font, prob_captcha, num_img): | 202 |
201 subtensor = slice(0, num_img) | 203 def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img): |
202 img = ft.read(nist.train_data, subtensor).astype(numpy.float32)/255 | 204 img = ft.read(nist.train_data).astype(numpy.float32)/255 |
203 labels = ft.read(nist.train_labels, subtensor) | 205 labels = ft.read(nist.train_labels) |
206 if prob_ocr: | |
207 ocr_img = ft.read(nist.ocr_data).astype(numpy.float32)/255 | |
208 ocr_labels = ft.read(nist.ocr_labels) | |
204 | 209 |
205 for i in xrange(num_img): | 210 for i in xrange(num_img): |
206 r = numpy.random.rand() | 211 r = numpy.random.rand() |
207 if r<= prob_font: | 212 if r <= prob_font: |
208 pass #get font | 213 pass #get font |
209 elif r<= prob_font + prob_captcha: | 214 elif r <= prob_font + prob_captcha: |
210 pass #get captcha | 215 pass #get captcha |
216 elif r <= prob_font + prob_captcha + prob_ocr: | |
217 j = numpy.random.randint(len(ocr_labels)) | |
218 yield ocr_img[j], ocr_labels[j] | |
211 else: | 219 else: |
212 j = numpy.random.randint(num_img) | 220 j = numpy.random.randint(len(labels)) |
213 yield img[j], labels[j] | 221 yield img[j], labels[j] |
214 | 222 |
215 | 223 |
216 # Mostly for debugging, for the moment, just to see if we can | 224 # Mostly for debugging, for the moment, just to see if we can |
217 # reload the images and parameters. | 225 # reload the images and parameters. |
235 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] | 243 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] |
236 -m, --max-complexity: max complexity to generate for an image | 244 -m, --max-complexity: max complexity to generate for an image |
237 -z, --probability-zero: probability of using complexity=0 for an image | 245 -z, --probability-zero: probability of using complexity=0 for an image |
238 -o, --output-file: full path to file to use for output of images | 246 -o, --output-file: full path to file to use for output of images |
239 -p, --params-output-file: path to file to output params to | 247 -p, --params-output-file: path to file to output params to |
240 -r, --labels-output-file: path to file to output labels to | 248 -x, --labels-output-file: path to file to output labels to |
241 -f, --data-file: path to filetensor (.ft) data file (NIST) | 249 -f, --data-file: path to filetensor (.ft) data file (NIST) |
242 -l, --label-file: path to filetensor (.ft) labels file (NIST labels) | 250 -l, --label-file: path to filetensor (.ft) labels file (NIST labels) |
251 -c, --ocr-file: path to filetensor (.ft) data file (OCR) | |
252 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels) | |
243 -a, --prob-font: probability of using a raw font image | 253 -a, --prob-font: probability of using a raw font image |
244 -b, --prob-captcha: probability of using a captcha image | 254 -b, --prob-captcha: probability of using a captcha image |
255 -e, --prob-ocr: probability of using an ocr image | |
245 ''' | 256 ''' |
246 | 257 |
247 # See run_pipeline.py | 258 # See run_pipeline.py |
248 def get_argv(): | 259 def get_argv(): |
249 with open(ARGS_FILE) as f: | 260 with open(ARGS_FILE) as f: |
252 | 263 |
253 # Might be called locally or through dbidispatch. In all cases it should be | 264 # Might be called locally or through dbidispatch. In all cases it should be |
254 # passed to the GIMP executable to be able to use GIMP filters. | 265 # passed to the GIMP executable to be able to use GIMP filters. |
255 # Ex: | 266 # Ex: |
256 def _main(): | 267 def _main(): |
268 #global DEFAULT_NIST_PATH, DEFAULT_LABEL_PATH, DEFAULT_OCR_PATH, DEFAULT_OCRLABEL_PATH | |
269 #global getopt, get_argv | |
270 | |
257 max_complexity = 0.5 # default | 271 max_complexity = 0.5 # default |
258 probability_zero = 0.1 # default | 272 probability_zero = 0.1 # default |
259 output_file_path = None | 273 output_file_path = None |
260 params_output_file_path = None | 274 params_output_file_path = None |
261 labels_output_file_path = None | 275 labels_output_file_path = None |
262 nist_path = DEFAULT_NIST_PATH | 276 nist_path = DEFAULT_NIST_PATH |
263 label_path = DEFAULT_LABEL_PATH | 277 label_path = DEFAULT_LABEL_PATH |
278 ocr_path = DEFAULT_OCR_PATH | |
279 ocrlabel_path = DEFAULT_OCRLABEL_PATH | |
264 prob_font = 0.0 | 280 prob_font = 0.0 |
265 prob_captcha = 0.0 | 281 prob_captcha = 0.0 |
282 prob_ocr = 0.0 | |
266 stop_after = None | 283 stop_after = None |
267 reload_mode = False | 284 reload_mode = False |
268 | 285 |
269 try: | 286 try: |
270 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:r:s:f:l:a:b:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "prob-font=", "prob-captcha="]) | 287 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:e:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr="]) |
271 except getopt.GetoptError, err: | 288 except getopt.GetoptError, err: |
272 # print help information and exit: | 289 # print help information and exit: |
273 print str(err) # will print something like "option -a not recognized" | 290 print str(err) # will print something like "option -a not recognized" |
274 usage() | 291 usage() |
292 pdb.gimp_quit(0) | |
275 sys.exit(2) | 293 sys.exit(2) |
276 | 294 |
277 for o, a in opts: | 295 for o, a in opts: |
278 if o in ('-m', '--max-complexity'): | 296 if o in ('-m', '--max-complexity'): |
279 max_complexity = float(a) | 297 max_complexity = float(a) |
285 assert probability_zero >= 0.0 and probability_zero <= 1.0 | 303 assert probability_zero >= 0.0 and probability_zero <= 1.0 |
286 elif o in ("-o", "--output-file"): | 304 elif o in ("-o", "--output-file"): |
287 output_file_path = a | 305 output_file_path = a |
288 elif o in ('-p', "--params-output-file"): | 306 elif o in ('-p', "--params-output-file"): |
289 params_output_file_path = a | 307 params_output_file_path = a |
290 elif o in ('-r', "--labels-output-file"): | 308 elif o in ('-x', "--labels-output-file"): |
291 labels_output_file_path = a | 309 labels_output_file_path = a |
292 elif o in ('-s', "--stop-after"): | 310 elif o in ('-s', "--stop-after"): |
293 stop_after = int(a) | 311 stop_after = int(a) |
294 elif o in ('-f', "--data-file"): | 312 elif o in ('-f', "--data-file"): |
295 nist_path = a | 313 nist_path = a |
296 elif o in ('-l', "--label-file"): | 314 elif o in ('-l', "--label-file"): |
297 label_path = a | 315 label_path = a |
316 elif o in ('-c', "--ocr-file"): | |
317 ocr_path = a | |
318 elif o in ('-d', "--ocrlabel-file"): | |
319 ocrlabel_path = a | |
298 elif o in ('-a', "--prob-font"): | 320 elif o in ('-a', "--prob-font"): |
299 prob_font = float(a) | 321 prob_font = float(a) |
300 elif o in ('-b', "--prob-captcha"): | 322 elif o in ('-b', "--prob-captcha"): |
301 prob_captcha = float(a) | 323 prob_captcha = float(a) |
324 elif o in ('-e', "--prob-ocr"): | |
325 prob_ocr = float(a) | |
302 else: | 326 else: |
303 assert False, "unhandled option" | 327 assert False, "unhandled option" |
304 | 328 |
305 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: | 329 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: |
306 print "Must specify the three output files." | 330 print "Must specify the three output files." |
307 print | |
308 usage() | 331 usage() |
332 pdb.gimp_quit(0) | |
309 sys.exit(2) | 333 sys.exit(2) |
310 | 334 |
311 if reload_mode: | 335 if reload_mode: |
312 reload(output_file_path, params_output_file_path) | 336 reload(output_file_path, params_output_file_path) |
313 else: | 337 else: |
318 num_img = len(debug_images.filelist) | 342 num_img = len(debug_images.filelist) |
319 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) | 343 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) |
320 img_it = debug_images_iterator(debug_images) | 344 img_it = debug_images_iterator(debug_images) |
321 ''' | 345 ''' |
322 else: | 346 else: |
323 nist = NistData(nist_path, label_path) | 347 nist = NistData(nist_path, label_path, ocr_path, ocrlabel_path) |
324 num_img = nist.dim[0] | 348 num_img = 819200 # 800 Mb file |
325 if stop_after: | 349 if stop_after: |
326 num_img = stop_after | 350 num_img = stop_after |
327 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) | 351 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) |
328 img_it = nist_supp_iterator(nist, prob_font, prob_captcha, num_img) | 352 img_it = nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img) |
329 | 353 |
330 cpx_it = range_complexity_iterator(probability_zero, max_complexity) | 354 cpx_it = range_complexity_iterator(probability_zero, max_complexity) |
331 pl.run(img_it, cpx_it) | 355 pl.run(img_it, cpx_it) |
332 pl.write_output(output_file_path, params_output_file_path, labels_output_file_path) | 356 pl.write_output(output_file_path, params_output_file_path, labels_output_file_path) |
333 | 357 |