changeset 49:8ce089f30463

Oublier d'add deux fichiers pour dernier commit.
author fsavard
date Thu, 04 Feb 2010 13:40:44 -0500
parents fabf910467b2
children ff59670cd1f9
files transformations/image_tiling.py transformations/visualizer.py
diffstat 2 files changed, 159 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/transformations/image_tiling.py	Thu Feb 04 13:40:44 2010 -0500
@@ -0,0 +1,86 @@
+"""
+Illustrate filters (or data) in a grid of small image-shaped tiles.
+
+Note: taken from the pylearn codebase on Feb 4, 2010 (fsavard)
+"""
+
+import numpy
+from PIL import Image
+
+def scale_to_unit_interval(ndar,eps=1e-8):
+    ndar = ndar.copy()
+    ndar -= ndar.min()
+    ndar *= 1.0 / (ndar.max()+eps)
+    return ndar
+
+def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0,0),
+        scale_rows_to_unit_interval=True, 
+        output_pixel_vals=True
+        ):
+    """
+    Transform an array with one flattened image per row, into an array in which images are
+    reshaped and layed out like tiles on a floor.
+
+    This function is useful for visualizing datasets whose rows are images, and also columns of
+    matrices for transforming those rows (such as the first layer of a neural net).
+
+    :type X: a 2-D ndarray or a tuple of 4 channels, elements of which can be 2-D ndarrays or None
+    :param X: a 2-D array in which every row is a flattened image.
+    :type img_shape: tuple; (height, width)
+    :param img_shape: the original shape of each image
+    :type tile_shape: tuple; (rows, cols)
+    :param tile_shape: the number of images to tile (rows, cols)
+
+    :returns: array suitable for viewing as an image.  (See:`PIL.Image.fromarray`.)
+    :rtype: a 2-d array with same dtype as X.
+
+    """
+    assert len(img_shape) == 2
+    assert len(tile_shape) == 2
+    assert len(tile_spacing) == 2
+
+    out_shape = [(ishp + tsp) * tshp - tsp for ishp, tshp, tsp 
+        in zip(img_shape, tile_shape, tile_spacing)]
+
+    if isinstance(X, tuple):
+        assert len(X) == 4
+        if output_pixel_vals:
+            out_array = numpy.zeros((out_shape[0], out_shape[1], 4), dtype='uint8')
+        else:
+            out_array = numpy.zeros((out_shape[0], out_shape[1], 4), dtype=X.dtype)
+
+        #colors default to 0, alpha defaults to 1 (opaque)
+        if output_pixel_vals:
+            channel_defaults = [0,0,0,255]
+        else:
+            channel_defaults = [0.,0.,0.,1.]
+
+        for i in xrange(4):
+            if X[i] is None:
+                out_array[:,:,i] = numpy.zeros(out_shape,
+                        dtype='uint8' if output_pixel_vals else out_array.dtype
+                        )+channel_defaults[i]
+            else:
+                out_array[:,:,i] = tile_raster_images(X[i], img_shape, tile_shape, tile_spacing, scale_rows_to_unit_interval, output_pixel_vals)
+        return out_array
+
+    else:
+        H, W = img_shape
+        Hs, Ws = tile_spacing
+
+        out_array = numpy.zeros(out_shape, dtype='uint8' if output_pixel_vals else X.dtype)
+        for tile_row in xrange(tile_shape[0]):
+            for tile_col in xrange(tile_shape[1]):
+                if tile_row * tile_shape[1] + tile_col < X.shape[0]:
+                    if scale_rows_to_unit_interval:
+                        this_img = scale_to_unit_interval(X[tile_row * tile_shape[1] + tile_col].reshape(img_shape))
+                    else:
+                        this_img = X[tile_row * tile_shape[1] + tile_col].reshape(img_shape)
+                    out_array[
+                        tile_row * (H+Hs):tile_row*(H+Hs)+H,
+                        tile_col * (W+Ws):tile_col*(W+Ws)+W
+                        ] \
+                        = this_img * (255 if output_pixel_vals else 1)
+        return out_array
+
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/transformations/visualizer.py	Thu Feb 04 13:40:44 2010 -0500
@@ -0,0 +1,73 @@
+#!/usr/bin/python
+
+import numpy
+import Image
+from image_tiling import tile_raster_images
+import pylab
+import time
+
+class Visualizer():
+    def __init__(self, num_columns=10, image_size=(32,32), to_dir=None, on_screen=False):
+        self.list = []
+        self.image_size = image_size
+        self.num_columns = num_columns
+
+        self.on_screen = on_screen
+        self.to_dir = to_dir
+
+        self.cur_grid_image = None
+
+        self.cur_index = 0
+
+    def visualize_stop_and_flush(self):
+        self.make_grid_image()
+
+        if self.on_screen:
+            self.visualize()
+        if self.to_dir:
+            self.dump_to_disk()
+
+        self.stop_and_wait()
+        self.flush()
+
+        self.cur_index += 1
+
+    def make_grid_image(self):
+        num_rows = len(self.list) / self.num_columns
+        if len(self.list) % self.num_columns != 0:
+            num_rows += 1
+        grid_shape = (num_rows, self.num_columns)
+        self.cur_grid_image = tile_raster_images(numpy.array(self.list), self.image_size, grid_shape, tile_spacing=(5,5), output_pixel_vals=False)
+
+    def visualize(self):
+        pylab.imshow(self.cur_grid_image)
+        pylab.draw()
+
+    def dump_to_disk(self):
+        gi = Image.fromarray((self.cur_grid_image * 255).astype('uint8'), "L")
+        gi.save(self.to_dir + "/grid_" + str(self.cur_index) + ".png")
+        
+    def stop_and_wait(self):
+        # can't raw_input under gimp, so sleep)
+        print "New image generated, sleeping 5 secs"
+        time.sleep(5)
+
+    def flush(self):
+        self.list = []
+    
+    def get_parameters_names(self):
+        return []
+
+    def regenerate_parameters(self):
+        return []
+
+    def after_transform_callback(self, image):
+        self.transform_image(image)
+
+    def end_transform_callback(self, final_image):
+        self.visualize_stop_and_flush()
+
+    def transform_image(self, image):
+        sz = self.image_size
+        self.list.append(image.copy().reshape((sz[0] * sz[1])))
+