Mercurial > ift6266
view transformations/pipeline.py @ 45:f8a92292b299
merge de 4 fevrier
author | SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca> |
---|---|
date | Thu, 04 Feb 2010 10:27:58 -0500 |
parents | fdb0e0870fb4 |
children | fabf910467b2 |
line wrap: on
line source
#!/usr/bin/python # coding: utf-8 from __future__ import with_statement import sys, os, getopt import numpy import filetensor as ft import random # This is intended to be run as a GIMP script from gimpfu import * DEBUG = True BATCH_SIZE = 100 DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft' ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] if DEBUG: import pylab pylab.ion() #from add_background_image import AddBackground #from affine_transform import AffineTransformation #from PoivreSel import PoivreSel from thick import Thick #from BruitGauss import BruitGauss #from gimp_script import GIMPTransformation #from Rature import Rature #from contrast Contrast from local_elastic_distortions import LocalElasticDistorter from slant import Slant MODULE_INSTANCES = [Thick(), LocalElasticDistorter(), Slant()] class Pipeline(): def __init__(self, modules, num_batches, batch_size, image_size=(32,32)): self.modules = modules self.num_batches = num_batches self.batch_size = batch_size self.num_params_stored = 0 self.image_size = image_size self.init_memory() def init_num_params_stored(self): # just a dummy call to regenerate_parameters() to get the # real number of params (only those which are stored) self.num_params_stored = 0 for m in self.modules: self.num_params_stored += len(m.regenerate_parameters(0.0)) def init_memory(self): self.init_num_params_stored() total = (self.num_batches + 1) * self.batch_size num_px = self.image_size[0] * self.image_size[1] self.res_data = numpy.empty((total, num_px)) self.params = numpy.empty((total, self.num_params_stored)) def run(self, batch_iterator, complexity_iterator): img_size = self.image_size for batch_no, batch in enumerate(batch_iterator): complexity = complexity_iterator.next() assert len(batch) == self.batch_size for img_no, img in enumerate(batch): sys.stdout.flush() global_idx = batch_no*self.batch_size + img_no img = img.reshape(img_size) param_idx = 0 for mod in self.modules: # This used to be done _per batch_, # ie. out of the "for img" loop p = mod.regenerate_parameters(complexity) self.params[global_idx, param_idx:param_idx+len(p)] = p param_idx += len(p) img = mod.transform_image(img) self.res_data[global_idx] = \ img.reshape((img_size[0] * img_size[1],))*255 pylab.imshow(img) pylab.draw() def write_output(self, output_file_path, params_output_file_path): with open(output_file_path, 'wb') as f: ft.write(f, self.res_data) numpy.save(params_output_file_path, self.params) ############################################################################## # COMPLEXITY ITERATORS # They're called once every batch, to get the complexity to use for that batch # they must be infinite (should never throw StopIteration when calling next()) # probability of generating 0 complexity, otherwise # uniform over 0.0-max_complexity def range_complexity_iterator(probability_zero, max_complexity): assert max_complexity <= 1.0 n = numpy.random.uniform(0.0, 1.0) while True: if n < probability_zero: yield 0.0 else: yield numpy.random.uniform(0.0, max_complexity) ############################################################################## # DATA ITERATORS # They can be used to interleave different data sources etc. class NistData(): def __init__(self, ): nist_path = DEFAULT_NIST_PATH self.train_data = open(nist_path, 'rb') self.dim = tuple(ft._read_header(self.train_data)[3]) def just_nist_iterator(nist, batch_size, stop_after=None): for i in xrange(0, nist.dim[0], batch_size): nist.train_data.seek(0) yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255 if not stop_after is None and i >= stop_after: break ############################################################################## # MAIN def usage(): print ''' Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] -m, --max-complexity: max complexity to generate for a batch -z, --probability-zero: probability of using complexity=0 for a batch -o, --output-file: full path to file to use for output of images -p, --params-output-file: path to file to output params to ''' # See run_pipeline.py def get_argv(): with open(ARGS_FILE) as f: args = [l.rstrip() for l in f.readlines()] return args # Might be called locally or through dbidispatch. In all cases it should be # passed to the GIMP executable to be able to use GIMP filters. # Ex: def main(): max_complexity = 0.5 # default probability_zero = 0.1 # default output_file_path = None params_output_file_path = None stop_after = None try: opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="]) except getopt.GetoptError, err: # print help information and exit: print str(err) # will print something like "option -a not recognized" usage() sys.exit(2) output = None verbose = False for o, a in opts: if o in ('-m', '--max-complexity'): max_complexity = float(a) assert max_complexity >= 0.0 and max_complexity <= 1.0 elif o in ("-z", "--probability-zero"): probability_zero = float(a) assert probability_zero >= 0.0 and probability_zero <= 1.0 elif o in ("-o", "--output-file"): output_file_path = a elif o in ('-p', "--params-output-file"): params_output_file_path = a elif o in ('-s', "--stop-after"): stop_after = int(a) else: assert False, "unhandled option" if output_file_path == None or params_output_file_path == None: print "Must specify both output files." print usage() sys.exit(2) nist = NistData() num_batches = nist.dim[0]/BATCH_SIZE if stop_after: num_batches = stop_after pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) cpx_it = range_complexity_iterator(probability_zero, max_complexity) batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after) pl.run(batch_it, cpx_it) pl.write_output(output_file_path, params_output_file_path) main() pdb.gimp_quit(0) pylab.ioff() pylab.show()