changeset 1456:272879b84d30

added io/image_tiling:tile_slices_to_image which is a better version of tile_raster_*
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 04 Apr 2011 19:03:48 -0400
parents 93e5ce7ccd6d
children 9d941cd77479
files pylearn/io/image_tiling.py
diffstat 1 files changed, 53 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/io/image_tiling.py	Mon Apr 04 19:01:12 2011 -0400
+++ b/pylearn/io/image_tiling.py	Mon Apr 04 19:03:48 2011 -0400
@@ -40,9 +40,13 @@
     :rtype: a 2-d array with same dtype as X.
 
     """
+    # This is premature when tile_slices_to_image is not documented at all yet,
+    # but ultimately true:
+    #print >> sys.stderr, "WARN: tile_raster_images sucks, use tile_slices_to_image"
     if len(img_shape)==3 and img_shape[2]==3:
         # make this save an rgb image
-
+        if scale_rows_to_unit_interval:
+            print >> sys.stderr, "WARN: tile_raster_images' scaling routine messes up colour - try tile_slices_to_image"
         return tile_raster_images(
                 (X[:,0::3], X[:,1::3], X[:,2::3], None),
                 img_shape=img_shape[:2],
@@ -69,6 +73,8 @@
         in zip(img_shape, tile_shape, tile_spacing)]
 
     if isinstance(X, tuple):
+        if scale_rows_to_unit_interval:
+            raise NotImplementedError()
         assert len(X) == 4
         if output_pixel_vals:
             out_array = numpy.zeros((out_shape[0], out_shape[1], 4), dtype='uint8')
@@ -141,3 +147,49 @@
     img.save(filename)
     return img
 
+def tile_slices_to_image_uint8(X, tile_shape=None):
+    if str(X.dtype) != 'uint8':
+        raise TypeError(X)
+    if tile_shape is None:
+        #how many tile rows and cols
+        (TR, TC) = most_square_shape(X.shape[0])
+    H, W = X.shape[1], X.shape[2]
+
+    Hs = H+1 #spacing between tiles
+    Ws = W+1 #spacing between tiles
+
+    trows, tcols= most_square_shape(X.shape[0])
+    outrows = trows * Hs - 1
+    outcols = tcols * Ws - 1
+    out = numpy.zeros((outrows, outcols,3), dtype='uint8')
+    tr_stride= 1+X.shape[1]
+    for tr in range(trows):
+        for tc in range(tcols):
+            Xrc = X[tr*tcols+tc]
+            if Xrc.ndim==2: # if no color channel make it broadcast
+                Xrc=Xrc[:,:,None]
+            #print Xrc.shape
+            #print out[tr*Hs:tr*Hs+H,tc*Ws:tc*Ws+W].shape
+            out[tr*Hs:tr*Hs+H,tc*Ws:tc*Ws+W] = Xrc
+    img = Image.fromarray(out, 'RGB')
+    return img
+
+def tile_slices_to_image(X,
+        tile_shape=None,
+        scale_each=True,
+        min_dynamic_range=1e-4):
+    #always returns an RGB image
+    def scale_0_255(x):
+        xmin = x.min()
+        xmax = x.max()
+        return numpy.asarray(
+                255 * (x - xmin) / max(xmax - xmin, min_dynamic_range),
+                dtype='uint8')
+
+    if scale_each:
+        x = X.copy()
+        for i, Xi in enumerate(X):
+            X[i] = scale_0_255(Xi)
+    else:
+        X = scale_0_255(X)
+    return tile_slices_to_image_uint8(X, tile_shape=tile_shape)