view pylearn/dataset_ops/image_patches.py @ 963:06f21a964bd8

datasets - added olshausen_field data loaders, and an image_patches dataset_op
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 20 Aug 2010 09:30:34 -0400
parents
children 507159eea97e
line wrap: on
line source

import os, numpy
import theano

from pylearn.datasets.image_patches import  (
        olshausen_field_1996_whitened_images,
        extract_random_patches)

from .protocol import TensorFnDataset # protocol.py __init__.py
from .memo import memo

@memo
def get_dataset(N,R,C,dtype):
    seed=98234
    rng = numpy.random.RandomState(seed)
    img_stack = olshausen_field_1996_whitened_images()
    patch_stack = extract_random_patches(img_stack, N,R,C,rng)
    return patch_stack.astype(dtype).reshape((N,(R*C)))

def image_patches(s_idx, dims,
        split='train', dtype=theano.config.floatX, rasterized=False):
    N,R,C=dims

    if split != 'train':
        raise NotImplementedError('train/test/valid splits for randomly sampled image patches?')

    if not rasterized:
        raise NotImplementedError()

    op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,)), single_shape=(R*C,))
    x = op(s_idx%N)
    if x.ndim == 1:
        if not rasterized:
            x = x.reshape((20,20))
    elif x.ndim == 2:
        if not rasterized:
            x = x.reshape((x.shape[0], 20,20))
    else:
        assert False, 'what happened?'

    return x