view pylearn/dataset_ops/cifar10.py @ 1282:f36f59e53c28

cifar10 op - made splits constructors a parameter
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 15 Sep 2010 17:46:03 -0400
parents 5d70dfc70ec0
children a73db8d65abb
line wrap: on
line source

"""
CIFAR-10 dataset of labeled small colour images.

For details see either:

  - http://www.cs.toronto.edu/~kriz/cifar.html, or

  - /data/lisa/data/cifar10/cifar-10-batches-py/readme.html

"""
import cPickle, os, sys, numpy
from pylearn.datasets.config import data_root
import theano

from protocol import TensorFnDataset # protocol.py __init__.py
from .memo import memo # memo.py

def _unpickle(filename, dtype):
    #implements loading as well as dtype-conversion and dtype-scaling
    fo = open(filename, 'rb')
    dict = cPickle.load(fo)
    fo.close()
    data, labels = numpy.asarray(dict['data'], dtype=dtype), numpy.asarray(dict['labels'], dtype='int32')
    if dtype in ('float32', 'float64'):
        data /= 255
    return data, labels

@memo
def train_data_labels(dtype='uint8'):
    batch_data, batch_labels = zip(*[ _unpickle( os.path.join(data_root(), 'cifar10', 
        'cifar-10-batches-py', 'data_batch_%i'%i), dtype) for i in range(1,6)])
    data = numpy.vstack(batch_data)
    labels = numpy.hstack(batch_labels)
    return data, labels
@memo
def test_data_labels(dtype='uint8'):
    return _unpickle(os.path.join(data_root(), 'cifar10', 'cifar-10-batches-py', 'test_batch'),
            dtype)

def forget():
    train_data_labels.forget()
    test_data_labels.forget()


# functions for TensorFnDataset

def train_data(dtype):
    return train_data_labels(dtype)[0][:40000]
def train_labels():
    return train_data_labels()[1][:40000]
def valid_data(dtype):
    return train_data_labels(dtype)[0][40000:]
def valid_labels():
    return train_data_labels()[1][40000:]
def test_data(dtype):
    return test_data_labels(dtype)[0]
def test_labels():
    return test_data_labels()[1]


def cifar10(s_idx, split, dtype='float64', rasterized=False, color='grey',
        split_options = {'train':(train_data, train_labels),
                'valid': (valid_data, valid_labels),
                'test': (test_data, test_labels)}
            ):
    """ 
    Returns a pair (img, label) of theano expressions for cifar-10 samples

    :param s_idx: the indexes

    :param split:

    :param dtype:

    :param rasterized: return examples as vectors (True) or 28x28 matrices (False)

    :param color: control how to deal with the color in the images'
      - grey   greyscale (with luminance weighting)
      - rgb    add a trailing dimension of length 3 with rgb colour channels

    """

    if split not in split_options:
        raise ValueError('invalid split option', (split, split_options.keys()))

    color_options = ('grey', 'rgb')
    if color not in color_options:
        raise ValueError('invalid color option', (color, color_options))

    x_fn, y_fn = split_options[split]

    x_op = TensorFnDataset(dtype, (False,), (x_fn, (dtype,)), (3072,))
    y_op = TensorFnDataset('int32', (), y_fn)

    x = x_op(s_idx)
    y = y_op(s_idx)

    # Y = 0.3R + 0.59G + 0.11B from
    # http://gimp-savvy.com/BOOK/index.html?node54.html
    rgb_dtype = 'float32'
    if dtype == 'float64':
        rgb_dtype = dtype
    r = numpy.asarray(.3, dtype=rgb_dtype)
    g = numpy.asarray(.59, dtype=rgb_dtype)
    b = numpy.asarray(.11, dtype=rgb_dtype)

    if x.ndim == 1:
        if rasterized:
            if color=='grey':
                x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
                if dtype=='uint8':
                    x = theano.tensor.cast(x, 'uint8')
            elif color=='rgb':
                # the strides aren't what you'd expect between channels,
                # but theano is all about weird strides
                x = x.reshape((3,32*32)).T
            else:
                raise NotImplemented('color', color)
        else:
            if color=='grey':
                x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
                if dtype=='uint8':
                    x = theano.tensor.cast(x, 'uint8')
                x = x.reshape((32,32))
            elif color=='rgb':
                # the strides aren't what you'd expect between channels,
                # but theano is all about weird strides
                x = x.reshape((3,32,32)).dimshuffle(1, 2, 0)
            else:
                raise NotImplemented('color', color)
    elif x.ndim == 2:
        N = x.shape[0] # symbolic
        if rasterized:
            if color=='grey':
                x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
                if dtype=='uint8':
                    x = theano.tensor.cast(x, 'uint8')
            elif color=='rgb':
                # the strides aren't what you'd expect between channels,
                # but theano is all about weird strides
                x = x.reshape((N, 3,32*32)).dimshuffle(0, 2, 1)
            else:
                raise NotImplemented('color', color)
        else:
            if color=='grey':
                x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
                if dtype=='uint8':
                    x = theano.tensor.cast(x, 'uint8')
                x.reshape((N, 32, 32))
            elif color=='rgb':
                # the strides aren't what you'd expect between channels,
                # but theano is all about weird strides
                x = x.reshape((N,3,32,32)).dimshuffle(0, 2, 3, 1)
            else:
                raise NotImplemented('color', color)
    else:
        raise ValueError('x has too many dimensions', x.ndim)

    return x, y

nclasses = 10

def glviewer(split='train'):
    from glviewer import GlViewer
    i = theano.tensor.iscalar()
    f = theano.function([i], cifar10(i, split, dtype='uint8', rasterized=False, color='rgb')[0])
    GlViewer(f).main()


if 0:
    def datarow_to_greyscale_28by28(row, max_scale=1.0):
        assert row.shape == (3072,)
        rgb = row.reshape((3, 1024))
        grey = numpy.mean(rgb, axis=0) * max_scale / 255.0
        assert grey.shape == (1024,)
        grey_arr = grey.reshape((32,32))
        middle = grey_arr[1:29,1:29]
        middle_flat = middle.reshape((784,))
        return middle_flat


    def batch_iter(b_idx, max_scale=1.0):
        if b_idx == 'test':
            data, labels = data_batch_test()
        else:
            data, labels = data_batches[b_idx]()
        assert len(data) == 10000
        assert len(labels) == 10000

        for i in xrange(len(labels)):
            yield datarow_to_greyscale_28by28(data[i], max_scale=max_scale), labels[i]

    def train_iter(scale=1.0):
        while True:
            for b_idx in xrange(4):
                for d, l in batch_iter(b_idx, max_scale=scale):
                    yield d, l
    def valid_iter(scale=1.0):
        for d, l in batch_iter(4, max_scale=scale):
            yield d, l
    def test_iter(scale=1.0):
        for d, l in batch_iter('test', max_scale=scale):
            yield d, l



    # the following function is patterned after the MNIST.mnist function
    class GreyScale(theano.Op):
        def __eq__(self, other): 
            return type(self) == type(other)

        def __hash__(self):
            return hash(type(self))

        def make_node(self, x):
            x_ = theano.tensor.as_tensor_variable(x)
            if x_.type.ndim not in (3,4):
                raise TypeError('Greyscaling a tensor with unexpected number of dimensions',
                        x_.type.ndim)
            z_type = theano.tensor.TensorType(
                    dtype=x.dtype,
                    broadcastable = x_.type.broadcastable[:-1])
            return theano.Apply(self, [x_], [z_type()])

        def perform(self, node, (x,), (z,)):
            #TODO: Use PIL for real greyscale
            z[0] = numpy.asarray(x.mean(axis=x.ndim-1), dtype=node.outputs[0].type.dtype)

        def grad(self, (x,), (z,)):
            # TODO: this op is actually differentiable...
            #       when perform is done with PIL, then TODO is to look up the constants of the RGB
            #       weights, and put them here in the grad function.
            return [None]