Mercurial > pylearn
changeset 1281:5ae77ac21609
extended cifar10.tile_rasterized_examples to work for patches too
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 15 Sep 2010 17:45:04 -0400 |
parents | 1eaa015e88c1 |
children | f36f59e53c28 |
files | pylearn/datasets/cifar10.py |
diffstat | 1 files changed, 7 insertions(+), 5 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/datasets/cifar10.py Wed Sep 15 17:44:26 2010 -0400 +++ b/pylearn/datasets/cifar10.py Wed Sep 15 17:45:04 2010 -0400 @@ -89,18 +89,20 @@ def first_1k(dtype='uint8', ntrain=1000, nvalid=200, ntest=200): return cifar10(dtype, ntrain, nvalid, ntest) -def tile_rasterized_examples(X): +def tile_rasterized_examples(X, img_shape=(32,32)): """Returns an ndarray that is ready to be passed to `image_tiling.save_tiled_raster_images` This function is for the `x` matrices in the cifar dataset, or for the weight matrices (filters) used to multiply them. """ + ndim = img_shape[0]*img_shape[1] + assert ndim *3 == X.shape[1], (ndim, X.shape) X = X.astype('float32') - r = X[:,:1024] - g = X[:,1024:2048] - b = X[:,2048:] + r = X[:,:ndim] + g = X[:,ndim:ndim*2] + b = X[:,ndim*2:] from pylearn.io.image_tiling import tile_raster_images - rval = tile_raster_images((r,g,b,None), img_shape=(32,32)) + rval = tile_raster_images((r,g,b,None), img_shape=img_shape) return rval