diff transformations/pipeline.py @ 48:fabf910467b2

Ajouté des hooks pour visualisation à différentes étapes. On peut dumper la grille d'images pour chaque image transformée ou visualiser live avec pylab.imshow() (pas encore essayé cette façon... j'ai un problème avec GIMP+python sur mon laptop).
author fsavard
date Thu, 04 Feb 2010 13:39:46 -0500
parents fdb0e0870fb4
children ff59670cd1f9
line wrap: on
line diff
--- a/transformations/pipeline.py	Thu Feb 04 10:32:07 2010 -0500
+++ b/transformations/pipeline.py	Thu Feb 04 13:39:46 2010 -0500
@@ -3,35 +3,71 @@
 
 from __future__ import with_statement
 
+# This is intended to be run as a GIMP script
+from gimpfu import *
+
 import sys, os, getopt
 import numpy
 import filetensor as ft
 import random
 
-# This is intended to be run as a GIMP script
-from gimpfu import *
+# To debug locally, also call with -s 1 (to stop after 1 batch ~= 100)
+# (otherwise we allocate all needed memory, might be loonnng and/or crash
+# if, lucky like me, you have an age-old laptop creaking from everywhere)
+DEBUG = True
+DEBUG_X = False # Debug under X (pylab.show())
 
-DEBUG = True
+DEBUG_IMAGES_PATH = None
+if DEBUG:
+    # UNTESTED YET
+    # To avoid loading NIST if you don't have it handy
+    # (use with debug_images_iterator(), see main())
+    # To use NIST, leave as = None
+    DEBUG_IMAGES_PATH = None#'/home/francois/Desktop/debug_images'
+
+# Directory where to dump images to visualize results
+# (create it, otherwise it'll crash)
+DEBUG_OUTPUT_DIR = 'debug_out'
+
 BATCH_SIZE = 100
 DEFAULT_NIST_PATH = '/data/lisa/data/nist/by_class/all/all_train_data.ft'
 ARGS_FILE = os.environ['PIPELINE_ARGS_TMPFILE']
 
-if DEBUG:
+if DEBUG_X:
     import pylab
     pylab.ion()
 
 #from add_background_image import AddBackground
 #from affine_transform import AffineTransformation
-#from PoivreSel import PoivreSel
+from PoivreSel import PoivreSel
 from thick import Thick
 #from BruitGauss import BruitGauss
 #from gimp_script import GIMPTransformation
 #from Rature import Rature
-#from contrast Contrast
+from contrast import Contrast
 from local_elastic_distortions import LocalElasticDistorter
 from slant import Slant
 
-MODULE_INSTANCES = [Thick(), LocalElasticDistorter(), Slant()]
+if DEBUG:
+    from visualizer import Visualizer
+    # Either put the visualizer as in the MODULES_INSTANCES list
+    # after each module you want to visualize, or in the
+    # AFTER_EACH_MODULE_HOOK list (but not both, it's redundant)
+    VISUALIZER = Visualizer(to_dir=DEBUG_OUTPUT_DIR,  on_screen=False)
+
+MODULE_INSTANCES = [Thick(), LocalElasticDistorter(), PoivreSel(), Contrast()]
+
+# These should have a "after_transform_callback(self, image)" method
+# (called after each call to transform_image in a module)
+AFTER_EACH_MODULE_HOOK = []
+if DEBUG:
+    AFTER_EACH_MODULE_HOOK = [VISUALIZER]
+
+# These should have a "end_transform_callback(self, final_image" method
+# (called after all modules have been called)
+END_TRANSFORM_HOOK = []
+if DEBUG:
+    END_TRANSFORM_HOOK = [VISUALIZER]
 
 class Pipeline():
     def __init__(self, modules, num_batches, batch_size, image_size=(32,32)):
@@ -62,8 +98,13 @@
     def run(self, batch_iterator, complexity_iterator):
         img_size = self.image_size
 
+        should_hook_after_each = len(AFTER_EACH_MODULE_HOOK) != 0
+        should_hook_at_the_end = len(END_TRANSFORM_HOOK) != 0
+
         for batch_no, batch in enumerate(batch_iterator):
             complexity = complexity_iterator.next()
+            if DEBUG:
+                print "Complexity:", complexity
 
             assert len(batch) == self.batch_size
 
@@ -83,11 +124,17 @@
 
                     img = mod.transform_image(img)
 
+                    if should_hook_after_each:
+                        for hook in AFTER_EACH_MODULE_HOOK:
+                            hook.after_transform_callback(img)
+
                 self.res_data[global_idx] = \
                         img.reshape((img_size[0] * img_size[1],))*255
 
-                pylab.imshow(img)
-                pylab.draw()
+
+                if should_hook_at_the_end:
+                    for hook in END_TRANSFORM_HOOK:
+                        hook.end_transform_callback(img)
 
     def write_output(self, output_file_path, params_output_file_path):
         with open(output_file_path, 'wb') as f:
@@ -116,6 +163,27 @@
 # DATA ITERATORS
 # They can be used to interleave different data sources etc.
 
+'''
+# Following code (DebugImages and iterator) is untested
+
+def load_image(filepath):
+    _RGB_TO_GRAYSCALE = [0.3, 0.59, 0.11, 0.0]
+    img = Image.open(filepath)
+    img = numpy.asarray(img)
+    if len(img.shape) > 2:
+        img = (img * _RGB_TO_GRAYSCALE).sum(axis=2)
+    return (img / 255.0).astype('float')
+
+class DebugImages():
+    def __init__(self, images_dir_path):
+        import glob, os.path
+        self.filelist = glob.glob(os.path.join(images_dir_path, "*.png"))
+
+def debug_images_iterator(debug_images):
+    for path in debug_images.filelist:
+        yield load_image(path)
+'''
+
 class NistData():
     def __init__(self, ):
         nist_path = DEFAULT_NIST_PATH
@@ -151,13 +219,16 @@
 # Might be called locally or through dbidispatch. In all cases it should be
 # passed to the GIMP executable to be able to use GIMP filters.
 # Ex: 
-def main():
+def _main():
     max_complexity = 0.5 # default
     probability_zero = 0.1 # default
     output_file_path = None
     params_output_file_path = None
     stop_after = None
 
+    import sys
+    print "python version: ", sys.version
+
     try:
         opts, args = getopt.getopt(get_argv(), "m:z:o:p:s:", ["max-complexity=", "probability-zero=", "output-file=", "params-output-file=", "stop-after="])
     except getopt.GetoptError, err:
@@ -189,20 +260,32 @@
         usage()
         sys.exit(2)
 
-    nist = NistData()
-    num_batches = nist.dim[0]/BATCH_SIZE
-    if stop_after:
-        num_batches = stop_after
-    pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
+    if DEBUG_IMAGES_PATH:
+        '''
+        # This code is yet untested
+        debug_images = DebugImages(DEBUG_IMAGES_PATH)
+        num_batches = 1
+        batch_size = len(debug_images.filelist)
+        pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
+        batch_it = debug_images_iterator(debug_images)
+        '''
+    else:
+        nist = NistData()
+        num_batches = nist.dim[0]/BATCH_SIZE
+        if stop_after:
+            num_batches = stop_after
+        pl = Pipeline(modules=MODULE_INSTANCES, num_batches=num_batches, batch_size=BATCH_SIZE, image_size=(32,32))
+        batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)
+
     cpx_it = range_complexity_iterator(probability_zero, max_complexity)
-    batch_it = just_nist_iterator(nist, BATCH_SIZE, stop_after)
-
     pl.run(batch_it, cpx_it)
     pl.write_output(output_file_path, params_output_file_path)
 
-main()
+_main()
+
+if DEBUG_X:
+    pylab.ioff()
+    pylab.show()
 
 pdb.gimp_quit(0)
-pylab.ioff()
-pylab.show()