# HG changeset patch # User James Bergstra # Date 1258397033 18000 # Node ID faa9f880d0d24728c61fa348b8dd8a83e2747618 # Parent 533b9a0a14f36fb24563b8c81a184a1600aac8d9 fixes to dataset_ops.shapeset1 diff -r 533b9a0a14f3 -r faa9f880d0d2 pylearn/dataset_ops/shapeset1.py --- a/pylearn/dataset_ops/shapeset1.py Mon Nov 16 13:37:52 2009 -0500 +++ b/pylearn/dataset_ops/shapeset1.py Mon Nov 16 13:43:53 2009 -0500 @@ -1,5 +1,9 @@ """A Theano Op to load/access Shapeset1 """ +import theano, numpy + +from .protocol import TensorFnDataset + from ..datasets.shapeset1 import head_train, head_valid, head_test #make global functions so Op can be pickled. @@ -7,6 +11,8 @@ def train_img(dtype): if dtype not in _train_cache: x, y = head_train() + if dtype.startswith('uint') or dtype.startswith('int'): + x *= 255 _train_cache[dtype] = numpy.asarray(x, dtype=dtype) _train_cache['lbl'] = y return _train_cache[dtype] @@ -21,6 +27,8 @@ def valid_img(dtype): if dtype not in _valid_cache: x, y = head_valid() + if dtype.startswith('uint') or dtype.startswith('int'): + x *= 255 _valid_cache[dtype] = numpy.asarray(x, dtype=dtype) _valid_cache['lbl'] = y return _valid_cache[dtype] @@ -35,6 +43,8 @@ def test_img(dtype): if dtype not in _test_cache: x, y = head_test() + if dtype.startswith('uint') or dtype.startswith('int'): + x *= 255 _test_cache[dtype] = numpy.asarray(x, dtype=dtype) _test_cache['lbl'] = y return _test_cache[dtype] @@ -65,8 +75,9 @@ x_fn, y_fn = _split_fns[split] - x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), single_shape=(1024,))(idx) - y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(idx) + x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), + single_shape=(1024,))(s_idx) + y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(s_idx) if x.ndim == 1: if not rasterized: x = x.reshape((32,32))