changeset 61:cc4be6b25b8e

Data iterator alternating between NIST/font/captcha, removed the use of batches, keep track of labels (Not fully done yet)
author boulanni <nicolas_boulanger@hotmail.com>
date Mon, 08 Feb 2010 23:45:17 -0500
parents d508f5a8acd0
children bab98bb47616
files transformations/pipeline.py
diffstat 1 files changed, 89 insertions(+), 66 deletions(-) [+]
line wrap: on
line diff
--- a/transformations/pipeline.py	Mon Feb 08 14:55:22 2010 -0500
+++ b/transformations/pipeline.py	Mon Feb 08 23:45:17 2010 -0500
@@ -4,14 +4,14 @@
 from __future__ import with_statement
 
 # This is intended to be run as a GIMP script
-from gimpfu import *
+#from gimpfu import *
 
 import sys, os, getopt
 import numpy
 import filetensor as ft
 import random
 
-# To debug locally, also call with -s 1 (to stop after 1 batch ~= 100)
+# To debug locally, also call with -s 100 (to stop after ~100)
 # (otherwise we allocate all needed memory, might be loonnng and/or crash
 # if, lucky like me, you have an age-old laptop creaking from everywhere)
 DEBUG = True
@@ -31,8 +31,8 @@
 # (create it, otherwise it'll crash)
 DEBUG_OUTPUT_DIR = 'debug_out'
 
-BATCH_SIZE = 100
-DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
+DEFAULT_NIST_PATH = '/data/lisa/data/ift6266h10/train_data.ft'
+DEFAULT_LABEL_PATH = '/data/lisa/data/ift6266h10/train_labels.ft'
 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
 
 if DEBUG_X:
@@ -72,10 +72,9 @@
     END_TRANSFORM_HOOK = [VISUALIZER]
 
 class Pipeline():
-    def __init__(self, modules, num_batches, batch_size, image_size=(32,32)):
+    def __init__(self, modules, num_img, image_size=(32,32)):
         self.modules = modules
-        self.num_batches = num_batches
-        self.batch_size = batch_size
+        self.num_img = num_img
         self.num_params_stored = 0
         self.image_size = image_size
 
@@ -91,66 +90,65 @@
     def init_memory(self):
         self.init_num_params_stored()
 
-        total = self.num_batches * self.batch_size
+        total = self.num_img
         num_px = self.image_size[0] * self.image_size[1]
 
-        self.res_data = numpy.empty((total, num_px))
+        self.res_data = numpy.empty((total, num_px), dtype=numpy.uint8)
         # +1 to store complexity
         self.params = numpy.empty((total, self.num_params_stored+1))
+        self.res_labels = numpy.empty(total, dtype=numpy.int32)
 
-    def run(self, batch_iterator, complexity_iterator):
+    def run(self, img_iterator, complexity_iterator):
         img_size = self.image_size
 
         should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
         should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0
 
-        for batch_no, batch in enumerate(batch_iterator):
+        for img_no, (img, label) in enumerate(img_iterator):
+            sys.stdout.flush()
             complexity = complexity_iterator.next()
-            if DEBUG:
-                print "Complexity:", complexity
 
-            assert len(batch) == self.batch_size
+            global_idx = img_no
 
-            for img_no, img in enumerate(batch):
-                sys.stdout.flush()
-                global_idx = batch_no*self.batch_size + img_no
-
-                img = img.reshape(img_size)
+            img = img.reshape(img_size)
 
-                param_idx = 1
-                # store complexity along with other params
-                self.params[global_idx, 0] = complexity
-                for mod in self.modules:
-                    # This used to be done _per batch_,
-                    # ie. out of the "for img" loop                   
-                    p = mod.regenerate_parameters(complexity)
-                    self.params[global_idx, param_idx:param_idx+len(p)] = p
-                    param_idx += len(p)
+            param_idx = 1
+            # store complexity along with other params
+            self.params[global_idx, 0] = complexity
+            for mod in self.modules:
+                # This used to be done _per batch_,
+                # ie. out of the "for img" loop                   
+                p = mod.regenerate_parameters(complexity)
+                self.params[global_idx, param_idx:param_idx+len(p)] = p
+                param_idx += len(p)
 
-                    img = mod.transform_image(img)
+                img = mod.transform_image(img)
 
-                    if should_hook_after_each:
-                        for hook in AFTER_EACH_MODULE_HOOK:
-                            hook.after_transform_callback(img)
-
-                self.res_data[global_idx] = \
-                        img.reshape((img_size[0] * img_size[1],))*255
+                if should_hook_after_each:
+                    for hook in AFTER_EACH_MODULE_HOOK:
+                        hook.after_transform_callback(img)
 
+            self.res_data[global_idx] = \
+                    img.reshape((img_size[0] * img_size[1],))*255
+            self.res_labels[global_idx] = label
 
-                if should_hook_at_the_end:
-                    for hook in END_TRANSFORM_HOOK:
-                        hook.end_transform_callback(img)
+            if should_hook_at_the_end:
+                for hook in END_TRANSFORM_HOOK:
+                    hook.end_transform_callback(img)
 
-    def write_output(self, output_file_path, params_output_file_path):
+    def write_output(self, output_file_path, params_output_file_path, labels_output_file_path):
         with open(output_file_path, 'wb') as f:
             ft.write(f, self.res_data)
 
         numpy.save(params_output_file_path, self.params)
+
+        with open(labels_output_file_path, 'wb') as f:
+            ft.write(f, self.res_labels)
                 
 
 ##############################################################################
 # COMPLEXITY ITERATORS
-# They're called once every batch, to get the complexity to use for that batch
+# They're called once every img, to get the complexity to use for that img
 # they must be infinite (should never throw StopIteration when calling next())
 
 # probability of generating 0 complexity, otherwise
@@ -190,19 +188,25 @@
 '''
 
 class NistData():
-    def __init__(self, ):
-        nist_path = DEFAULT_NIST_PATH
+    def __init__(self, nist_path, label_path):
         self.train_data = open(nist_path, 'rb')
+        self.train_labels = open(label_path, 'rb')
         self.dim = tuple(ft._read_header(self.train_data)[3])
 
-def just_nist_iterator(nist, batch_size, stop_after=None):
-    for i in xrange(0, nist.dim[0], batch_size):
-        if not stop_after is None and i >= stop_after:
-            break
-
-        nist.train_data.seek(0)
-        yield ft.read(nist.train_data, slice(i, i+batch_size)).astype(numpy.float32)/255
-
+def nist_supp_iterator(nist, prob_font, prob_captcha, num_img):
+    subtensor = slice(0, num_img)
+    img = ft.read(nist.train_data, subtensor).astype(numpy.float32)/255
+    labels = ft.read(nist.train_labels, subtensor)
+    
+    for i in xrange(num_img):
+        r = numpy.random.rand()
+        if r<= prob_font:
+            pass #get font
+        elif r<= prob_font + prob_captcha:
+            pass #get captcha
+        else:
+            j = numpy.random.randint(num_img)
+            yield img[j], labels[j]
 
 
 # Mostly for debugging, for the moment, just to see if we can
@@ -225,10 +229,15 @@
 def usage():
     print '''
 Usage: run_pipeline.sh [-m ...] [-z ...] [-o ...] [-p ...]
-    -m, --max-complexity: max complexity to generate for a batch
-    -z, --probability-zero: probability of using complexity=0 for a batch
+    -m, --max-complexity: max complexity to generate for an image
+    -z, --probability-zero: probability of using complexity=0 for an image
     -o, --output-file: full path to file to use for output of images
     -p, --params-output-file: path to file to output params to
+    -r, --labels-output-file: path to file to output labels to
+    -f, --data-file: path to filetensor (.ft) data file (NIST)
+    -l, --label-file: path to filetensor (.ft) labels file (NIST labels)
+    -a, --prob-font: probability of using a raw font image
+    -b, --prob-captcha: probability of using a captcha image
     '''
 
 # See run_pipeline.py
@@ -245,11 +254,16 @@
     probability_zero = 0.1 # default
     output_file_path = None
     params_output_file_path = None
+    labels_output_file_path = None
+    nist_path = DEFAULT_NIST_PATH
+    label_path = DEFAULT_LABEL_PATH
+    prob_font = 0.0
+    prob_captcha = 0.0
     stop_after = None
     reload_mode = False
 
     try:
-        opts, args = getopt.getopt(get_argv(), "rm:z:o:p:s:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="])
+        opts, args = getopt.getopt(get_argv(), "rm:z:o:p:r:s:f:l:a:b:", ["reload","max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "labels-output-file=", "stop-after=", "data-file=", "label-file=", "prob-font=", "prob-captcha="])
     except getopt.GetoptError, err:
         # print help information and exit:
         print str(err) # will print something like "option -a not recognized"
@@ -269,13 +283,23 @@
             output_file_path = a
         elif o in ('-p', "--params-output-file"):
             params_output_file_path = a
+        elif o in ('-r', "--labels-output-file"):
+            labels_output_file_path = a
         elif o in ('-s', "--stop-after"):
             stop_after = int(a)
+        elif o in ('-f', "--data-file"):
+            nist_path = a
+        elif o in ('-l', "--label-file"):
+            label_path = a
+        elif o in ('-a', "--prob-font"):
+            prob_font = float(a)
+        elif o in ('-b', "--prob-captcha"):
+            prob_captcha = float(a)
         else:
             assert False, "unhandled option"
 
-    if output_file_path == None or params_output_file_path == None:
-        print "Must specify both output files."
+    if output_file_path == None or params_output_file_path == None or labels_output_file_path == None:
+        print "Must specify the three output files."
         print
         usage()
         sys.exit(2)
@@ -287,22 +311,21 @@
             '''
             # This code is yet untested
             debug_images = DebugImages(DEBUG_IMAGES_PATH)
-            num_batches = 1
-            batch_size = len(debug_images.filelist)
-            pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
-            batch_it = debug_images_iterator(debug_images)
+            num_img = len(debug_images.filelist)
+            pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
+            img_it = debug_images_iterator(debug_images)
             '''
         else:
-            nist = NistData()
-            num_batches = nist.dim[0]/BATCH_SIZE
+            nist = NistData(nist_path, label_path)
+            num_img = nist.dim[0]
             if stop_after:
-                num_batches = stop_after
-            pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
-            batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)
+                num_img = stop_after
+            pl = Pipeline(modules=MODULE_INSTANCES, num_img=num_img, image_size=(32,32))
+            img_it = nist_supp_iterator(nist, prob_font, prob_captcha, num_img)
 
         cpx_it = range_complexity_iterator(probability_zero, max_complexity)
-        pl.run(batch_it, cpx_it)
-        pl.write_output(output_file_path, params_output_file_path)
+        pl.run(img_it, cpx_it)
+        pl.write_output(output_file_path, params_output_file_path, labels_output_file_path)
 
 _main()