Mercurial > pylearn
view pylearn/dataset_ops/shapeset1.py @ 931:1c62fa857cab
forcing int32 label dtype in shapeset1
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 15 Apr 2010 10:50:10 -0400 |
parents | 0f33afbf517e |
children |
line wrap: on
line source
"""A Theano Op to load/access Shapeset1 """ import theano, numpy from .protocol import TensorFnDataset from ..datasets.shapeset1 import head_train, head_valid, head_test #make global functions so Op can be pickled. _train_cache = {} def train_img(dtype): if dtype not in _train_cache: x, y = head_train() if dtype.startswith('uint') or dtype.startswith('int'): x *= 255 _train_cache[dtype] = numpy.asarray(x, dtype=dtype) _train_cache['lbl'] = numpy.asarray(y, dtype='int32') return _train_cache[dtype] def train_lbl(): if 'lbl' not in _train_cache: x, y = head_train() # cache x in some format now that it's read (it isn't that big). _train_cache[x.dtype] = x _train_cache['lbl'] = numpy.asarray(y, dtype='int32') return _train_cache['lbl'] _valid_cache = {} def valid_img(dtype): if dtype not in _valid_cache: x, y = head_valid() if dtype.startswith('uint') or dtype.startswith('int'): x *= 255 _valid_cache[dtype] = numpy.asarray(x, dtype=dtype) _valid_cache['lbl'] = numpy.asarray(y, dtype='int32') return _valid_cache[dtype] def valid_lbl(): if 'lbl' not in _valid_cache: x, y = head_valid() # cache x in some format now that it's read (it isn't that big). _valid_cache[x.dtype] = x _valid_cache['lbl'] = numpy.asarray(y, dtype='int32') return _valid_cache['lbl'] _test_cache = {} def test_img(dtype): if dtype not in _test_cache: x, y = head_test() if dtype.startswith('uint') or dtype.startswith('int'): x *= 255 _test_cache[dtype] = numpy.asarray(x, dtype=dtype) _test_cache['lbl'] = numpy.asarray(y, dtype='int32') return _test_cache[dtype] def test_lbl(): if 'lbl' not in _test_cache: x, y = head_test() # cache x in some format now that it's read (it isn't that big). _test_cache[x.dtype] = x _test_cache['lbl'] = numpy.asarray(y, dtype='int32') return _test_cache['lbl'] _split_fns = dict( train=(train_img, train_lbl), valid=(valid_img, valid_lbl), test=(test_img, test_lbl)) def shapeset1(s_idx, split, dtype='float64', rasterized=False): """ :param s_idx: :param split: :param dtype: :param rasterized: return examples as vectors (True) or 28x28 matrices (False) """ x_fn, y_fn = _split_fns[split] x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), single_shape=(1024,))(s_idx) y = TensorFnDataset(dtype='int32', bcast=(), fn=y_fn)(s_idx) if x.ndim == 1: if not rasterized: x = x.reshape((32,32)) elif x.ndim == 2: if not rasterized: x = x.reshape((x.shape[0], 32,32)) else: assert False, 'what happened?' return x, y nclasses = 10 def glviewer(split='train'): from glviewer import GlViewer i = theano.tensor.iscalar() f = theano.function([i], shapeset1(i, split, dtype='uint8', rasterized=False)[0]) GlViewer(f).main()