Mercurial > ift6266
changeset 61:cc4be6b25b8e
Data iterator alternating between NIST/font/captcha, removed the use of batches, keep track of labels (Not fully done yet)
author | boulanni <nicolas_boulanger@hotmail.com> |
---|---|
date | Mon, 08 Feb 2010 23:45:17 -0500 |
parents | d508f5a8acd0 |
children | bab98bb47616 |
files | transformations/pipeline.py |
diffstat | 1 files changed, 89 insertions(+), 66 deletions(-) [+] |
line wrap: on
line diff
--- a/transformations/pipeline.py Mon Feb 08 14:55:22 2010 -0500 +++ b/transformations/pipeline.py Mon Feb 08 23:45:17 2010 -0500 @@ -4,14 +4,14 @@ from __future__ import with_statement # This is intended to be run as a GIMP script -from gimpfu import * +#from gimpfu import * import sys, os, getopt import numpy import filetensor as ft import random -# To debug locally, also call with -s 1 (to stop after 1 batch ~= 100) +# To debug locally, also call with -s 100 (to stop after ~100) # (otherwise we allocate all needed memory, might be loonnng and/or crash # if, lucky like me, you have an age-old laptop creaking from everywhere) DEBUG = True @@ -31,8 +31,8 @@ # (create it, otherwise it'll crash) DEBUG_OUTPUT_DIR = 'debug_out' -BATCH_SIZE = 100 -DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft' +DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft' +DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft' ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE'] if DEBUG_X: @@ -72,10 +72,9 @@ END_TRANSFORM_HOOK = [VISUALIZER] class Pipeline(): - def __init__(self, modules, num_batches, batch_size, image_size=(32,32)): + def __init__(self, modules, num_img, image_size=(32,32)): self.modules = modules - self.num_batches = num_batches - self.batch_size = batch_size + self.num_img = num_img self.num_params_stored = 0 self.image_size = image_size @@ -91,66 +90,65 @@ def init_memory(self): self.init_num_params_stored() - total = self.num_batches * self.batch_size + total = self.num_img num_px = self.image_size[0] * self.image_size[1] - self.res_data = numpy.empty((total, num_px)) + self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8) # +1 to store complexity self.params = numpy.empty((total, self.num_params_stored+1)) + self.res_labels = numpy.empty(total, dtype=numpy.int32) - def run(self, batch_iterator, complexity_iterator): + def run(self, img_iterator, complexity_iterator): img_size = self.image_size should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0 should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0 - for batch_no, batch in enumerate(batch_iterator): + for img_no, (img, label) in enumerate(img_iterator): + sys.stdout.flush() complexity = complexity_iterator.next() - if DEBUG: - print "Complexity:", complexity - assert len(batch) == self.batch_size + global_idx = img_no - for img_no, img in enumerate(batch): - sys.stdout.flush() - global_idx = batch_no*self.batch_size + img_no - - img = img.reshape(img_size) + img = img.reshape(img_size) - param_idx = 1 - # store complexity along with other params - self.params[global_idx, 0] = complexity - for mod in self.modules: - # This used to be done _per batch_, - # ie. out of the "for img" loop - p = mod.regenerate_parameters(complexity) - self.params[global_idx, param_idx:param_idx+len(p)] = p - param_idx += len(p) + param_idx = 1 + # store complexity along with other params + self.params[global_idx, 0] = complexity + for mod in self.modules: + # This used to be done _per batch_, + # ie. out of the "for img" loop + p = mod.regenerate_parameters(complexity) + self.params[global_idx, param_idx:param_idx+len(p)] = p + param_idx += len(p) - img = mod.transform_image(img) + img = mod.transform_image(img) - if should_hook_after_each: - for hook in AFTER_EACH_MODULE_HOOK: - hook.after_transform_callback(img) - - self.res_data[global_idx] = \ - img.reshape((img_size[0] * img_size[1],))*255 + if should_hook_after_each: + for hook in AFTER_EACH_MODULE_HOOK: + hook.after_transform_callback(img) + self.res_data[global_idx] = \ + img.reshape((img_size[0] * img_size[1],))*255 + self.res_labels[global_idx] = label - if should_hook_at_the_end: - for hook in END_TRANSFORM_HOOK: - hook.end_transform_callback(img) + if should_hook_at_the_end: + for hook in END_TRANSFORM_HOOK: + hook.end_transform_callback(img) - def write_output(self, output_file_path, params_output_file_path): + def write_output(self, output_file_path, params_output_file_path, labels_output_file_path): with open(output_file_path, 'wb') as f: ft.write(f, self.res_data) numpy.save(params_output_file_path, self.params) + + with open(labels_output_file_path, 'wb') as f: + ft.write(f, self.res_labels) ############################################################################## # COMPLEXITY ITERATORS -# They're called once every batch, to get the complexity to use for that batch +# They're called once every img, to get the complexity to use for that img # they must be infinite (should never throw StopIteration when calling next()) # probability of generating 0 complexity, otherwise @@ -190,19 +188,25 @@ ''' class NistData(): - def __init__(self, ): - nist_path = DEFAULT_NIST_PATH + def __init__(self, nist_path, label_path): self.train_data = open(nist_path, 'rb') + self.train_labels = open(label_path, 'rb') self.dim = tuple(ft._read_header(self.train_data)[3]) -def just_nist_iterator(nist, batch_size, stop_after=None): - for i in xrange(0, nist.dim[0], batch_size): - if not stop_after is None and i >= stop_after: - break - - nist.train_data.seek(0) - yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255 - +def nist_supp_iterator(nist, prob_font, prob_captcha, num_img): + subtensor = slice(0, num_img) + img = ft.read(nist.train_data, subtensor).astype(numpy.float32)/255 + labels = ft.read(nist.train_labels, subtensor) + + for i in xrange(num_img): + r = numpy.random.rand() + if r<= prob_font: + pass #get font + elif r<= prob_font + prob_captcha: + pass #get captcha + else: + j = numpy.random.randint(num_img) + yield img[j], labels[j] # Mostly for debugging, for the moment, just to see if we can @@ -225,10 +229,15 @@ def usage(): print ''' Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...] - -m, --max-complexity: max complexity to generate for a batch - -z, --probability-zero: probability of using complexity=0 for a batch + -m, --max-complexity: max complexity to generate for an image + -z, --probability-zero: probability of using complexity=0 for an image -o, --output-file: full path to file to use for output of images -p, --params-output-file: path to file to output params to + -r, --labels-output-file: path to file to output labels to + -f, --data-file: path to filetensor (.ft) data file (NIST) + -l, --label-file: path to filetensor (.ft) labels file (NIST labels) + -a, --prob-font: probability of using a raw font image + -b, --prob-captcha: probability of using a captcha image ''' # See run_pipeline.py @@ -245,11 +254,16 @@ probability_zero = 0.1 # default output_file_path = None params_output_file_path = None + labels_output_file_path = None + nist_path = DEFAULT_NIST_PATH + label_path = DEFAULT_LABEL_PATH + prob_font = 0.0 + prob_captcha = 0.0 stop_after = None reload_mode = False try: - opts, args = getopt.getopt(get_argv(), "rm:z:o:p:s:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="]) + opts, args = getopt.getopt(get_argv(), "rm:z:o:p:r:s:f:l:a:b:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "prob-font=", "prob-captcha="]) except getopt.GetoptError, err: # print help information and exit: print str(err) # will print something like "option -a not recognized" @@ -269,13 +283,23 @@ output_file_path = a elif o in ('-p', "--params-output-file"): params_output_file_path = a + elif o in ('-r', "--labels-output-file"): + labels_output_file_path = a elif o in ('-s', "--stop-after"): stop_after = int(a) + elif o in ('-f', "--data-file"): + nist_path = a + elif o in ('-l', "--label-file"): + label_path = a + elif o in ('-a', "--prob-font"): + prob_font = float(a) + elif o in ('-b', "--prob-captcha"): + prob_captcha = float(a) else: assert False, "unhandled option" - if output_file_path == None or params_output_file_path == None: - print "Must specify both output files." + if output_file_path == None or params_output_file_path == None or labels_output_file_path == None: + print "Must specify the three output files." print usage() sys.exit(2) @@ -287,22 +311,21 @@ ''' # This code is yet untested debug_images = DebugImages(DEBUG_IMAGES_PATH) - num_batches = 1 - batch_size = len(debug_images.filelist) - pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) - batch_it = debug_images_iterator(debug_images) + num_img = len(debug_images.filelist) + pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) + img_it = debug_images_iterator(debug_images) ''' else: - nist = NistData() - num_batches = nist.dim[0]/BATCH_SIZE + nist = NistData(nist_path, label_path) + num_img = nist.dim[0] if stop_after: - num_batches = stop_after - pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32)) - batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after) + num_img = stop_after + pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32)) + img_it = nist_supp_iterator(nist, prob_font, prob_captcha, num_img) cpx_it = range_complexity_iterator(probability_zero, max_complexity) - pl.run(batch_it, cpx_it) - pl.write_output(output_file_path, params_output_file_path) + pl.run(img_it, cpx_it) + pl.write_output(output_file_path, params_output_file_path, labels_output_file_path) _main()