changeset 971:507159eea97e

image_patches - return centered data by default
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 23 Aug 2010 15:54:54 -0400
parents 930b92f88e61
children 0b392d1401c5
files pylearn/dataset_ops/image_patches.py
diffstat 1 files changed, 13 insertions(+), 4 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/image_patches.py	Mon Aug 23 15:54:19 2010 -0400
+++ b/pylearn/dataset_ops/image_patches.py	Mon Aug 23 15:54:54 2010 -0400
@@ -9,15 +9,24 @@
 from .memo import memo
 
 @memo
-def get_dataset(N,R,C,dtype):
+def get_dataset(N,R,C,dtype,center,unitvar):
     seed=98234
     rng = numpy.random.RandomState(seed)
     img_stack = olshausen_field_1996_whitened_images()
     patch_stack = extract_random_patches(img_stack, N,R,C,rng)
-    return patch_stack.astype(dtype).reshape((N,(R*C)))
+    rval = patch_stack.astype(dtype).reshape((N,(R*C)))
+
+    if center:
+        rval -= rval.mean(axis=0)
+    if unitvar:
+        rval /= numpy.max(rval.std(axis=0),1e-8)
+
+    return rval
 
 def image_patches(s_idx, dims,
-        split='train', dtype=theano.config.floatX, rasterized=False):
+        split='train', dtype=theano.config.floatX, rasterized=False,
+        center=True,
+        unitvar=True):
     N,R,C=dims
 
     if split != 'train':
@@ -26,7 +35,7 @@
     if not rasterized:
         raise NotImplementedError()
 
-    op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,)), single_shape=(R*C,))
+    op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,center,unitvar)), single_shape=(R*C,))
     x = op(s_idx%N)
     if x.ndim == 1:
         if not rasterized: