changeset 878:faa9f880d0d2

fixes to dataset_ops.shapeset1
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 16 Nov 2009 13:43:53 -0500
parents 533b9a0a14f3
children 0f33afbf517e
files pylearn/dataset_ops/shapeset1.py
diffstat 1 files changed, 13 insertions(+), 2 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/shapeset1.py	Mon Nov 16 13:37:52 2009 -0500
+++ b/pylearn/dataset_ops/shapeset1.py	Mon Nov 16 13:43:53 2009 -0500
@@ -1,5 +1,9 @@
 """A Theano Op to load/access Shapeset1
 """
+import theano, numpy
+
+from .protocol import TensorFnDataset
+
 from ..datasets.shapeset1 import head_train, head_valid, head_test
 
 #make global functions so Op can be pickled.
@@ -7,6 +11,8 @@
 def train_img(dtype):
     if dtype not in _train_cache:
         x, y = head_train()
+        if dtype.startswith('uint') or dtype.startswith('int'):
+            x *= 255
         _train_cache[dtype] = numpy.asarray(x, dtype=dtype)
         _train_cache['lbl'] = y
     return _train_cache[dtype]
@@ -21,6 +27,8 @@
 def valid_img(dtype):
     if dtype not in _valid_cache:
         x, y = head_valid()
+        if dtype.startswith('uint') or dtype.startswith('int'):
+            x *= 255
         _valid_cache[dtype] = numpy.asarray(x, dtype=dtype)
         _valid_cache['lbl'] = y
     return _valid_cache[dtype]
@@ -35,6 +43,8 @@
 def test_img(dtype):
     if dtype not in _test_cache:
         x, y = head_test()
+        if dtype.startswith('uint') or dtype.startswith('int'):
+            x *= 255
         _test_cache[dtype] = numpy.asarray(x, dtype=dtype)
         _test_cache['lbl'] = y
     return _test_cache[dtype]
@@ -65,8 +75,9 @@
 
     x_fn, y_fn = _split_fns[split]
 
-    x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)), single_shape=(1024,))(idx)
-    y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(idx)
+    x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)),
+            single_shape=(1024,))(s_idx)
+    y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(s_idx)
     if x.ndim == 1:
         if not rasterized:
             x = x.reshape((32,32))