# HG changeset patch # User James Bergstra # Date 1284587104 14400 # Node ID 5ae77ac21609e1ca713d8c96e4950fb0684b8cdd # Parent 1eaa015e88c188a2302a93fa39615151744cb84c extended cifar10.tile_rasterized_examples to work for patches too diff -r 1eaa015e88c1 -r 5ae77ac21609 pylearn/datasets/cifar10.py --- a/pylearn/datasets/cifar10.py Wed Sep 15 17:44:26 2010 -0400 +++ b/pylearn/datasets/cifar10.py Wed Sep 15 17:45:04 2010 -0400 @@ -89,18 +89,20 @@ def first_1k(dtype='uint8', ntrain=1000, nvalid=200, ntest=200): return cifar10(dtype, ntrain, nvalid, ntest) -def tile_rasterized_examples(X): +def tile_rasterized_examples(X, img_shape=(32,32)): """Returns an ndarray that is ready to be passed to `image_tiling.save_tiled_raster_images` This function is for the `x` matrices in the cifar dataset, or for the weight matrices (filters) used to multiply them. """ + ndim = img_shape[0]*img_shape[1] + assert ndim *3 == X.shape[1], (ndim, X.shape) X = X.astype('float32') - r = X[:,:1024] - g = X[:,1024:2048] - b = X[:,2048:] + r = X[:,:ndim] + g = X[:,ndim:ndim*2] + b = X[:,ndim*2:] from pylearn.io.image_tiling import tile_raster_images - rval = tile_raster_images((r,g,b,None), img_shape=(32,32)) + rval = tile_raster_images((r,g,b,None), img_shape=img_shape) return rval