view transformations/pipeline.py @ 41:fdb0e0870fb4

Beaucoup de modifications à pipeline.py pour généraliser et un début de visualisation, et créé un wrapper (run_pipeline.py) pour appeler avec GIMP. - Modifications à pipeline.py - Wrappé la boucle du pipeline dans une classe - Isolé le problème de itérer sur les batches et les complexités dans des itérateurs - Permet d'avoir des ordres compliqués de batch (plusieurs sources), de complexités - Maintenant regenerate_parameters() est appelé pour chaque image. - Command line arguments avec getopt(). On pourra rajouter des options ainsi. - run_pipeline.py - Le but est de permettre de passer des arguments. Pas facile (pas trouvé comment de façon simple) avec la command line pour appeler GIMP en mode batch. C'est un hack ici. - Le but ultime est de permettre de lancer les jobs sur les clusters avec dbidispatch en précisant les options (diff. pour chaque job) sur la ligne de commande.
author fsavard
date Wed, 03 Feb 2010 17:08:27 -0500
parents f6b6c74bb82f
children fabf910467b2
line wrap: on
line source

#!/usr/bin/python
# coding: utf-8

from __future__ import with_statement

import sys, os, getopt
import numpy
import filetensor as ft
import random

# This is intended to be run as a GIMP script
from gimpfu import *

DEBUG = True
BATCH_SIZE = 100
DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']

if DEBUG:
    import pylab
    pylab.ion()

#from add_background_image import AddBackground
#from affine_transform import AffineTransformation
#from PoivreSel import PoivreSel
from thick import Thick
#from BruitGauss import BruitGauss
#from gimp_script import GIMPTransformation
#from Rature import Rature
#from contrast Contrast
from local_elastic_distortions import LocalElasticDistorter
from slant import Slant

MODULE_INSTANCES = [Thick(), LocalElasticDistorter(), Slant()]

class Pipeline():
    def __init__(self, modules, num_batches, batch_size, image_size=(32,32)):
        self.modules = modules
        self.num_batches = num_batches
        self.batch_size = batch_size
        self.num_params_stored = 0
        self.image_size = image_size

        self.init_memory()

    def init_num_params_stored(self):
        # just a dummy call to regenerate_parameters() to get the
        # real number of params (only those which are stored)
        self.num_params_stored = 0
        for m in self.modules:
            self.num_params_stored += len(m.regenerate_parameters(0.0))

    def init_memory(self):
        self.init_num_params_stored()

        total = (self.num_batches + 1) * self.batch_size
        num_px = self.image_size[0] * self.image_size[1]

        self.res_data = numpy.empty((total, num_px))
        self.params = numpy.empty((total, self.num_params_stored))

    def run(self, batch_iterator, complexity_iterator):
        img_size = self.image_size

        for batch_no, batch in enumerate(batch_iterator):
            complexity = complexity_iterator.next()

            assert len(batch) == self.batch_size

            for img_no, img in enumerate(batch):
                sys.stdout.flush()
                global_idx = batch_no*self.batch_size + img_no

                img = img.reshape(img_size)

                param_idx = 0
                for mod in self.modules:
                    # This used to be done _per batch_,
                    # ie. out of the "for img" loop                   
                    p = mod.regenerate_parameters(complexity)
                    self.params[global_idx, param_idx:param_idx+len(p)] = p
                    param_idx += len(p)

                    img = mod.transform_image(img)

                self.res_data[global_idx] = \
                        img.reshape((img_size[0] * img_size[1],))*255

                pylab.imshow(img)
                pylab.draw()

    def write_output(self, output_file_path, params_output_file_path):
        with open(output_file_path, 'wb') as f:
            ft.write(f, self.res_data)

        numpy.save(params_output_file_path, self.params)
                

##############################################################################
# COMPLEXITY ITERATORS
# They're called once every batch, to get the complexity to use for that batch
# they must be infinite (should never throw StopIteration when calling next())

# probability of generating 0 complexity, otherwise
# uniform over 0.0-max_complexity
def range_complexity_iterator(probability_zero, max_complexity):
    assert max_complexity <= 1.0
    n = numpy.random.uniform(0.0, 1.0)
    while True:
        if n < probability_zero:
            yield 0.0
        else:
            yield numpy.random.uniform(0.0, max_complexity)

##############################################################################
# DATA ITERATORS
# They can be used to interleave different data sources etc.

class NistData():
    def __init__(self, ):
        nist_path = DEFAULT_NIST_PATH
        self.train_data = open(nist_path, 'rb')
        self.dim = tuple(ft._read_header(self.train_data)[3])

def just_nist_iterator(nist, batch_size, stop_after=None):
    for i in xrange(0, nist.dim[0], batch_size):
        nist.train_data.seek(0)
        yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255

        if not stop_after is None and i >= stop_after:
            break

##############################################################################
# MAIN

def usage():
    print '''
Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
    -m, --max-complexity: max complexity to generate for a batch
    -z, --probability-zero: probability of using complexity=0 for a batch
    -o, --output-file: full path to file to use for output of images
    -p, --params-output-file: path to file to output params to
    '''

# See run_pipeline.py
def get_argv():
    with open(ARGS_FILE) as f:
        args = [l.rstrip() for l in f.readlines()]
    return args

# Might be called locally or through dbidispatch. In all cases it should be
# passed to the GIMP executable to be able to use GIMP filters.
# Ex: 
def main():
    max_complexity = 0.5 # default
    probability_zero = 0.1 # default
    output_file_path = None
    params_output_file_path = None
    stop_after = None

    try:
        opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="])
    except getopt.GetoptError, err:
        # print help information and exit:
        print str(err) # will print something like "option -a not recognized"
        usage()
        sys.exit(2)
    output = None
    verbose = False
    for o, a in opts:
        if o in ('-m', '--max-complexity'):
            max_complexity = float(a)
            assert max_complexity >= 0.0 and max_complexity <= 1.0
        elif o in ("-z", "--probability-zero"):
            probability_zero = float(a)
            assert probability_zero >= 0.0 and probability_zero <= 1.0
        elif o in ("-o", "--output-file"):
            output_file_path = a
        elif o in ('-p', "--params-output-file"):
            params_output_file_path = a
        elif o in ('-s', "--stop-after"):
            stop_after = int(a)
        else:
            assert False, "unhandled option"

    if output_file_path == None or params_output_file_path == None:
        print "Must specify both output files."
        print
        usage()
        sys.exit(2)

    nist = NistData()
    num_batches = nist.dim[0]/BATCH_SIZE
    if stop_after:
        num_batches = stop_after
    pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
    cpx_it = range_complexity_iterator(probability_zero, max_complexity)
    batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)

    pl.run(batch_it, cpx_it)
    pl.write_output(output_file_path, params_output_file_path)

main()

pdb.gimp_quit(0)
pylab.ioff()
pylab.show()