Mercurial > pylearn
changeset 933:ca9fc8cae5b5
forcing int32 label dtype in shapeset1
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 15 Apr 2010 10:50:10 -0400 |
parents | b2a60af9cc28 |
children | e0b960ee57f5 |
files | pylearn/dataset_ops/shapeset1.py |
diffstat | 1 files changed, 7 insertions(+), 7 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/dataset_ops/shapeset1.py Thu Apr 15 10:49:46 2010 -0400 +++ b/pylearn/dataset_ops/shapeset1.py Thu Apr 15 10:50:10 2010 -0400 @@ -14,14 +14,14 @@ if dtype.startswith('uint') or dtype.startswith('int'): x *= 255 _train_cache[dtype] = numpy.asarray(x, dtype=dtype) - _train_cache['lbl'] = y + _train_cache['lbl'] = numpy.asarray(y, dtype='int32') return _train_cache[dtype] def train_lbl(): if 'lbl' not in _train_cache: x, y = head_train() # cache x in some format now that it's read (it isn't that big). _train_cache[x.dtype] = x - _train_cache['lbl'] = y + _train_cache['lbl'] = numpy.asarray(y, dtype='int32') return _train_cache['lbl'] _valid_cache = {} def valid_img(dtype): @@ -30,14 +30,14 @@ if dtype.startswith('uint') or dtype.startswith('int'): x *= 255 _valid_cache[dtype] = numpy.asarray(x, dtype=dtype) - _valid_cache['lbl'] = y + _valid_cache['lbl'] = numpy.asarray(y, dtype='int32') return _valid_cache[dtype] def valid_lbl(): if 'lbl' not in _valid_cache: x, y = head_valid() # cache x in some format now that it's read (it isn't that big). _valid_cache[x.dtype] = x - _valid_cache['lbl'] = y + _valid_cache['lbl'] = numpy.asarray(y, dtype='int32') return _valid_cache['lbl'] _test_cache = {} def test_img(dtype): @@ -46,14 +46,14 @@ if dtype.startswith('uint') or dtype.startswith('int'): x *= 255 _test_cache[dtype] = numpy.asarray(x, dtype=dtype) - _test_cache['lbl'] = y + _test_cache['lbl'] = numpy.asarray(y, dtype='int32') return _test_cache[dtype] def test_lbl(): if 'lbl' not in _test_cache: x, y = head_test() # cache x in some format now that it's read (it isn't that big). _test_cache[x.dtype] = x - _test_cache['lbl'] = y + _test_cache['lbl'] = numpy.asarray(y, dtype='int32') return _test_cache['lbl'] _split_fns = dict( @@ -77,7 +77,7 @@ x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), single_shape=(1024,))(s_idx) - y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(s_idx) + y = TensorFnDataset(dtype='int32', bcast=(), fn=y_fn)(s_idx) if x.ndim == 1: if not rasterized: x = x.reshape((32,32))