Mercurial > ift6266
diff data_generation/transformations/pipeline.py @ 167:1f5937e9e530
More moves - transformations into data_generation, added "deep" folder
author | Dumitru Erhan <dumitru.erhan@gmail.com> |
---|---|
date | Fri, 26 Feb 2010 14:15:38 -0500 |
parents | transformations/pipeline.py@6f3b866c0182 |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/data_generation/transformations/pipeline.py Fri Feb 26 14:15:38 2010 -0500 @@ -0,0 +1,391 @@ +#!/usr/bin/python +# coding: utf-8 + +from __future__ import with_statement + +# This is intended to be run as a GIMP script +#from gimpfu import * + +import sys, os, getopt +import numpy +import filetensor as ft +import random + +# To debug locally, also call with -s 100 (to stop after ~100) +# (otherwise we allocate all needed memory, might be loonnng and/or crash +# if, lucky like me, you have an age-old laptop creaking from everywhere) +DEBUG = False +DEBUG_X = False +if DEBUG: + DEBUG_X = False # Debug under X (pylab.show()) + +DEBUG_IMAGES_PATH = None +if DEBUG: + # UNTESTED YET + # To avoid loading NIST if you don't have it handy + # (use with debug_images_iterator(), see main()) + # To use NIST, leave as = None + DEBUG_IMAGES_PATH = None#'/home/francois/Desktop/debug_images' + +# Directory where to dump images to visualize results +# (create it, otherwise it'll crash) +DEBUG_OUTPUT_DIR = 'debug_out' + +DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft' +DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft' +DEFAULT_OCR_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-shuffled.ft' +DEFAULT_OCRLABEL_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-labels-shuffled.ft' +ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] + +# PARSE COMMAND LINE ARGUMENTS +def get_argv(): + with open(ARGS_FILE) as f: + args = [l.rstrip() for l in f.readlines()] + return args + +def usage(): + print ''' +Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] + -m, --max-complexity: max complexity to generate for an image + -z, --probability-zero: probability of using complexity=0 for an image + -o, --output-file: full path to file to use for output of images + -p, --params-output-file: path to file to output params to + -x, --labels-output-file: path to file to output labels to + -f, --data-file: path to filetensor (.ft) data file (NIST) + -l, --label-file: path to filetensor (.ft) labels file (NIST labels) + -c, --ocr-file: path to filetensor (.ft) data file (OCR) + -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels) + -a, --prob-font: probability of using a raw font image + -b, --prob-captcha: probability of using a captcha image + -g, --prob-ocr: probability of using an ocr image + -y, --seed: the job seed + ''' + +try: + opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:g:y:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", +"stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr=", "seed="]) +except getopt.GetoptError, err: + # print help information and exit: + print str(err) # will print something like "option -a not recognized" + usage() + pdb.gimp_quit(0) + sys.exit(2) + +for o, a in opts: + if o in ('-y','--seed'): + random.seed(int(a)) + numpy.random.seed(int(a)) + +if DEBUG_X: + import pylab + pylab.ion() + +from PoivreSel import PoivreSel +from thick import Thick +from BruitGauss import BruitGauss +from DistorsionGauss import DistorsionGauss +from PermutPixel import PermutPixel +from gimp_script import GIMP1 +from Rature import Rature +from contrast import Contrast +from local_elastic_distortions import LocalElasticDistorter +from slant import Slant +from Occlusion import Occlusion +from add_background_image import AddBackground +from affine_transform import AffineTransformation +from ttf2jpg import ttf2jpg +from Facade import generateCaptcha + +if DEBUG: + from visualizer import Visualizer + # Either put the visualizer as in the MODULES_INSTANCES list + # after each module you want to visualize, or in the + # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant) + VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR, on_screen=False) + +###---------------------order of transformation module +MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()] + +# These should have a "after_transform_callback(self, image)" method +# (called after each call to transform_image in a module) +AFTER_EACH_MODULE_HOOK = [] +if DEBUG: + AFTER_EACH_MODULE_HOOK = [VISUALIZER] + +# These should have a "end_transform_callback(self, final_image" method +# (called after all modules have been called) +END_TRANSFORM_HOOK = [] +if DEBUG: + END_TRANSFORM_HOOK = [VISUALIZER] + +class Pipeline(): + def __init__(self, modules, num_img, image_size=(32,32)): + self.modules = modules + self.num_img = num_img + self.num_params_stored = 0 + self.image_size = image_size + + self.init_memory() + + def init_num_params_stored(self): + # just a dummy call to regenerate_parameters() to get the + # real number of params (only those which are stored) + self.num_params_stored = 0 + for m in self.modules: + self.num_params_stored += len(m.regenerate_parameters(0.0)) + + def init_memory(self): + self.init_num_params_stored() + + total = self.num_img + num_px = self.image_size[0] * self.image_size[1] + + self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8) + # +1 to store complexity + self.params = numpy.empty((total, self.num_params_stored+len(self.modules))) + self.res_labels = numpy.empty(total, dtype=numpy.int32) + + def run(self, img_iterator, complexity_iterator): + img_size = self.image_size + + should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 + should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 + + for img_no, (img, label) in enumerate(img_iterator): + sys.stdout.flush() + + global_idx = img_no + + img = img.reshape(img_size) + + param_idx = 0 + mod_idx = 0 + for mod in self.modules: + # This used to be done _per batch_, + # ie. out of the "for img" loop + complexity = complexity_iterator.next() + #better to do a complexity sampling for each transformations in order to have more variability + #otherwise a lot of images similar to the source are generated (i.e. when complexity is close to 0 (1/8 of the time)) + #we need to save the complexity of each transformations and the sum of these complexity is a good indicator of the overall + #complexity + self.params[global_idx, mod_idx] = complexity + mod_idx += 1 + + p = mod.regenerate_parameters(complexity) + self.params[global_idx, param_idx+len(self.modules):param_idx+len(p)+len(self.modules)] = p + param_idx += len(p) + + img = mod.transform_image(img) + + if should_hook_after_each: + for hook in AFTER_EACH_MODULE_HOOK: + hook.after_transform_callback(img) + + self.res_data[global_idx] = \ + img.reshape((img_size[0] * img_size[1],))*255 + self.res_labels[global_idx] = label + + if should_hook_at_the_end: + for hook in END_TRANSFORM_HOOK: + hook.end_transform_callback(img) + + def write_output(self, output_file_path, params_output_file_path, labels_output_file_path): + with open(output_file_path, 'wb') as f: + ft.write(f, self.res_data) + + numpy.save(params_output_file_path, self.params) + + with open(labels_output_file_path, 'wb') as f: + ft.write(f, self.res_labels) + + +############################################################################## +# COMPLEXITY ITERATORS +# They're called once every img, to get the complexity to use for that img +# they must be infinite (should never throw StopIteration when calling next()) + +# probability of generating 0 complexity, otherwise +# uniform over 0.0-max_complexity +def range_complexity_iterator(probability_zero, max_complexity): + assert max_complexity <= 1.0 + n = numpy.random.uniform(0.0, 1.0) + while True: + if n < probability_zero: + yield 0.0 + else: + yield numpy.random.uniform(0.0, max_complexity) + +############################################################################## +# DATA ITERATORS +# They can be used to interleave different data sources etc. + +''' +# Following code (DebugImages and iterator) is untested + +def load_image(filepath): + _RGB_TO_GRAYSCALE = [0.3, 0.59, 0.11, 0.0] + img = Image.open(filepath) + img = numpy.asarray(img) + if len(img.shape) > 2: + img = (img * _RGB_TO_GRAYSCALE).sum(axis=2) + return (img / 255.0).astype('float') + +class DebugImages(): + def __init__(self, images_dir_path): + import glob, os.path + self.filelist = glob.glob(os.path.join(images_dir_path, "*.png")) + +def debug_images_iterator(debug_images): + for path in debug_images.filelist: + yield load_image(path) +''' + +class NistData(): + def __init__(self, nist_path, label_path, ocr_path, ocrlabel_path): + self.train_data = open(nist_path, 'rb') + self.train_labels = open(label_path, 'rb') + self.dim = tuple(ft._read_header(self.train_data)[3]) + # in order to seek to the beginning of the file + self.train_data.close() + self.train_data = open(nist_path, 'rb') + self.ocr_data = open(ocr_path, 'rb') + self.ocr_labels = open(ocrlabel_path, 'rb') + +# cet iterator load tout en ram +def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img): + img = ft.read(nist.train_data) + labels = ft.read(nist.train_labels) + if prob_ocr: + ocr_img = ft.read(nist.ocr_data) + ocr_labels = ft.read(nist.ocr_labels) + ttf = ttf2jpg() + L = [chr(ord('0')+x) for x in range(10)] + [chr(ord('A')+x) for x in range(26)] + [chr(ord('a')+x) for x in range(26)] + + for i in xrange(num_img): + r = numpy.random.rand() + if r <= prob_font: + yield ttf.generate_image() + elif r <=prob_font + prob_captcha: + (arr, charac) = generateCaptcha(0,1) + yield arr.astype(numpy.float32)/255, L.index(charac[0]) + elif r <= prob_font + prob_captcha + prob_ocr: + j = numpy.random.randint(len(ocr_labels)) + yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j] + else: + j = numpy.random.randint(len(labels)) + yield img[j].astype(numpy.float32)/255, labels[j] + + +# Mostly for debugging, for the moment, just to see if we can +# reload the images and parameters. +def reload(output_file_path, params_output_file_path): + images_ft = open(output_file_path, 'rb') + images_ft_dim = tuple(ft._read_header(images_ft)[3]) + + print "Images dimensions: ", images_ft_dim + + params = numpy.load(params_output_file_path) + + print "Params dimensions: ", params.shape + print params + + +############################################################################## +# MAIN + + +# Might be called locally or through dbidispatch. In all cases it should be +# passed to the GIMP executable to be able to use GIMP filters. +# Ex: +def _main(): + #global DEFAULT_NIST_PATH, DEFAULT_LABEL_PATH, DEFAULT_OCR_PATH, DEFAULT_OCRLABEL_PATH + #global getopt, get_argv + + max_complexity = 0.5 # default + probability_zero = 0.1 # default + output_file_path = None + params_output_file_path = None + labels_output_file_path = None + nist_path = DEFAULT_NIST_PATH + label_path = DEFAULT_LABEL_PATH + ocr_path = DEFAULT_OCR_PATH + ocrlabel_path = DEFAULT_OCRLABEL_PATH + prob_font = 0.0 + prob_captcha = 0.0 + prob_ocr = 0.0 + stop_after = None + reload_mode = False + + for o, a in opts: + if o in ('-m', '--max-complexity'): + max_complexity = float(a) + assert max_complexity >= 0.0 and max_complexity <= 1.0 + elif o in ('-r', '--reload'): + reload_mode = True + elif o in ("-z", "--probability-zero"): + probability_zero = float(a) + assert probability_zero >= 0.0 and probability_zero <= 1.0 + elif o in ("-o", "--output-file"): + output_file_path = a + elif o in ('-p', "--params-output-file"): + params_output_file_path = a + elif o in ('-x', "--labels-output-file"): + labels_output_file_path = a + elif o in ('-s', "--stop-after"): + stop_after = int(a) + elif o in ('-f', "--data-file"): + nist_path = a + elif o in ('-l', "--label-file"): + label_path = a + elif o in ('-c', "--ocr-file"): + ocr_path = a + elif o in ('-d', "--ocrlabel-file"): + ocrlabel_path = a + elif o in ('-a', "--prob-font"): + prob_font = float(a) + elif o in ('-b', "--prob-captcha"): + prob_captcha = float(a) + elif o in ('-g', "--prob-ocr"): + prob_ocr = float(a) + elif o in ('-y', "--seed"): + pass + else: + assert False, "unhandled option" + + if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: + print "Must specify the three output files." + usage() + pdb.gimp_quit(0) + sys.exit(2) + + if reload_mode: + reload(output_file_path, params_output_file_path) + else: + if DEBUG_IMAGES_PATH: + ''' + # This code is yet untested + debug_images = DebugImages(DEBUG_IMAGES_PATH) + num_img = len(debug_images.filelist) + pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) + img_it = debug_images_iterator(debug_images) + ''' + else: + nist = NistData(nist_path, label_path, ocr_path, ocrlabel_path) + num_img = 819200 # 800 Mb file + if stop_after: + num_img = stop_after + pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) + img_it = nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img) + + cpx_it = range_complexity_iterator(probability_zero, max_complexity) + pl.run(img_it, cpx_it) + pl.write_output(output_file_path, params_output_file_path, labels_output_file_path) + +_main() + +if DEBUG_X: + pylab.ioff() + pylab.show() + +pdb.gimp_quit(0) +