changeset 1281:5ae77ac21609

extended cifar10.tile_rasterized_examples to work for patches too
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 15 Sep 2010 17:45:04 -0400
parents 1eaa015e88c1
children f36f59e53c28
files pylearn/datasets/cifar10.py
diffstat 1 files changed, 7 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/cifar10.py	Wed Sep 15 17:44:26 2010 -0400
+++ b/pylearn/datasets/cifar10.py	Wed Sep 15 17:45:04 2010 -0400
@@ -89,18 +89,20 @@
 def first_1k(dtype='uint8', ntrain=1000, nvalid=200, ntest=200):
     return cifar10(dtype, ntrain, nvalid, ntest)
 
-def tile_rasterized_examples(X):
+def tile_rasterized_examples(X, img_shape=(32,32)):
     """Returns an ndarray that is ready to be passed to `image_tiling.save_tiled_raster_images`
 
     This function is for the `x` matrices in the cifar dataset, or for the weight matrices
     (filters) used to multiply them.
     """
+    ndim = img_shape[0]*img_shape[1]
+    assert ndim *3 == X.shape[1], (ndim, X.shape)
     X = X.astype('float32')
-    r = X[:,:1024]
-    g = X[:,1024:2048]
-    b = X[:,2048:]
+    r = X[:,:ndim]
+    g = X[:,ndim:ndim*2]
+    b = X[:,ndim*2:]
     from pylearn.io.image_tiling import tile_raster_images
-    rval = tile_raster_images((r,g,b,None), img_shape=(32,32))
+    rval = tile_raster_images((r,g,b,None), img_shape=img_shape)
     return rval