view pylearn/dataset_ops/image_patches.py @ 1510:07b48bd449cd

Make a dataset ops use the new path system.
author Frederic Bastien <nouiz@nouiz.org>
date Mon, 12 Sep 2011 11:47:00 -0400
parents 976539956475
children 9ffe5d6faee3
line wrap: on
line source

import os, numpy
import theano

from pylearn.datasets.image_patches import  (
        olshausen_field_1996_whitened_images,
        extract_random_patches)

from .protocol import TensorFnDataset # protocol.py __init__.py
from .memo import memo

import scipy.io
from pylearn.io import image_tiling
from pylearn.datasets.config import get_filepath_in_roots

@memo
def get_dataset(N,R,C,dtype,center,unitvar):
    seed=98234
    rng = numpy.random.RandomState(seed)
    img_stack = olshausen_field_1996_whitened_images()
    patch_stack = extract_random_patches(img_stack, N,R,C,rng)
    rval = patch_stack.astype(dtype).reshape((N,(R*C)))

    if center:
        rval -= rval.mean(axis=0)
    if unitvar:
        rval /= numpy.max(rval.std(axis=0),1e-8)

    return rval

def image_patches(s_idx, dims,
        split='train', dtype=theano.config.floatX, rasterized=False,
        center=True,
        unitvar=True,
        fn=get_dataset):
    N,R,C=dims

    if split != 'train':
        raise NotImplementedError('train/test/valid splits for randomly sampled image patches?')

    if not rasterized:
        raise NotImplementedError()

    op = TensorFnDataset(dtype, bcast=(False,), fn=(fn, (N,R,C,dtype,center,unitvar)), single_shape=(R*C,))
    x = op(s_idx%N)
    if x.ndim == 1:
        if not rasterized:
            x = x.reshape((20,20))
    elif x.ndim == 2:
        if not rasterized:
            x = x.reshape((x.shape[0], 20,20))
    else:
        assert False, 'what happened?'

    return x



@memo
def ranzato_hinton_2010(path=None):
    if path is None:
        path = get_filepath_in_roots(os.path.join('image_patches', 'mcRBM',
                'training_colorpatches_16x16_demo.mat'))
    dct = scipy.io.loadmat(path)
    return dct
def ranzato_hinton_2010_whitened_patches(path=None):
    """Return the pca of the data, which is 10240 x 105
    """
    dct = ranzato_hinton_2010(path)
    return dct['whitendata'].astype('float32')

def undo_pca_filters_of_ranzato_hinton_2010(X, path=None):
    """Return tuple (R,G,B,None) of matrices for matrix `X` of filters (one per row)
    
    Return value can be passed to `image_tiling.tile_raster_images`.
    """
    dct = ranzato_hinton_2010(path)
    X = numpy.dot(X, dct['invpcatransf'].T)
    return (X[:,:256], X[:,256:512], X[:,512:], None)

def save_filters_of_ranzato_hinton_2010(X, fname, min_dynamic_range=1e-3, data_path=None):
    _img = image_tiling.tile_raster_images(
            undo_pca_filters_of_ranzato_hinton_2010(X, path=data_path),
            img_shape=(16,16),
            min_dynamic_range=min_dynamic_range)
    image_tiling.save_tiled_raster_images(_img, fname)

def ranzato_hinton_2010_op(s_idx,
        split='train', 
        dtype=theano.config.floatX, rasterized=True,
        center=True,
        unitvar=True,
        fn=ranzato_hinton_2010_whitened_patches):
    N = 10240

    if split != 'train':
        raise NotImplementedError('train/test/valid splits for randomly sampled image patches?')

    if not rasterized:
        # the data is provided as PCA-sphered, so rasterizing does not make sense
        # TODO: add a param to enable/disable 'PCA', and if disabled, then consider
        # rasterizing or not
        raise NotImplementedError('only pca data is provided')

    if dtype != 'float32':
        raise NotImplementedError('dtype not float32')

    op = TensorFnDataset(dtype,
            bcast=(False,), 
            fn=fn,
            single_shape=(105,))
    x = op(s_idx%N)
    return x