view data_generation/pipeline/pipeline.py @ 174:ff26436d42d6

Make data_generation.transformations importable and fixup test.py to not try some of the modules.
author Arnaud Bergeron <abergeron@gmail.com>
date Sat, 27 Feb 2010 12:18:26 -0500
parents 72eee6a9f43f
children 992ca8035a4d
line wrap: on
line source

#!/usr/bin/python
# coding: utf-8

from __future__ import with_statement

# This is intended to be run as a GIMP script
#from gimpfu import *

import sys, os, getopt
import numpy
import filetensor as ft
import random

# To debug locally, also call with -s 100 (to stop after ~100)
# (otherwise we allocate all needed memory, might be loonnng and/or crash
# if, lucky like me, you have an age-old laptop creaking from everywhere)
DEBUG = False
DEBUG_X = False
if DEBUG:
    DEBUG_X = False # Debug under X (pylab.show())

DEBUG_IMAGES_PATH = None
if DEBUG:
    # UNTESTED YET
    # To avoid loading NIST if you don't have it handy
    # (use with debug_images_iterator(), see main())
    # To use NIST, leave as = None
    DEBUG_IMAGES_PATH = None#'/home/francois/Desktop/debug_images'

# Directory where to dump images to visualize results
# (create it, otherwise it'll crash)
DEBUG_OUTPUT_DIR = 'debug_out'

DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft'
DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft'
DEFAULT_OCR_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-shuffled.ft'
DEFAULT_OCRLABEL_PATH = '/data/lisa/data/ocr_breuel/filetensor/unlv-corrected-2010-02-01-labels-shuffled.ft'
ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']

# PARSE COMMAND LINE ARGUMENTS
def get_argv():
    with open(ARGS_FILE) as f:
        args = [l.rstrip() for l in f.readlines()]
    return args

def usage():
    print '''
Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
    -m, --max-complexity: max complexity to generate for an image
    -z, --probability-zero: probability of using complexity=0 for an image
    -o, --output-file: full path to file to use for output of images
    -p, --params-output-file: path to file to output params to
    -x, --labels-output-file: path to file to output labels to
    -f, --data-file: path to filetensor (.ft) data file (NIST)
    -l, --label-file: path to filetensor (.ft) labels file (NIST labels)
    -c, --ocr-file: path to filetensor (.ft) data file (OCR)
    -d, --ocrlabel-file: path to filetensor (.ft) labels file (OCR labels)
    -a, --prob-font: probability of using a raw font image
    -b, --prob-captcha: probability of using a captcha image
    -g, --prob-ocr: probability of using an ocr image
    -y, --seed: the job seed
    '''

try:
    opts, args = getopt.getopt(get_argv(), "rm:z:o:p:x:s:f:l:c:d:a:b:g:y:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", 
"stop-after=", "data-file=", "label-file=", "ocr-file=", "ocrlabel-file=", "prob-font=", "prob-captcha=", "prob-ocr=", "seed="])
except getopt.GetoptError, err:
        # print help information and exit:
        print str(err) # will print something like "option -a not recognized"
        usage()
        pdb.gimp_quit(0)
        sys.exit(2)

for o, a in opts:
    if o in ('-y','--seed'):
        random.seed(int(a))
        numpy.random.seed(int(a))

if DEBUG_X:
    import pylab
    pylab.ion()

from ift6266.data_generation.transformations import *

from PoivreSel import PoivreSel
from thick import Thick
from BruitGauss import BruitGauss
from DistorsionGauss import DistorsionGauss
from PermutPixel import PermutPixel
from gimp_script import GIMP1
from Rature import Rature
from contrast import Contrast
from local_elastic_distortions import LocalElasticDistorter
from slant import Slant
from Occlusion import Occlusion
from add_background_image import AddBackground
from affine_transform import AffineTransformation
from ttf2jpg import ttf2jpg
from Facade import generateCaptcha

if DEBUG:
    from visualizer import Visualizer
    # Either put the visualizer as in the MODULES_INSTANCES list
    # after each module you want to visualize, or in the
    # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant)
    VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR,  on_screen=False)

###---------------------order of transformation module
MODULE_INSTANCES = [Slant(),Thick(),AffineTransformation(),LocalElasticDistorter(),GIMP1(),Rature(),Occlusion(), PermutPixel(),DistorsionGauss(),AddBackground(), PoivreSel(), BruitGauss(), Contrast()]

# These should have a "after_transform_callback(self, image)" method
# (called after each call to transform_image in a module)
AFTER_EACH_MODULE_HOOK = []
if DEBUG:
    AFTER_EACH_MODULE_HOOK = [VISUALIZER]

# These should have a "end_transform_callback(self, final_image" method
# (called after all modules have been called)
END_TRANSFORM_HOOK = []
if DEBUG:
    END_TRANSFORM_HOOK = [VISUALIZER]

class Pipeline():
    def __init__(self, modules, num_img, image_size=(32,32)):
        self.modules = modules
        self.num_img = num_img
        self.num_params_stored = 0
        self.image_size = image_size

        self.init_memory()

    def init_num_params_stored(self):
        # just a dummy call to regenerate_parameters() to get the
        # real number of params (only those which are stored)
        self.num_params_stored = 0
        for m in self.modules:
            self.num_params_stored += len(m.regenerate_parameters(0.0))

    def init_memory(self):
        self.init_num_params_stored()

        total = self.num_img
        num_px = self.image_size[0] * self.image_size[1]

        self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8)
        # +1 to store complexity
        self.params = numpy.empty((total, self.num_params_stored+len(self.modules)))
        self.res_labels = numpy.empty(total, dtype=numpy.int32)

    def run(self, img_iterator, complexity_iterator):
        img_size = self.image_size

        should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
        should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0

        for img_no, (img, label) in enumerate(img_iterator):
            sys.stdout.flush()
            
            global_idx = img_no

            img = img.reshape(img_size)

            param_idx = 0
            mod_idx = 0
            for mod in self.modules:
                # This used to be done _per batch_,
                # ie. out of the "for img" loop
                complexity = complexity_iterator.next() 
                #better to do a complexity sampling for each transformations in order to have more variability
                #otherwise a lot of images similar to the source are generated (i.e. when complexity is close to 0 (1/8 of the time))
                #we need to save the complexity of each transformations and the sum of these complexity is a good indicator of the overall
                #complexity
                self.params[global_idx, mod_idx] = complexity
                mod_idx += 1
                 
                p = mod.regenerate_parameters(complexity)
                self.params[global_idx, param_idx+len(self.modules):param_idx+len(p)+len(self.modules)] = p
                param_idx += len(p)

                img = mod.transform_image(img)

                if should_hook_after_each:
                    for hook in AFTER_EACH_MODULE_HOOK:
                        hook.after_transform_callback(img)

            self.res_data[global_idx] = \
                    img.reshape((img_size[0] * img_size[1],))*255
            self.res_labels[global_idx] = label

            if should_hook_at_the_end:
                for hook in END_TRANSFORM_HOOK:
                    hook.end_transform_callback(img)

    def write_output(self, output_file_path, params_output_file_path, labels_output_file_path):
        with open(output_file_path, 'wb') as f:
            ft.write(f, self.res_data)

        numpy.save(params_output_file_path, self.params)

        with open(labels_output_file_path, 'wb') as f:
            ft.write(f, self.res_labels)
                

##############################################################################
# COMPLEXITY ITERATORS
# They're called once every img, to get the complexity to use for that img
# they must be infinite (should never throw StopIteration when calling next())

# probability of generating 0 complexity, otherwise
# uniform over 0.0-max_complexity
def range_complexity_iterator(probability_zero, max_complexity):
    assert max_complexity <= 1.0
    n = numpy.random.uniform(0.0, 1.0)
    while True:
        if n < probability_zero:
            yield 0.0
        else:
            yield numpy.random.uniform(0.0, max_complexity)

##############################################################################
# DATA ITERATORS
# They can be used to interleave different data sources etc.

'''
# Following code (DebugImages and iterator) is untested

def load_image(filepath):
    _RGB_TO_GRAYSCALE = [0.3, 0.59, 0.11, 0.0]
    img = Image.open(filepath)
    img = numpy.asarray(img)
    if len(img.shape) > 2:
        img = (img * _RGB_TO_GRAYSCALE).sum(axis=2)
    return (img / 255.0).astype('float')

class DebugImages():
    def __init__(self, images_dir_path):
        import glob, os.path
        self.filelist = glob.glob(os.path.join(images_dir_path, "*.png"))

def debug_images_iterator(debug_images):
    for path in debug_images.filelist:
        yield load_image(path)
'''

class NistData():
    def __init__(self, nist_path, label_path, ocr_path, ocrlabel_path):
        self.train_data = open(nist_path, 'rb')
        self.train_labels = open(label_path, 'rb')
        self.dim = tuple(ft._read_header(self.train_data)[3])
        # in order to seek to the beginning of the file
        self.train_data.close()
        self.train_data = open(nist_path, 'rb')
        self.ocr_data = open(ocr_path, 'rb')
        self.ocr_labels = open(ocrlabel_path, 'rb')

# cet iterator load tout en ram
def nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img):
    img = ft.read(nist.train_data)
    labels = ft.read(nist.train_labels)
    if prob_ocr:
        ocr_img = ft.read(nist.ocr_data)
        ocr_labels = ft.read(nist.ocr_labels)
    ttf = ttf2jpg()
    L = [chr(ord('0')+x) for x in range(10)] + [chr(ord('A')+x) for x in range(26)] + [chr(ord('a')+x) for x in range(26)]

    for i in xrange(num_img):
        r = numpy.random.rand()
        if r <= prob_font:
            yield ttf.generate_image()
        elif r <=prob_font + prob_captcha:
            (arr, charac) = generateCaptcha(0,1)
            yield arr.astype(numpy.float32)/255, L.index(charac[0])
        elif r <= prob_font + prob_captcha + prob_ocr:
            j = numpy.random.randint(len(ocr_labels))
            yield ocr_img[j].astype(numpy.float32)/255, ocr_labels[j]
        else:
            j = numpy.random.randint(len(labels))
            yield img[j].astype(numpy.float32)/255, labels[j]


# Mostly for debugging, for the moment, just to see if we can
# reload the images and parameters.
def reload(output_file_path, params_output_file_path):
    images_ft = open(output_file_path, 'rb')
    images_ft_dim = tuple(ft._read_header(images_ft)[3])

    print "Images dimensions: ", images_ft_dim

    params = numpy.load(params_output_file_path)

    print "Params dimensions: ", params.shape
    print params
    

##############################################################################
# MAIN


# Might be called locally or through dbidispatch. In all cases it should be
# passed to the GIMP executable to be able to use GIMP filters.
# Ex: 
def _main():
    #global DEFAULT_NIST_PATH, DEFAULT_LABEL_PATH, DEFAULT_OCR_PATH, DEFAULT_OCRLABEL_PATH
    #global getopt, get_argv

    max_complexity = 0.5 # default
    probability_zero = 0.1 # default
    output_file_path = None
    params_output_file_path = None
    labels_output_file_path = None
    nist_path = DEFAULT_NIST_PATH
    label_path = DEFAULT_LABEL_PATH
    ocr_path = DEFAULT_OCR_PATH
    ocrlabel_path = DEFAULT_OCRLABEL_PATH
    prob_font = 0.0
    prob_captcha = 0.0
    prob_ocr = 0.0
    stop_after = None
    reload_mode = False

    for o, a in opts:
        if o in ('-m', '--max-complexity'):
            max_complexity = float(a)
            assert max_complexity >= 0.0 and max_complexity <= 1.0
        elif o in ('-r', '--reload'):
            reload_mode = True
        elif o in ("-z", "--probability-zero"):
            probability_zero = float(a)
            assert probability_zero >= 0.0 and probability_zero <= 1.0
        elif o in ("-o", "--output-file"):
            output_file_path = a
        elif o in ('-p', "--params-output-file"):
            params_output_file_path = a
        elif o in ('-x', "--labels-output-file"):
            labels_output_file_path = a
        elif o in ('-s', "--stop-after"):
            stop_after = int(a)
        elif o in ('-f', "--data-file"):
            nist_path = a
        elif o in ('-l', "--label-file"):
            label_path = a
        elif o in ('-c', "--ocr-file"):
            ocr_path = a
        elif o in ('-d', "--ocrlabel-file"):
            ocrlabel_path = a
        elif o in ('-a', "--prob-font"):
            prob_font = float(a)
        elif o in ('-b', "--prob-captcha"):
            prob_captcha = float(a)
        elif o in ('-g', "--prob-ocr"):
            prob_ocr = float(a)
        elif o in ('-y', "--seed"):
            pass
        else:
            assert False, "unhandled option"

    if output_file_path == None or params_output_file_path == None or labels_output_file_path == None:
        print "Must specify the three output files."
        usage()
        pdb.gimp_quit(0)
        sys.exit(2)

    if reload_mode:
        reload(output_file_path, params_output_file_path)
    else:
        if DEBUG_IMAGES_PATH:
            '''
            # This code is yet untested
            debug_images = DebugImages(DEBUG_IMAGES_PATH)
            num_img = len(debug_images.filelist)
            pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
            img_it = debug_images_iterator(debug_images)
            '''
        else:
            nist = NistData(nist_path, label_path, ocr_path, ocrlabel_path)
            num_img = 819200 # 800 Mb file
            if stop_after:
                num_img = stop_after
            pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
            img_it = nist_supp_iterator(nist, prob_font, prob_captcha, prob_ocr, num_img)

        cpx_it = range_complexity_iterator(probability_zero, max_complexity)
        pl.run(img_it, cpx_it)
        pl.write_output(output_file_path, params_output_file_path, labels_output_file_path)

_main()

if DEBUG_X:
    pylab.ioff()
    pylab.show()

pdb.gimp_quit(0)