view pylearn/dataset_ops/image_patches.py @ 971:507159eea97e

image_patches - return centered data by default
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 23 Aug 2010 15:54:54 -0400
parents 06f21a964bd8
children 8ba8b08e0442
line wrap: on
line source

import os, numpy
import theano

from pylearn.datasets.image_patches import  (
        olshausen_field_1996_whitened_images,
        extract_random_patches)

from .protocol import TensorFnDataset # protocol.py __init__.py
from .memo import memo

@memo
def get_dataset(N,R,C,dtype,center,unitvar):
    seed=98234
    rng = numpy.random.RandomState(seed)
    img_stack = olshausen_field_1996_whitened_images()
    patch_stack = extract_random_patches(img_stack, N,R,C,rng)
    rval = patch_stack.astype(dtype).reshape((N,(R*C)))

    if center:
        rval -= rval.mean(axis=0)
    if unitvar:
        rval /= numpy.max(rval.std(axis=0),1e-8)

    return rval

def image_patches(s_idx, dims,
        split='train', dtype=theano.config.floatX, rasterized=False,
        center=True,
        unitvar=True):
    N,R,C=dims

    if split != 'train':
        raise NotImplementedError('train/test/valid splits for randomly sampled image patches?')

    if not rasterized:
        raise NotImplementedError()

    op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,center,unitvar)), single_shape=(R*C,))
    x = op(s_idx%N)
    if x.ndim == 1:
        if not rasterized:
            x = x.reshape((20,20))
    elif x.ndim == 2:
        if not rasterized:
            x = x.reshape((x.shape[0], 20,20))
    else:
        assert False, 'what happened?'

    return x