Mercurial > pylearn
view pylearn/dataset_ops/image_patches.py @ 963:06f21a964bd8
datasets - added olshausen_field data loaders, and an image_patches
dataset_op
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 20 Aug 2010 09:30:34 -0400 |
parents | |
children | 507159eea97e |
line wrap: on
line source
import os, numpy import theano from pylearn.datasets.image_patches import ( olshausen_field_1996_whitened_images, extract_random_patches) from .protocol import TensorFnDataset # protocol.py __init__.py from .memo import memo @memo def get_dataset(N,R,C,dtype): seed=98234 rng = numpy.random.RandomState(seed) img_stack = olshausen_field_1996_whitened_images() patch_stack = extract_random_patches(img_stack, N,R,C,rng) return patch_stack.astype(dtype).reshape((N,(R*C))) def image_patches(s_idx, dims, split='train', dtype=theano.config.floatX, rasterized=False): N,R,C=dims if split != 'train': raise NotImplementedError('train/test/valid splits for randomly sampled image patches?') if not rasterized: raise NotImplementedError() op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,)), single_shape=(R*C,)) x = op(s_idx%N) if x.ndim == 1: if not rasterized: x = x.reshape((20,20)) elif x.ndim == 2: if not rasterized: x = x.reshape((x.shape[0], 20,20)) else: assert False, 'what happened?' return x