Mercurial > pylearn
view pylearn/dataset_ops/image_patches.py @ 971:507159eea97e
image_patches - return centered data by default
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 23 Aug 2010 15:54:54 -0400 |
parents | 06f21a964bd8 |
children | 8ba8b08e0442 |
line wrap: on
line source
import os, numpy import theano from pylearn.datasets.image_patches import ( olshausen_field_1996_whitened_images, extract_random_patches) from .protocol import TensorFnDataset # protocol.py __init__.py from .memo import memo @memo def get_dataset(N,R,C,dtype,center,unitvar): seed=98234 rng = numpy.random.RandomState(seed) img_stack = olshausen_field_1996_whitened_images() patch_stack = extract_random_patches(img_stack, N,R,C,rng) rval = patch_stack.astype(dtype).reshape((N,(R*C))) if center: rval -= rval.mean(axis=0) if unitvar: rval /= numpy.max(rval.std(axis=0),1e-8) return rval def image_patches(s_idx, dims, split='train', dtype=theano.config.floatX, rasterized=False, center=True, unitvar=True): N,R,C=dims if split != 'train': raise NotImplementedError('train/test/valid splits for randomly sampled image patches?') if not rasterized: raise NotImplementedError() op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,center,unitvar)), single_shape=(R*C,)) x = op(s_idx%N) if x.ndim == 1: if not rasterized: x = x.reshape((20,20)) elif x.ndim == 2: if not rasterized: x = x.reshape((x.shape[0], 20,20)) else: assert False, 'what happened?' return x