changeset 1288:a165f2666643

cifar10 - added support for "all" split
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 29 Sep 2010 18:35:40 -0400
parents 4fa2a32e8fde
children 092cd4cd2009
files pylearn/dataset_ops/cifar10.py
diffstat 1 files changed, 31 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/cifar10.py	Wed Sep 29 18:34:47 2010 -0400
+++ b/pylearn/dataset_ops/cifar10.py	Wed Sep 29 18:35:40 2010 -0400
@@ -21,11 +21,20 @@
     dict = cPickle.load(fo)
     fo.close()
     data, labels = numpy.asarray(dict['data'], dtype=dtype), numpy.asarray(dict['labels'], dtype='int32')
-    if dtype in ('float32', 'float64'):
+    if str(dtype) in ('float32', 'float64'):
         data /= 255
     return data, labels
 
 @memo
+def all_data_labels(dtype='uint8'):
+    train_batch_data, train_batch_labels = zip(*[ _unpickle( os.path.join(data_root(), 'cifar10', 
+        'cifar-10-batches-py', 'data_batch_%i'%i), dtype) for i in range(1,6)])
+    test_batch_data, test_batch_labels = test_data_labels(dtype)
+    data = numpy.vstack(list(train_batch_data)+[test_batch_data])
+    labels = numpy.hstack(list(train_batch_labels)+[test_batch_labels])
+    return data, labels
+
+@memo
 def train_data_labels(dtype='uint8'):
     batch_data, batch_labels = zip(*[ _unpickle( os.path.join(data_root(), 'cifar10', 
         'cifar-10-batches-py', 'data_batch_%i'%i), dtype) for i in range(1,6)])
@@ -40,6 +49,7 @@
 def forget():
     train_data_labels.forget()
     test_data_labels.forget()
+    all_data_labels.forget()
 
 
 # functions for TensorFnDataset
@@ -56,12 +66,20 @@
     return test_data_labels(dtype)[0]
 def test_labels():
     return test_data_labels()[1]
+def all_data(dtype):
+    if dtype!='uint8':
+        raise ValueError()
+    return all_data_labels()[0]
+def all_labels():
+    return all_data_labels()[1]
 
 
 def cifar10(s_idx, split, dtype='float64', rasterized=False, color='grey',
         split_options = {'train':(train_data, train_labels),
                 'valid': (valid_data, valid_labels),
-                'test': (test_data, test_labels)}
+                'test': (test_data, test_labels),
+                'all': (all_data, all_labels),
+                }
             ):
     """ 
     Returns a pair (img, label) of theano expressions for cifar-10 samples
@@ -95,14 +113,15 @@
     x = x_op(s_idx)
     y = y_op(s_idx)
 
-    # Y = 0.3R + 0.59G + 0.11B from
-    # http://gimp-savvy.com/BOOK/index.html?node54.html
-    rgb_dtype = 'float32'
-    if dtype == 'float64':
-        rgb_dtype = dtype
-    r = numpy.asarray(.3, dtype=rgb_dtype)
-    g = numpy.asarray(.59, dtype=rgb_dtype)
-    b = numpy.asarray(.11, dtype=rgb_dtype)
+    if color=='grey':
+        # Y = 0.3R + 0.59G + 0.11B from
+        # http://gimp-savvy.com/BOOK/index.html?node54.html
+        rgb_dtype = 'float32'
+        if dtype == 'float64':
+            rgb_dtype = dtype
+        r = numpy.asarray(.3, dtype=rgb_dtype)
+        g = numpy.asarray(.59, dtype=rgb_dtype)
+        b = numpy.asarray(.11, dtype=rgb_dtype)
 
     if x.ndim == 1:
         if rasterized:
@@ -148,8 +167,8 @@
                     x = theano.tensor.cast(x, 'uint8')
                 x.reshape((N, 32, 32))
             elif color=='rgb':
-                # the strides aren't what you'd expect between channels,
-                # but theano is all about weird strides
+                # note: the strides aren't what you'd expect between channels,
+                # but a copy of the data would correct that.
                 x = x.reshape((N,3,32,32)).dimshuffle(0, 2, 3, 1)
             else:
                 raise NotImplemented('color', color)