changeset 1352:cc3e3e596500

dataset_ops/tinyimages - added an img_shape optional flag to save_filters fns
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 03 Nov 2010 12:49:12 -0400
parents 6402b3309ece
children 2024c5618466
files pylearn/dataset_ops/tinyimages.py
diffstat 1 files changed, 9 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/tinyimages.py	Thu Oct 28 16:15:47 2010 -0400
+++ b/pylearn/dataset_ops/tinyimages.py	Wed Nov 03 12:49:12 2010 -0400
@@ -67,7 +67,8 @@
     X[:,:,:,2]=centre(X[:,:,:,2])
     return X
 
-def save_filters(X, fname, min_dynamic_range=1e-8, data_path=None):
+def save_filters_orig(X, fname, min_dynamic_range=1e-8, data_path=None, img_shape=(8,8),
+        tile_shape=None):
     """
     Save filters X (encoded as whitened images) in the original image space.
     """
@@ -76,8 +77,9 @@
 
     _img = image_tiling.tile_raster_images(
             pylearn.preprocessing.pca.pca_whiten_inverse(pca, X),
-            img_shape=(8,8),
-            min_dynamic_range=1e-6)
+            img_shape=img_shape,
+            min_dynamic_range=1e-6,
+            tile_shape=tile_shape)
     image_tiling.save_tiled_raster_images(_img, fname)
 
 def extract_patches(n_imgs=1000*100, n_patches_per_image=10, patch_shape=(8,8), rng=numpy.random.RandomState(234)):
@@ -159,7 +161,7 @@
         i += b
     return rval
 
-def main(n_imgs=1000, n_patches_per_image=10, max_components=128, seed=234):
+def main(n_imgs=1000, n_patches_per_image=10, max_components=128, seed=234, patch_shape=(8,8)):
     if 0: #do this to render the dataset to the screen
         sys.exit(glviewer())
 
@@ -177,7 +179,7 @@
     else:
         print 'extracting raw patches'
         raw_patches = extract_patches(rng=rng, n_imgs=n_imgs,
-                n_patches_per_image=n_patches_per_image)
+                n_patches_per_image=n_patches_per_image, patch_shape=patch_shape)
         rng.shuffle(raw_patches)
         print 'saving raw patches to', _raw_patch_file
         numpy.save(open(_raw_patch_file, 'wb'), raw_patches)
@@ -247,14 +249,14 @@
     return x
 
 
-def save_filters(X, fname, tile_shape=None):
+def save_filters(X, fname, tile_shape=None, img_shape=(8,8)):
     dct = load_pca_dct()
     eigs = dct['eig_vals'], dct['eig_vecs']
     mean = dct['mean']
     rasterized = pylearn.preprocessing.pca.pca_whiten_inverse(eigs, X)+mean
     _img = image_tiling.tile_raster_images(
             (rasterized[:,::3], rasterized[:,1::3], rasterized[:,2::3], None),
-            img_shape=(8,8),
+            img_shape=img_shape,
             min_dynamic_range=1e-6,
             tile_shape=tile_shape)
     image_tiling.save_tiled_raster_images(_img, fname)