changeset 874:76f71e10f5ef

added dataset_ops.shapeset1
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 16 Nov 2009 13:20:51 -0500
parents 223caaea433a
children 2a8a7ce78c12
files pylearn/dataset_ops/shapeset1.py
diffstat 1 files changed, 86 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/dataset_ops/shapeset1.py	Mon Nov 16 13:20:51 2009 -0500
@@ -0,0 +1,86 @@
+"""A Theano Op to load/access Shapeset1
+"""
+from ..datasets.shapeset1 import head_train, head_valid, head_test
+
+#make global functions so Op can be pickled.
+_train_cache = {}
+def train_img(dtype):
+    if dtype not in _train_cache:
+        x, y = head_train()
+        _train_cache[dtype] = numpy.asarray(x, dtype=dtype)
+        _train_cache['lbl'] = y
+    return _train_cache[dtype]
+def train_lbl():
+    if 'lbl' not in _train_cache:
+        x, y = head_train()
+        # cache x in some format now that it's read (it isn't that big).
+        _train_cache[x.dtype] = x 
+        _train_cache['lbl'] = y
+    return _train_cache['lbl']
+_valid_cache = {}
+def valid_img(dtype):
+    if dtype not in _valid_cache:
+        x, y = head_valid()
+        _valid_cache[dtype] = numpy.asarray(x, dtype=dtype)
+        _valid_cache['lbl'] = y
+    return _valid_cache[dtype]
+def valid_lbl():
+    if 'lbl' not in _valid_cache:
+        x, y = head_valid()
+        # cache x in some format now that it's read (it isn't that big).
+        _valid_cache[x.dtype] = x 
+        _valid_cache['lbl'] = y
+    return _valid_cache['lbl']
+_valid_cache = {}
+def test_img(dtype):
+    if dtype not in _test_cache:
+        x, y = head_test()
+        _test_cache[dtype] = numpy.asarray(x, dtype=dtype)
+        _test_cache['lbl'] = y
+    return _test_cache[dtype]
+def test_lbl():
+    if 'lbl' not in _test_cache:
+        x, y = head_test()
+        # cache x in some format now that it's read (it isn't that big).
+        _test_cache[x.dtype] = x 
+        _test_cache['lbl'] = y
+    return _test_cache['lbl']
+
+_split_fns = dict(
+        train=(train_img, train_lbl),
+        valid=(valid_img, valid_lbl),
+        test=(test_img, test_lbl))
+
+def shapeset1(s_idx, split, dtype='float64', rasterized=False):
+    """
+    :param s_idx:
+
+    :param split:
+
+    :param dtype:
+
+    :param rasterized: return examples as vectors (True) or 28x28 matrices (False)
+
+    """
+
+    x_fn, y_fn = _split_fns[split]
+
+    x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), single_shape=(1024,))(idx)
+    y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(idx)
+    if x.ndim == 1:
+        if not rasterized:
+            x = x.reshape((32,32))
+    elif x.ndim == 2:
+        if not rasterized:
+            x = x.reshape((x.shape[0], 32,32))
+    else:
+        assert False, 'what happened?'
+    return x, y
+nclasses = 10
+
+def glviewer(split='train'):
+    from glviewer import GlViewer
+    i = theano.tensor.iscalar()
+    f = theano.function([i], shapeset1(i, split, dtype='uint8', rasterized=False)[0])
+    GlViewer(f).main()
+