view transformations/pipeline.py @ 56:d9d836d3c625

Change in affine_transform to handle float images
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Sun, 07 Feb 2010 23:09:56 -0500
parents c89defea1e65
children cc4be6b25b8e
line wrap: on
line source

#!/usr/bin/python
# coding: utf-8

from __future__ import with_statement

# This is intended to be run as a GIMP script
from gimpfu import *

import sys, os, getopt
import numpy
import filetensor as ft
import random

# To debug locally, also call with -s 1 (to stop after 1 batch ~= 100)
# (otherwise we allocate all needed memory, might be loonnng and/or crash
# if, lucky like me, you have an age-old laptop creaking from everywhere)
DEBUG = True
DEBUG_X = False
if DEBUG:
    DEBUG_X = False # Debug under X (pylab.show())

DEBUG_IMAGES_PATH = None
if DEBUG:
    # UNTESTED YET
    # To avoid loading NIST if you don't have it handy
    # (use with debug_images_iterator(), see main())
    # To use NIST, leave as = None
    DEBUG_IMAGES_PATH = None#'/home/francois/Desktop/debug_images'

# Directory where to dump images to visualize results
# (create it, otherwise it'll crash)
DEBUG_OUTPUT_DIR = 'debug_out'

BATCH_SIZE = 100
DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']

if DEBUG_X:
    import pylab
    pylab.ion()

#from add_background_image import AddBackground
#from affine_transform import AffineTransformation
from PoivreSel import PoivreSel
from thick import Thick
#from BruitGauss import BruitGauss
#from gimp_script import GIMPTransformation
#from Rature import Rature
from contrast import Contrast
from local_elastic_distortions import LocalElasticDistorter
from slant import Slant

if DEBUG:
    from visualizer import Visualizer
    # Either put the visualizer as in the MODULES_INSTANCES list
    # after each module you want to visualize, or in the
    # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant)
    VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR,  on_screen=False)

MODULE_INSTANCES = [LocalElasticDistorter()]

# These should have a "after_transform_callback(self, image)" method
# (called after each call to transform_image in a module)
AFTER_EACH_MODULE_HOOK = []
if DEBUG:
    AFTER_EACH_MODULE_HOOK = [VISUALIZER]

# These should have a "end_transform_callback(self, final_image" method
# (called after all modules have been called)
END_TRANSFORM_HOOK = []
if DEBUG:
    END_TRANSFORM_HOOK = [VISUALIZER]

class Pipeline():
    def __init__(self, modules, num_batches, batch_size, image_size=(32,32)):
        self.modules = modules
        self.num_batches = num_batches
        self.batch_size = batch_size
        self.num_params_stored = 0
        self.image_size = image_size

        self.init_memory()

    def init_num_params_stored(self):
        # just a dummy call to regenerate_parameters() to get the
        # real number of params (only those which are stored)
        self.num_params_stored = 0
        for m in self.modules:
            self.num_params_stored += len(m.regenerate_parameters(0.0))

    def init_memory(self):
        self.init_num_params_stored()

        total = self.num_batches * self.batch_size
        num_px = self.image_size[0] * self.image_size[1]

        self.res_data = numpy.empty((total, num_px))
        # +1 to store complexity
        self.params = numpy.empty((total, self.num_params_stored+1))

    def run(self, batch_iterator, complexity_iterator):
        img_size = self.image_size

        should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
        should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0

        for batch_no, batch in enumerate(batch_iterator):
            complexity = complexity_iterator.next()
            if DEBUG:
                print "Complexity:", complexity

            assert len(batch) == self.batch_size

            for img_no, img in enumerate(batch):
                sys.stdout.flush()
                global_idx = batch_no*self.batch_size + img_no

                img = img.reshape(img_size)

                param_idx = 1
                # store complexity along with other params
                self.params[global_idx, 0] = complexity
                for mod in self.modules:
                    # This used to be done _per batch_,
                    # ie. out of the "for img" loop                   
                    p = mod.regenerate_parameters(complexity)
                    self.params[global_idx, param_idx:param_idx+len(p)] = p
                    param_idx += len(p)

                    img = mod.transform_image(img)

                    if should_hook_after_each:
                        for hook in AFTER_EACH_MODULE_HOOK:
                            hook.after_transform_callback(img)

                self.res_data[global_idx] = \
                        img.reshape((img_size[0] * img_size[1],))*255


                if should_hook_at_the_end:
                    for hook in END_TRANSFORM_HOOK:
                        hook.end_transform_callback(img)

    def write_output(self, output_file_path, params_output_file_path):
        with open(output_file_path, 'wb') as f:
            ft.write(f, self.res_data)

        numpy.save(params_output_file_path, self.params)
                

##############################################################################
# COMPLEXITY ITERATORS
# They're called once every batch, to get the complexity to use for that batch
# they must be infinite (should never throw StopIteration when calling next())

# probability of generating 0 complexity, otherwise
# uniform over 0.0-max_complexity
def range_complexity_iterator(probability_zero, max_complexity):
    assert max_complexity <= 1.0
    n = numpy.random.uniform(0.0, 1.0)
    while True:
        if n < probability_zero:
            yield 0.0
        else:
            yield numpy.random.uniform(0.0, max_complexity)

##############################################################################
# DATA ITERATORS
# They can be used to interleave different data sources etc.

'''
# Following code (DebugImages and iterator) is untested

def load_image(filepath):
    _RGB_TO_GRAYSCALE = [0.3, 0.59, 0.11, 0.0]
    img = Image.open(filepath)
    img = numpy.asarray(img)
    if len(img.shape) > 2:
        img = (img * _RGB_TO_GRAYSCALE).sum(axis=2)
    return (img / 255.0).astype('float')

class DebugImages():
    def __init__(self, images_dir_path):
        import glob, os.path
        self.filelist = glob.glob(os.path.join(images_dir_path, "*.png"))

def debug_images_iterator(debug_images):
    for path in debug_images.filelist:
        yield load_image(path)
'''

class NistData():
    def __init__(self, ):
        nist_path = DEFAULT_NIST_PATH
        self.train_data = open(nist_path, 'rb')
        self.dim = tuple(ft._read_header(self.train_data)[3])

def just_nist_iterator(nist, batch_size, stop_after=None):
    for i in xrange(0, nist.dim[0], batch_size):
        if not stop_after is None and i >= stop_after:
            break

        nist.train_data.seek(0)
        yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255



# Mostly for debugging, for the moment, just to see if we can
# reload the images and parameters.
def reload(output_file_path, params_output_file_path):
    images_ft = open(output_file_path, 'rb')
    images_ft_dim = tuple(ft._read_header(images_ft)[3])

    print "Images dimensions: ", images_ft_dim

    params = numpy.load(params_output_file_path)

    print "Params dimensions: ", params.shape
    print params
    

##############################################################################
# MAIN

def usage():
    print '''
Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
    -m, --max-complexity: max complexity to generate for a batch
    -z, --probability-zero: probability of using complexity=0 for a batch
    -o, --output-file: full path to file to use for output of images
    -p, --params-output-file: path to file to output params to
    '''

# See run_pipeline.py
def get_argv():
    with open(ARGS_FILE) as f:
        args = [l.rstrip() for l in f.readlines()]
    return args

# Might be called locally or through dbidispatch. In all cases it should be
# passed to the GIMP executable to be able to use GIMP filters.
# Ex: 
def _main():
    max_complexity = 0.5 # default
    probability_zero = 0.1 # default
    output_file_path = None
    params_output_file_path = None
    stop_after = None
    reload_mode = False

    try:
        opts, args = getopt.getopt(get_argv(), "rm:z:o:p:s:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="])
    except getopt.GetoptError, err:
        # print help information and exit:
        print str(err) # will print something like "option -a not recognized"
        usage()
        sys.exit(2)

    for o, a in opts:
        if o in ('-m', '--max-complexity'):
            max_complexity = float(a)
            assert max_complexity >= 0.0 and max_complexity <= 1.0
        elif o in ('-r', '--reload'):
            reload_mode = True
        elif o in ("-z", "--probability-zero"):
            probability_zero = float(a)
            assert probability_zero >= 0.0 and probability_zero <= 1.0
        elif o in ("-o", "--output-file"):
            output_file_path = a
        elif o in ('-p', "--params-output-file"):
            params_output_file_path = a
        elif o in ('-s', "--stop-after"):
            stop_after = int(a)
        else:
            assert False, "unhandled option"

    if output_file_path == None or params_output_file_path == None:
        print "Must specify both output files."
        print
        usage()
        sys.exit(2)

    if reload_mode:
        reload(output_file_path, params_output_file_path)
    else:
        if DEBUG_IMAGES_PATH:
            '''
            # This code is yet untested
            debug_images = DebugImages(DEBUG_IMAGES_PATH)
            num_batches = 1
            batch_size = len(debug_images.filelist)
            pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
            batch_it = debug_images_iterator(debug_images)
            '''
        else:
            nist = NistData()
            num_batches = nist.dim[0]/BATCH_SIZE
            if stop_after:
                num_batches = stop_after
            pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
            batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)

        cpx_it = range_complexity_iterator(probability_zero, max_complexity)
        pl.run(batch_it, cpx_it)
        pl.write_output(output_file_path, params_output_file_path)

_main()

if DEBUG_X:
    pylab.ioff()
    pylab.show()

pdb.gimp_quit(0)