Mercurial > pylearn
changeset 971:507159eea97e
image_patches - return centered data by default
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 23 Aug 2010 15:54:54 -0400 |
parents | 930b92f88e61 |
children | 0b392d1401c5 |
files | pylearn/dataset_ops/image_patches.py |
diffstat | 1 files changed, 13 insertions(+), 4 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/dataset_ops/image_patches.py Mon Aug 23 15:54:19 2010 -0400 +++ b/pylearn/dataset_ops/image_patches.py Mon Aug 23 15:54:54 2010 -0400 @@ -9,15 +9,24 @@ from .memo import memo @memo -def get_dataset(N,R,C,dtype): +def get_dataset(N,R,C,dtype,center,unitvar): seed=98234 rng = numpy.random.RandomState(seed) img_stack = olshausen_field_1996_whitened_images() patch_stack = extract_random_patches(img_stack, N,R,C,rng) - return patch_stack.astype(dtype).reshape((N,(R*C))) + rval = patch_stack.astype(dtype).reshape((N,(R*C))) + + if center: + rval -= rval.mean(axis=0) + if unitvar: + rval /= numpy.max(rval.std(axis=0),1e-8) + + return rval def image_patches(s_idx, dims, - split='train', dtype=theano.config.floatX, rasterized=False): + split='train', dtype=theano.config.floatX, rasterized=False, + center=True, + unitvar=True): N,R,C=dims if split != 'train': @@ -26,7 +35,7 @@ if not rasterized: raise NotImplementedError() - op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,)), single_shape=(R*C,)) + op = TensorFnDataset(dtype, bcast=(False,), fn=(get_dataset, (N,R,C,dtype,center,unitvar)), single_shape=(R*C,)) x = op(s_idx%N) if x.ndim == 1: if not rasterized: