changeset 838:4f7e0edee7d0

adding cifar10 dataset
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 22 Oct 2009 18:51:20 -0400
parents 28ceb345ab78
children 2418dad01307
files pylearn/dataset_ops/cifar10.py pylearn/dataset_ops/tests/test_cifar10.py
diffstat 2 files changed, 342 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/dataset_ops/cifar10.py	Thu Oct 22 18:51:20 2009 -0400
@@ -0,0 +1,221 @@
+"""
+CIFAR-10 dataset of labeled small colour images.
+
+For details see either:
+
+  - http://www.cs.toronto.edu/~kriz/cifar.html, or
+
+  - /data/lisa/data/cifar10/cifar-10-batches-py/readme.html
+
+"""
+import cPickle, os, sys, numpy
+from pylearn.datasets.config import data_root
+import theano
+
+from protocol import TensorFnDataset # protocol.py __init__.py
+from .memo import memo
+
+def _unpickle(filename, dtype):
+    #implements loading as well as dtype-conversion and dtype-scaling
+    fo = open(filename, 'rb')
+    dict = cPickle.load(fo)
+    fo.close()
+    data, labels = numpy.asarray(dict['data'], dtype=dtype), numpy.asarray(dict['labels'], dtype='int32')
+    if dtype in ('float32', 'float64'):
+        data /= 255
+    return data, labels
+
+@memo
+def train_data_labels(dtype='uint8'):
+    batch_data, batch_labels = zip(*[ _unpickle( os.path.join(data_root(), 'cifar10', 
+        'cifar-10-batches-py', 'data_batch_%i'%i), dtype) for i in range(1,6)])
+    data = numpy.vstack(batch_data)
+    labels = numpy.hstack(batch_labels)
+    return data, labels
+@memo
+def test_data_labels(dtype='uint8'):
+    return _unpickle(os.path.join(data_root(), 'cifar10', 'cifar-10-batches-py', 'test_batch'),
+            dtype)
+
+
+# functions for TensorFnDataset
+
+def train_data(dtype):
+    return train_data_labels(dtype)[0][:40000]
+def train_labels():
+    return train_data_labels()[1][:40000]
+def valid_data(dtype):
+    return train_data_labels(dtype)[0][40000:]
+def valid_labels():
+    return train_data_labels()[1][40000:]
+def test_data(dtype):
+    return test_data_labels(dtype)[0]
+def test_labels():
+    return test_data_labels()[1]
+
+
+def cifar10(s_idx, split, dtype='float64', rasterized=False, color='grey'):
+    """
+    :param s_idx: the indexes
+
+    :param split:
+
+    :param dtype:
+
+    :param rasterized: return examples as vectors (True) or 28x28 matrices (False)
+
+    :param color: control how to deal with the color in the images'
+      - grey   greyscale (with luminance weighting)
+      - rgb    add a trailing dimension of length 3 with rgb colour channels
+
+    """
+
+    split_options = {'train':(train_data, train_labels),
+            'valid': (valid_data, valid_labels),
+            'test': (test_data, test_labels)}
+
+    if split not in split_options:
+        raise ValueError('invalid split option', (split, split_options.keys()))
+
+    color_options = ('grey', 'rgb')
+    if color not in color_options:
+        raise ValueError('invalid color option', (color, color_options))
+
+    x_fn, y_fn = split_options[split]
+
+    x_op = TensorFnDataset(dtype, (False,), (x_fn, (dtype,)), (3072,))
+    y_op = TensorFnDataset(dtype, (), y_fn)
+
+    x = x_op(s_idx)
+    y = y_op(s_idx)
+
+    # Y = 0.3R + 0.59G + 0.11B from
+    # http://gimp-savvy.com/BOOK/index.html?node54.html
+    rgb_dtype = 'float32'
+    if dtype == 'float64':
+        rgb_dtype = dtype
+    r = numpy.asarray(.3, dtype=rgb_dtype)
+    g = numpy.asarray(.59, dtype=rgb_dtype)
+    b = numpy.asarray(.11, dtype=rgb_dtype)
+
+    if x.ndim == 1:
+        if rasterized:
+            if color=='grey':
+                x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((3,32*32)).T
+            else:
+                raise NotImplemented('color', color)
+        else:
+            if color=='grey':
+                x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+                x = x.reshape((32,32))
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((3,32,32)).dimshuffle(1, 2, 0)
+            else:
+                raise NotImplemented('color', color)
+    elif x.ndim == 2:
+        N = x.shape[0] # symbolic
+        if rasterized:
+            if color=='grey':
+                x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((N, 3,32*32)).dimshuffle(0, 2, 1)
+            else:
+                raise NotImplemented('color', color)
+        else:
+            if color=='grey':
+                x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+                x.reshape((N, 32, 32))
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((N,3,32,32)).dimshuffle(0, 2, 3, 1)
+            else:
+                raise NotImplemented('color', color)
+    else:
+        raise ValueError('x has too many dimensions', x.ndim)
+
+    return x, y
+
+nclasses = 10
+
+
+if 0:
+    def datarow_to_greyscale_28by28(row, max_scale=1.0):
+        assert row.shape == (3072,)
+        rgb = row.reshape((3, 1024))
+        grey = numpy.mean(rgb, axis=0) * max_scale / 255.0
+        assert grey.shape == (1024,)
+        grey_arr = grey.reshape((32,32))
+        middle = grey_arr[1:29,1:29]
+        middle_flat = middle.reshape((784,))
+        return middle_flat
+
+
+    def batch_iter(b_idx, max_scale=1.0):
+        if b_idx == 'test':
+            data, labels = data_batch_test()
+        else:
+            data, labels = data_batches[b_idx]()
+        assert len(data) == 10000
+        assert len(labels) == 10000
+
+        for i in xrange(len(labels)):
+            yield datarow_to_greyscale_28by28(data[i], max_scale=max_scale), labels[i]
+
+    def train_iter(scale=1.0):
+        while True:
+            for b_idx in xrange(4):
+                for d, l in batch_iter(b_idx, max_scale=scale):
+                    yield d, l
+    def valid_iter(scale=1.0):
+        for d, l in batch_iter(4, max_scale=scale):
+            yield d, l
+    def test_iter(scale=1.0):
+        for d, l in batch_iter('test', max_scale=scale):
+            yield d, l
+
+
+
+    # the following function is patterned after the MNIST.mnist function
+    class GreyScale(theano.Op):
+        def __eq__(self, other): 
+            return type(self) == type(other)
+
+        def __hash__(self):
+            return hash(type(self))
+
+        def make_node(self, x):
+            x_ = theano.tensor.as_tensor_variable(x)
+            if x_.type.ndim not in (3,4):
+                raise TypeError('Greyscaling a tensor with unexpected number of dimensions',
+                        x_.type.ndim)
+            z_type = theano.tensor.TensorType(
+                    dtype=x.dtype,
+                    broadcastable = x_.type.broadcastable[:-1])
+            return theano.Apply(self, [x_], [z_type()])
+
+        def perform(self, node, (x,), (z,)):
+            #TODO: Use PIL for real greyscale
+            z[0] = numpy.asarray(x.mean(axis=x.ndim-1), dtype=node.outputs[0].type.dtype)
+
+        def grad(self, (x,), (z,)):
+            # TODO: this op is actually differentiable...
+            #       when perform is done with PIL, then TODO is to look up the constants of the RGB
+            #       weights, and put them here in the grad function.
+            return [None]
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/dataset_ops/tests/test_cifar10.py	Thu Oct 22 18:51:20 2009 -0400
@@ -0,0 +1,121 @@
+import numpy
+import theano
+from theano.compile.sandbox import pfunc, shared
+from theano import tensor
+
+from pylearn.dataset_ops.cifar10 import cifar10
+
+def test_single():
+
+    s_idx = theano.tensor.iscalar()
+
+    for dtype in ('uint8', 'float64', 'float32'):
+        x, y = cifar10(s_idx, split='train', dtype=dtype, rasterized=False, color='grey')
+        assert x.dtype == dtype
+
+def test_shape_range():
+    """Test that the image numbers come out in the right range for various dtypes"""
+    s_idx = theano.tensor.iscalar()
+
+    #uint8 not-rasterized grey
+    x, y = cifar10(s_idx, split='train', dtype='uint8', rasterized=False, color='grey')
+    f = pfunc([s_idx], [x,y])
+    xval, yval = f(0)
+    assert str(xval.dtype) == 'uint8'
+    assert xval.min() >= 0
+    assert xval.max() < 256
+    assert xval.max() > 1
+    assert xval.shape == (32,32)
+
+    #uint8 not-rasterized rgb
+    x, y = cifar10(s_idx, split='train', dtype='uint8', rasterized=False, color='rgb')
+    f = pfunc([s_idx], [x,y])
+    xval, yval = f(0)
+    assert str(xval.dtype) == 'uint8'
+    assert xval.min() >= 0
+    assert xval.max() < 256
+    assert xval.max() > 1
+    assert xval.shape == (32,32, 3)
+
+    #uint8 rasterized grey
+    x, y = cifar10(s_idx, split='train', dtype='uint8', rasterized=True, color='grey')
+    f = pfunc([s_idx], [x,y])
+    xval, yval = f(0)
+    assert str(xval.dtype) == 'uint8'
+    assert xval.min() >= 0
+    assert xval.max() < 256
+    assert xval.max() > 1
+    assert xval.shape == (1024,)
+
+    #uint8 rasterized rgb
+    x, y = cifar10(s_idx, split='train', dtype='uint8', rasterized=True, color='rgb')
+    f = pfunc([s_idx], [x,y])
+    xval, yval = f(0)
+    assert str(xval.dtype) == 'uint8'
+    assert xval.min() >= 0
+    assert xval.max() < 256
+    assert xval.max() > 1
+    assert xval.shape == (1024, 3)
+
+    # ranges are handled independently from shapes, so I'll consider the shapes have been
+    # tested above, and now I just look at ranges for floating-point dtypes
+
+    #float32 
+    x, y = cifar10(s_idx, split='train', dtype='float32', rasterized=False, color='grey')
+    f = pfunc([s_idx], [x,y])
+    xval, yval = f(0)
+    assert str(xval.dtype) == 'float32'
+    assert xval.min() >= 0.0
+    assert xval.max() <= 1.0
+    assert xval.max() > 0.01
+    assert xval.shape == (32,32)
+
+    #float64
+    x, y = cifar10(s_idx + range(5), split='train', dtype='float64', rasterized=True, color='rgb')
+    f = pfunc([s_idx], [x,y])
+    xval, yval = f(0)
+    assert str(xval.dtype) == 'float64'
+    assert xval.min() >= 0.0
+    assert xval.max() <= 1.0
+    assert xval.max() > 0.01
+    assert xval.shape == (5, 1024, 3)
+
+def test_split_different():
+    s_idx = theano.tensor.iscalar()
+    x, y = cifar10(s_idx, split='train', dtype='uint8', rasterized=False, color='grey')
+    f = pfunc([s_idx], [x,y])
+    train_xval, train_yval = f(0)
+
+    x, y = cifar10(s_idx, split='valid', dtype='uint8', rasterized=False, color='grey')
+    f = pfunc([s_idx], [x,y])
+    valid_xval, valid_yval = f(0)
+
+    x, y = cifar10(s_idx, split='test', dtype='uint8', rasterized=False, color='grey')
+    f = pfunc([s_idx], [x,y])
+    test_xval, test_yval = f(0)
+
+    assert not numpy.all(train_xval == valid_xval)
+    assert not numpy.all(train_xval == test_xval)
+    assert not numpy.all(valid_xval == test_xval)
+
+
+def test_split_length():
+    """test that each split has the correct length"""
+    s_idx = theano.tensor.iscalar()
+    for bsize in [1, 3, 5]:
+        for (split, goodlen) in [('train', 40000), ('valid', 10000), ('test', 10000)]:
+            if bsize == 1:
+                x, y = cifar10(s_idx, split=split, dtype='uint8', rasterized=False, color='grey')
+            else:
+                x, y = cifar10(s_idx*bsize + range(bsize), split=split, dtype='uint8', rasterized=False, color='grey')
+
+            f = pfunc([s_idx], [x,y])
+            i = 0
+            while i < 900000:
+                try:
+                    f(i)
+                except IndexError:
+                    break
+                i += 1
+            assert i == (goodlen / bsize) # when goodlen % bsize, we should skip the extra bit
+