Mercurial > pylearn
changeset 878:faa9f880d0d2
fixes to dataset_ops.shapeset1
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 16 Nov 2009 13:43:53 -0500 |
parents | 533b9a0a14f3 |
children | 0f33afbf517e |
files | pylearn/dataset_ops/shapeset1.py |
diffstat | 1 files changed, 13 insertions(+), 2 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/dataset_ops/shapeset1.py Mon Nov 16 13:37:52 2009 -0500 +++ b/pylearn/dataset_ops/shapeset1.py Mon Nov 16 13:43:53 2009 -0500 @@ -1,5 +1,9 @@ """A Theano Op to load/access Shapeset1 """ +import theano, numpy + +from .protocol import TensorFnDataset + from ..datasets.shapeset1 import head_train, head_valid, head_test #make global functions so Op can be pickled. @@ -7,6 +11,8 @@ def train_img(dtype): if dtype not in _train_cache: x, y = head_train() + if dtype.startswith('uint') or dtype.startswith('int'): + x *= 255 _train_cache[dtype] = numpy.asarray(x, dtype=dtype) _train_cache['lbl'] = y return _train_cache[dtype] @@ -21,6 +27,8 @@ def valid_img(dtype): if dtype not in _valid_cache: x, y = head_valid() + if dtype.startswith('uint') or dtype.startswith('int'): + x *= 255 _valid_cache[dtype] = numpy.asarray(x, dtype=dtype) _valid_cache['lbl'] = y return _valid_cache[dtype] @@ -35,6 +43,8 @@ def test_img(dtype): if dtype not in _test_cache: x, y = head_test() + if dtype.startswith('uint') or dtype.startswith('int'): + x *= 255 _test_cache[dtype] = numpy.asarray(x, dtype=dtype) _test_cache['lbl'] = y return _test_cache[dtype] @@ -65,8 +75,9 @@ x_fn, y_fn = _split_fns[split] - x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), single_shape=(1024,))(idx) - y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(idx) + x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), + single_shape=(1024,))(s_idx) + y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(s_idx) if x.ndim == 1: if not rasterized: x = x.reshape((32,32))