changeset 963:06f21a964bd8

datasets - added olshausen_field data loaders, and an image_patches dataset_op
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 20 Aug 2010 09:30:34 -0400
parents 0fee974dca1d
children 6a778bca0dec
files pylearn/dataset_ops/image_patches.py pylearn/datasets/image_patches.py
diffstat 2 files changed, 112 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/dataset_ops/image_patches.py	Fri Aug 20 09:30:34 2010 -0400
@@ -0,0 +1,41 @@
+import os, numpy
+import theano
+
+from pylearn.datasets.image_patches import  (
+        olshausen_field_1996_whitened_images,
+        extract_random_patches)
+
+from .protocol import TensorFnDataset # protocol.py __init__.py
+from .memo import memo
+
+@memo
+def get_dataset(N,R,C,dtype):
+    seed=98234
+    rng = numpy.random.RandomState(seed)
+    img_stack = olshausen_field_1996_whitened_images()
+    patch_stack = extract_random_patches(img_stack, N,R,C,rng)
+    return patch_stack.astype(dtype).reshape((N,(R*C)))
+
+def image_patches(s_idx, dims,
+        split='train', dtype=theano.config.floatX, rasterized=False):
+    N,R,C=dims
+
+    if split != 'train':
+        raise NotImplementedError('train/test/valid splits for randomly sampled image patches?')
+
+    if not rasterized:
+        raise NotImplementedError()
+
+    op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,)), single_shape=(R*C,))
+    x = op(s_idx%N)
+    if x.ndim == 1:
+        if not rasterized:
+            x = x.reshape((20,20))
+    elif x.ndim == 2:
+        if not rasterized:
+            x = x.reshape((x.shape[0], 20,20))
+    else:
+        assert False, 'what happened?'
+
+    return x
+
--- a/pylearn/datasets/image_patches.py	Fri Aug 20 09:29:30 2010 -0400
+++ b/pylearn/datasets/image_patches.py	Fri Aug 20 09:30:34 2010 -0400
@@ -5,6 +5,8 @@
 import os
 import numpy
 
+import scipy.io  #for loadmat
+
 from .config import data_root
 from .dataset import Dataset
 
@@ -14,6 +16,13 @@
          '12by12_whiten_01': ('natural_images_patches_whiten_12_by_12_0_1.amat',(12,12))}
 
 def load_dataset(ntrain=70000, nvalid=15000, ntest=15000, variant='20by20_whiten_01'):
+    # This is implementation loads files which are way too big (storing bytes in double), and
+    # stored in ascii!??, and which appear to be un-documented variants on the original
+    # raw/whitened
+    # data.
+    # I'd like to deprecate this function.
+    # -JB Aug 2010
+    print >> sys.stderr, "WARNING: pylearn.datasets.image_patches.load_dataset is badly documented and does some weird stuff... could someone who uses this function do something about it?"
     
     ndata = 100000
 
@@ -40,3 +49,65 @@
     rval.img_shape = paths[variant][1]
 
     return rval
+
+
+#TODO: a little function to render a tiling of example images to an image using PIL
+
+#TODO: a pca_load_dataset function that loads the data, as projected onto principle components
+
+def olshausen_field_1996_whitened_images(path=None):
+    """Returns a (512,512,10) ndarray containing 10 whitened images.
+    
+    The images are in floating point.
+    Whitening was done by the paper authors, with band-pass whitening I think.
+    """
+    if path is None:
+        path=os.path.join(data_root(), 'image_patches', 'olshausen',
+                'original', 'IMAGES.mat')
+    images = scipy.io.loadmat(path)['IMAGES']
+    assert images.shape == (512,512,10)
+    return images.astype('float32')
+
+def olshausen_field_1996_raw_images(path=None):
+    """Returns a (512,512,10) ndarray containing 10 images.
+    
+    The images are in floating point.
+    """
+    if path is None:
+        path=os.path.join(data_root(), 'image_patches', 'olshausen',
+                'original', 'IMAGES_RAW.mat')
+    images = scipy.io.loadmat(path)['IMAGES_RAW']
+    assert images.shape == (512,512,10)
+    return images.astype('float32')
+
+def extract_random_patches(img_stack, N, R,C, rng):
+    """Return subimages from the img_stack
+
+    :param img_stack: a 3D ndarray (n_images, rows, cols) or a list of 2D images.
+    :param N: number of patches to extract
+    :param R: number of rows in patch
+    :param C: number of cols in patch
+    :param rng: numpy RandomState
+    
+    Sub-image regions are chosen at random from the img_stack with uniform probability, and
+    then within each image with uniform probability across the image.  Patches from a larger
+    image in the stack therefore would be sampled less frequently than patches from a smaller
+    image in the stack.
+
+    To use ZCA whitening, extract patches from the raw data, and pass it to
+    preprocessing.pca.zca_whitening.
+    """
+    rval = numpy.empty((N,R,C), dtype=img_stack[0].dtype)
+
+    L = len(img_stack)
+    img_idxlist = rng.randint(L,size=N)
+
+    for n, img_idxlist in enumerate(img_idxlist):
+        img_idx = rng.randint(L)
+        img_n = img_stack[img_idx]
+        offset_R = rng.randint(img_n.shape[0]-R+1)
+        offset_C = rng.randint(img_n.shape[1]-C+1)
+        rval[n] = img_n[offset_R:offset_R+R,offset_C:offset_C+C]
+
+    return rval
+