changeset 1521:6397233f3ccd

autopep8
author Frederic Bastien <nouiz@nouiz.org>
date Wed, 31 Oct 2012 16:12:57 -0400
parents 61134776e33c
children 5972fab3cfd2
files pylearn/dataset_ops/image_patches.py
diffstat 1 files changed, 27 insertions(+), 18 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/image_patches.py	Wed Oct 31 14:36:55 2012 -0400
+++ b/pylearn/dataset_ops/image_patches.py	Wed Oct 31 16:12:57 2012 -0400
@@ -1,60 +1,64 @@
-import os, numpy
+import os
+import numpy
 import theano
 
 from pylearn.datasets.image_patches import  (
         olshausen_field_1996_whitened_images,
         extract_random_patches)
 
-from .protocol import TensorFnDataset # protocol.py __init__.py
+from .protocol import TensorFnDataset  # protocol.py __init__.py
 from .memo import memo
 
 import scipy.io
 from pylearn.io import image_tiling
 from pylearn.datasets.config import get_filepath_in_roots
 
+
 @memo
-def get_dataset(N,R,C,dtype,center,unitvar):
-    seed=98234
+def get_dataset(N, R, C, dtype, center, unitvar):
+    seed = 98234
     rng = numpy.random.RandomState(seed)
     img_stack = olshausen_field_1996_whitened_images()
-    patch_stack = extract_random_patches(img_stack, N,R,C,rng)
-    rval = patch_stack.astype(dtype).reshape((N,(R*C)))
+    patch_stack = extract_random_patches(img_stack, N, R, C, rng)
+    rval = patch_stack.astype(dtype).reshape((N, (R * C)))
 
     if center:
         rval -= rval.mean(axis=0)
     if unitvar:
-        rval /= numpy.max(rval.std(axis=0),1e-8)
+        rval /= numpy.max(rval.std(axis=0), 1e-8)
 
     return rval
 
+
 def image_patches(s_idx, dims,
         split='train', dtype=theano.config.floatX, rasterized=False,
         center=True,
         unitvar=True,
         fn=get_dataset):
-    N,R,C=dims
+    N, R, C = dims
 
     if split != 'train':
-        raise NotImplementedError('train/test/valid splits for randomly sampled image patches?')
+        raise NotImplementedError(
+            'train/test/valid splits for randomly sampled image patches?')
 
     if not rasterized:
         raise NotImplementedError()
 
-    op = TensorFnDataset(dtype, bcast=(False,), fn=(fn, (N,R,C,dtype,center,unitvar)), single_shape=(R*C,))
-    x = op(s_idx%N)
+    op = TensorFnDataset(dtype, bcast=(False, ), fn=(fn, (N, R, C, dtype,
+        center, unitvar)), single_shape=(R * C, ))
+    x = op(s_idx % N)
     if x.ndim == 1:
         if not rasterized:
-            x = x.reshape((20,20))
+            x = x.reshape((20, 20))
     elif x.ndim == 2:
         if not rasterized:
-            x = x.reshape((x.shape[0], 20,20))
+            x = x.reshape((x.shape[0], 20, 20))
     else:
         assert False, 'what happened?'
 
     return x
 
 
-
 @memo
 def ranzato_hinton_2010(path=None):
     if path is None:
@@ -62,12 +66,15 @@
                 'training_colorpatches_16x16_demo.mat'))
     dct = scipy.io.loadmat(path)
     return dct
+
+
 def ranzato_hinton_2010_whitened_patches(path=None):
     """Return the pca of the data, which is 10240 x 105
     """
     dct = ranzato_hinton_2010(path)
     return dct['whitendata'].astype('float32')
 
+
 def undo_pca_filters_of_ranzato_hinton_2010(X, path=None):
     """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row)
 
@@ -75,15 +82,16 @@
     """
     dct = ranzato_hinton_2010(path)
     X = numpy.dot(X, dct['invpcatransf'].T)
-    return (X[:,:256], X[:,256:512], X[:,512:], None)
+    return (X[:, :256], X[:, 256:512], X[:, 512:], None)
 
 def save_filters_of_ranzato_hinton_2010(X, fname, min_dynamic_range=1e-3, data_path=None):
     _img = image_tiling.tile_raster_images(
             undo_pca_filters_of_ranzato_hinton_2010(X, path=data_path),
-            img_shape=(16,16),
+            img_shape=(16, 16),
             min_dynamic_range=min_dynamic_range)
     image_tiling.save_tiled_raster_images(_img, fname)
 
+
 def ranzato_hinton_2010_op(s_idx,
         split='train',
         dtype=theano.config.floatX, rasterized=True,
@@ -93,7 +101,8 @@
     N = 10240
 
     if split != 'train':
-        raise NotImplementedError('train/test/valid splits for randomly sampled image patches?')
+        raise NotImplementedError(
+            'train/test/valid splits for randomly sampled image patches?')
 
     if not rasterized:
         # the data is provided as PCA-sphered, so rasterizing does not make sense
@@ -108,5 +117,5 @@
             bcast=(False,),
             fn=fn,
             single_shape=(105,))
-    x = op(s_idx%N)
+    x = op(s_idx % N)
     return x