changeset 933:ca9fc8cae5b5

forcing int32 label dtype in shapeset1
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 15 Apr 2010 10:50:10 -0400
parents b2a60af9cc28
children e0b960ee57f5
files pylearn/dataset_ops/shapeset1.py
diffstat 1 files changed, 7 insertions(+), 7 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/shapeset1.py	Thu Apr 15 10:49:46 2010 -0400
+++ b/pylearn/dataset_ops/shapeset1.py	Thu Apr 15 10:50:10 2010 -0400
@@ -14,14 +14,14 @@
         if dtype.startswith('uint') or dtype.startswith('int'):
             x *= 255
         _train_cache[dtype] = numpy.asarray(x, dtype=dtype)
-        _train_cache['lbl'] = y
+        _train_cache['lbl'] = numpy.asarray(y, dtype='int32')
     return _train_cache[dtype]
 def train_lbl():
     if 'lbl' not in _train_cache:
         x, y = head_train()
         # cache x in some format now that it's read (it isn't that big).
         _train_cache[x.dtype] = x 
-        _train_cache['lbl'] = y
+        _train_cache['lbl'] = numpy.asarray(y, dtype='int32')
     return _train_cache['lbl']
 _valid_cache = {}
 def valid_img(dtype):
@@ -30,14 +30,14 @@
         if dtype.startswith('uint') or dtype.startswith('int'):
             x *= 255
         _valid_cache[dtype] = numpy.asarray(x, dtype=dtype)
-        _valid_cache['lbl'] = y
+        _valid_cache['lbl'] = numpy.asarray(y, dtype='int32')
     return _valid_cache[dtype]
 def valid_lbl():
     if 'lbl' not in _valid_cache:
         x, y = head_valid()
         # cache x in some format now that it's read (it isn't that big).
         _valid_cache[x.dtype] = x 
-        _valid_cache['lbl'] = y
+        _valid_cache['lbl'] = numpy.asarray(y, dtype='int32')
     return _valid_cache['lbl']
 _test_cache = {}
 def test_img(dtype):
@@ -46,14 +46,14 @@
         if dtype.startswith('uint') or dtype.startswith('int'):
             x *= 255
         _test_cache[dtype] = numpy.asarray(x, dtype=dtype)
-        _test_cache['lbl'] = y
+        _test_cache['lbl'] = numpy.asarray(y, dtype='int32')
     return _test_cache[dtype]
 def test_lbl():
     if 'lbl' not in _test_cache:
         x, y = head_test()
         # cache x in some format now that it's read (it isn't that big).
         _test_cache[x.dtype] = x 
-        _test_cache['lbl'] = y
+        _test_cache['lbl'] = numpy.asarray(y, dtype='int32')
     return _test_cache['lbl']
 
 _split_fns = dict(
@@ -77,7 +77,7 @@
 
     x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)),
             single_shape=(1024,))(s_idx)
-    y = TensorFnDataset(dtype='int64', bcast=(), fn=y_fn)(s_idx)
+    y = TensorFnDataset(dtype='int32', bcast=(), fn=y_fn)(s_idx)
     if x.ndim == 1:
         if not rasterized:
             x = x.reshape((32,32))