comparison transformations/pipeline.py @ 67:5e448ea129b3

Ajouté la source (optionnelle) de données OCR Autriche avec une probabilité passée en argument
author boulanni <nicolas_boulanger@hotmail.com>
date Tue, 09 Feb 2010 21:33:57 -0500
parents 1afa95285b9c
children 95c491bb5662
comparison
equal deleted inserted replaced
66:bf83682c827b 67:5e448ea129b3
12 import random 12 import random
13 13
14 # To debug locally, also call with -s 100 (to stop after ~100) 14 # To debug locally, also call with -s 100 (to stop after ~100)
15 # (otherwise we allocate all needed memory, might be loonnng and/or crash 15 # (otherwise we allocate all needed memory, might be loonnng and/or crash
16 # if, lucky like me, you have an age-old laptop creaking from everywhere) 16 # if, lucky like me, you have an age-old laptop creaking from everywhere)
17 DEBUG = True 17 DEBUG = False
18 DEBUG_X = False 18 DEBUG_X = False
19 if DEBUG: 19 if DEBUG:
20 DEBUG_X = False # Debug under X (pylab.show()) 20 DEBUG_X = False # Debug under X (pylab.show())
21 21
22 DEBUG_IMAGES_PATH = None 22 DEBUG_IMAGES_PATH = None
31 # (create it, otherwise it'll crash) 31 # (create it, otherwise it'll crash)
32 DEBUG_OUTPUT_DIR = 'debug_out' 32 DEBUG_OUTPUT_DIR = 'debug_out'
33 33
34 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft' 34 DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft'
35 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft' 35 DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft'
36 DEFAULT_OCR_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-shuffled.ft'
37 DEFAULT_OCRLABEL_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-labels-shuffled.ft'
36 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] 38 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
37 39
38 if DEBUG_X: 40 if DEBUG_X:
39 import pylab 41 import pylab
40 pylab.ion() 42 pylab.ion()
186 for path in debug_images.filelist: 188 for path in debug_images.filelist:
187 yield load_image(path) 189 yield load_image(path)
188 ''' 190 '''
189 191
190 class NistData(): 192 class NistData():
191 def __init__(self, nist_path, label_path): 193 def __init__(self, nist_path, label_path, ocr_path, ocrlabel_path):
192 self.train_data = open(nist_path, 'rb') 194 self.train_data = open(nist_path, 'rb')
193 self.train_labels = open(label_path, 'rb') 195 self.train_labels = open(label_path, 'rb')
194 self.dim = tuple(ft._read_header(self.train_data)[3]) 196 self.dim = tuple(ft._read_header(self.train_data)[3])
195 # in order to seek to the beginning of the file 197 # in order to seek to the beginning of the file
196 self.train_data.close() 198 self.train_data.close()
197 self.train_data = open(nist_path, 'rb') 199 self.train_data = open(nist_path, 'rb')
198 200 self.ocr_data = open(ocr_path, 'rb')
199 201 self.ocr_labels = open(ocrlabel_path, 'rb')
200 def nist_supp_iterator(nist, prob_font, prob_captcha, num_img): 202
201 subtensor = slice(0, num_img) 203 def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img):
202 img = ft.read(nist.train_data, subtensor).astype(numpy.float32)/255 204 img = ft.read(nist.train_data).astype(numpy.float32)/255
203 labels = ft.read(nist.train_labels, subtensor) 205 labels = ft.read(nist.train_labels)
206 if prob_ocr:
207 ocr_img = ft.read(nist.ocr_data).astype(numpy.float32)/255
208 ocr_labels = ft.read(nist.ocr_labels)
204 209
205 for i in xrange(num_img): 210 for i in xrange(num_img):
206 r = numpy.random.rand() 211 r = numpy.random.rand()
207 if r<= prob_font: 212 if r <= prob_font:
208 pass #get font 213 pass #get font
209 elif r<= prob_font + prob_captcha: 214 elif r <= prob_font + prob_captcha:
210 pass #get captcha 215 pass #get captcha
216 elif r <= prob_font + prob_captcha + prob_ocr:
217 j = numpy.random.randint(len(ocr_labels))
218 yield ocr_img[j], ocr_labels[j]
211 else: 219 else:
212 j = numpy.random.randint(num_img) 220 j = numpy.random.randint(len(labels))
213 yield img[j], labels[j] 221 yield img[j], labels[j]
214 222
215 223
216 # Mostly for debugging, for the moment, just to see if we can 224 # Mostly for debugging, for the moment, just to see if we can
217 # reload the images and parameters. 225 # reload the images and parameters.
235 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] 243 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
236 -m, --max-complexity: max complexity to generate for an image 244 -m, --max-complexity: max complexity to generate for an image
237 -z, --probability-zero: probability of using complexity=0 for an image 245 -z, --probability-zero: probability of using complexity=0 for an image
238 -o, --output-file: full path to file to use for output of images 246 -o, --output-file: full path to file to use for output of images
239 -p, --params-output-file: path to file to output params to 247 -p, --params-output-file: path to file to output params to
240 -r, --labels-output-file: path to file to output labels to 248 -x, --labels-output-file: path to file to output labels to
241 -f, --data-file: path to filetensor (.ft) data file (NIST) 249 -f, --data-file: path to filetensor (.ft) data file (NIST)
242 -l, --label-file: path to filetensor (.ft) labels file (NIST labels) 250 -l, --label-file: path to filetensor (.ft) labels file (NIST labels)
251 -c, --ocr-file: path to filetensor (.ft) data file (OCR)
252 -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels)
243 -a, --prob-font: probability of using a raw font image 253 -a, --prob-font: probability of using a raw font image
244 -b, --prob-captcha: probability of using a captcha image 254 -b, --prob-captcha: probability of using a captcha image
255 -e, --prob-ocr: probability of using an ocr image
245 ''' 256 '''
246 257
247 # See run_pipeline.py 258 # See run_pipeline.py
248 def get_argv(): 259 def get_argv():
249 with open(ARGS_FILE) as f: 260 with open(ARGS_FILE) as f:
252 263
253 # Might be called locally or through dbidispatch. In all cases it should be 264 # Might be called locally or through dbidispatch. In all cases it should be
254 # passed to the GIMP executable to be able to use GIMP filters. 265 # passed to the GIMP executable to be able to use GIMP filters.
255 # Ex: 266 # Ex:
256 def _main(): 267 def _main():
268 #global DEFAULT_NIST_PATH, DEFAULT_LABEL_PATH, DEFAULT_OCR_PATH, DEFAULT_OCRLABEL_PATH
269 #global getopt, get_argv
270
257 max_complexity = 0.5 # default 271 max_complexity = 0.5 # default
258 probability_zero = 0.1 # default 272 probability_zero = 0.1 # default
259 output_file_path = None 273 output_file_path = None
260 params_output_file_path = None 274 params_output_file_path = None
261 labels_output_file_path = None 275 labels_output_file_path = None
262 nist_path = DEFAULT_NIST_PATH 276 nist_path = DEFAULT_NIST_PATH
263 label_path = DEFAULT_LABEL_PATH 277 label_path = DEFAULT_LABEL_PATH
278 ocr_path = DEFAULT_OCR_PATH
279 ocrlabel_path = DEFAULT_OCRLABEL_PATH
264 prob_font = 0.0 280 prob_font = 0.0
265 prob_captcha = 0.0 281 prob_captcha = 0.0
282 prob_ocr = 0.0
266 stop_after = None 283 stop_after = None
267 reload_mode = False 284 reload_mode = False
268 285
269 try: 286 try:
270 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:r:s:f:l:a:b:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "prob-font=", "prob-captcha="]) 287 opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:e:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr="])
271 except getopt.GetoptError, err: 288 except getopt.GetoptError, err:
272 # print help information and exit: 289 # print help information and exit:
273 print str(err) # will print something like "option -a not recognized" 290 print str(err) # will print something like "option -a not recognized"
274 usage() 291 usage()
292 pdb.gimp_quit(0)
275 sys.exit(2) 293 sys.exit(2)
276 294
277 for o, a in opts: 295 for o, a in opts:
278 if o in ('-m', '--max-complexity'): 296 if o in ('-m', '--max-complexity'):
279 max_complexity = float(a) 297 max_complexity = float(a)
285 assert probability_zero >= 0.0 and probability_zero <= 1.0 303 assert probability_zero >= 0.0 and probability_zero <= 1.0
286 elif o in ("-o", "--output-file"): 304 elif o in ("-o", "--output-file"):
287 output_file_path = a 305 output_file_path = a
288 elif o in ('-p', "--params-output-file"): 306 elif o in ('-p', "--params-output-file"):
289 params_output_file_path = a 307 params_output_file_path = a
290 elif o in ('-r', "--labels-output-file"): 308 elif o in ('-x', "--labels-output-file"):
291 labels_output_file_path = a 309 labels_output_file_path = a
292 elif o in ('-s', "--stop-after"): 310 elif o in ('-s', "--stop-after"):
293 stop_after = int(a) 311 stop_after = int(a)
294 elif o in ('-f', "--data-file"): 312 elif o in ('-f', "--data-file"):
295 nist_path = a 313 nist_path = a
296 elif o in ('-l', "--label-file"): 314 elif o in ('-l', "--label-file"):
297 label_path = a 315 label_path = a
316 elif o in ('-c', "--ocr-file"):
317 ocr_path = a
318 elif o in ('-d', "--ocrlabel-file"):
319 ocrlabel_path = a
298 elif o in ('-a', "--prob-font"): 320 elif o in ('-a', "--prob-font"):
299 prob_font = float(a) 321 prob_font = float(a)
300 elif o in ('-b', "--prob-captcha"): 322 elif o in ('-b', "--prob-captcha"):
301 prob_captcha = float(a) 323 prob_captcha = float(a)
324 elif o in ('-e', "--prob-ocr"):
325 prob_ocr = float(a)
302 else: 326 else:
303 assert False, "unhandled option" 327 assert False, "unhandled option"
304 328
305 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: 329 if output_file_path == None or params_output_file_path == None or labels_output_file_path == None:
306 print "Must specify the three output files." 330 print "Must specify the three output files."
307 print
308 usage() 331 usage()
332 pdb.gimp_quit(0)
309 sys.exit(2) 333 sys.exit(2)
310 334
311 if reload_mode: 335 if reload_mode:
312 reload(output_file_path, params_output_file_path) 336 reload(output_file_path, params_output_file_path)
313 else: 337 else:
318 num_img = len(debug_images.filelist) 342 num_img = len(debug_images.filelist)
319 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) 343 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
320 img_it = debug_images_iterator(debug_images) 344 img_it = debug_images_iterator(debug_images)
321 ''' 345 '''
322 else: 346 else:
323 nist = NistData(nist_path, label_path) 347 nist = NistData(nist_path, label_path, ocr_path, ocrlabel_path)
324 num_img = nist.dim[0] 348 num_img = 819200 # 800 Mb file
325 if stop_after: 349 if stop_after:
326 num_img = stop_after 350 num_img = stop_after
327 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) 351 pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
328 img_it = nist_supp_iterator(nist, prob_font, prob_captcha, num_img) 352 img_it = nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img)
329 353
330 cpx_it = range_complexity_iterator(probability_zero, max_complexity) 354 cpx_it = range_complexity_iterator(probability_zero, max_complexity)
331 pl.run(img_it, cpx_it) 355 pl.run(img_it, cpx_it)
332 pl.write_output(output_file_path, params_output_file_path, labels_output_file_path) 356 pl.write_output(output_file_path, params_output_file_path, labels_output_file_path)
333 357