Mercurial > pylearn
changeset 874:76f71e10f5ef
added dataset_ops.shapeset1
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 16 Nov 2009 13:20:51 -0500 |
parents | 223caaea433a |
children | 2a8a7ce78c12 |
files | pylearn/dataset_ops/shapeset1.py |
diffstat | 1 files changed, 86 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/dataset_ops/shapeset1.py Mon Nov 16 13:20:51 2009 -0500 @@ -0,0 +1,86 @@ +"""A Theano Op to load/access Shapeset1 +""" +from ..datasets.shapeset1 import head_train, head_valid, head_test + +#make global functions so Op can be pickled. +_train_cache = {} +def train_img(dtype): + if dtype not in _train_cache: + x, y = head_train() + _train_cache[dtype] = numpy.asarray(x, dtype=dtype) + _train_cache['lbl'] = y + return _train_cache[dtype] +def train_lbl(): + if 'lbl' not in _train_cache: + x, y = head_train() + # cache x in some format now that it's read (it isn't that big). + _train_cache[x.dtype] = x + _train_cache['lbl'] = y + return _train_cache['lbl'] +_valid_cache = {} +def valid_img(dtype): + if dtype not in _valid_cache: + x, y = head_valid() + _valid_cache[dtype] = numpy.asarray(x, dtype=dtype) + _valid_cache['lbl'] = y + return _valid_cache[dtype] +def valid_lbl(): + if 'lbl' not in _valid_cache: + x, y = head_valid() + # cache x in some format now that it's read (it isn't that big). + _valid_cache[x.dtype] = x + _valid_cache['lbl'] = y + return _valid_cache['lbl'] +_valid_cache = {} +def test_img(dtype): + if dtype not in _test_cache: + x, y = head_test() + _test_cache[dtype] = numpy.asarray(x, dtype=dtype) + _test_cache['lbl'] = y + return _test_cache[dtype] +def test_lbl(): + if 'lbl' not in _test_cache: + x, y = head_test() + # cache x in some format now that it's read (it isn't that big). + _test_cache[x.dtype] = x + _test_cache['lbl'] = y + return _test_cache['lbl'] + +_split_fns = dict( + train=(train_img, train_lbl), + valid=(valid_img, valid_lbl), + test=(test_img, test_lbl)) + +def shapeset1(s_idx, split, dtype='float64', rasterized=False): + """ + :param s_idx: + + :param split: + + :param dtype: + + :param rasterized: return examples as vectors (True) or 28x28 matrices (False) + + """ + + x_fn, y_fn = _split_fns[split] + + x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), single_shape=(1024,))(idx) + y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(idx) + if x.ndim == 1: + if not rasterized: + x = x.reshape((32,32)) + elif x.ndim == 2: + if not rasterized: + x = x.reshape((x.shape[0], 32,32)) + else: + assert False, 'what happened?' + return x, y +nclasses = 10 + +def glviewer(split='train'): + from glviewer import GlViewer + i = theano.tensor.iscalar() + f = theano.function([i], shapeset1(i, split, dtype='uint8', rasterized=False)[0]) + GlViewer(f).main() +