view pylearn/dataset_ops/MNIST.py @ 998:8ba8b08e0442

added the image_patches dataset used in RanzatoHinton2010 modified mcRBM to use it.
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 16:51:53 -0400
parents 223caaea433a
children
line wrap: on
line source

"""Regular MNIST using the dataset protocol
"""
import os, numpy
import theano
from pylearn.datasets.config import data_root # config
from pylearn.io.ubyte import read_ubyte_matrix
from protocol import TensorFnDataset # protocol.py __init__.py
from .memo import memo

@memo
def get_train_img_u8_rasterized():
    """Returns 60000 x 784 MNIST train set"""
    return read_ubyte_matrix(
            os.path.join(data_root(), 'mnist', 'train-images-idx3-ubyte'),
            60000, 784, 16,
            write=False, align=True, as_dtype='uint8')
@memo
def get_test_img_u8_rasterized():
    """Returns 10000 x 784 MNIST test set"""
    return read_ubyte_matrix(
            os.path.join(data_root(), 'mnist', 't10k-images-idx3-ubyte'),
            10000, 784, 16,
            write=False, align=True, as_dtype='uint8')
@memo
def get_train_labels():
    # these are actually uint8, but the nnet classif code is for ints.
    return read_ubyte_matrix(
            os.path.join(data_root(), 'mnist', 'train-labels-idx1-ubyte'),
            60000, 1, 8,
            write=False, align=True, as_dtype='int32').reshape(60000)
@memo
def get_test_labels():
    # these are actually uint8, but the nnet classif code is for ints.
    return read_ubyte_matrix(
            os.path.join(data_root(), 'mnist', 't10k-labels-idx1-ubyte'),
            10000, 1, 8,
            write=False, align=True, as_dtype='int32').reshape(10000)

#This will cause both the uint8 version and the float version of the dataset to be cached.
# For larger datasets, it would be better to use Theano's cast(x, dtype) to do this conversion
# on the fly.
@memo
def get_train_img_f32_rasterized():
    return get_train_img_u8_rasterized() / numpy.asarray(255, dtype='float32')
@memo
def get_train_img_f64_rasterized():
    return get_train_img_u8_rasterized() / numpy.asarray(255, dtype='float64')
@memo
def get_test_img_f32_rasterized():
    return get_test_img_u8_rasterized() / numpy.asarray(255, dtype='float32')
@memo
def get_test_img_f64_rasterized():
    return get_test_img_u8_rasterized() / numpy.asarray(255, dtype='float64')

def mnist(s_idx, split, dtype='float64', rasterized=False):
    """
    :param s_idx:

    :param split:

    :param dtype:

    :param rasterized: return examples as vectors (True) or 28x28 matrices (False)

    """
    if split not in ('train', 'valid', 'test'):
        raise ValueError('split should be train, valid, or test', split)

    if split == 'test':
        l_fn = get_test_labels
        if dtype == 'uint8':
            i_fn = get_test_img_u8_rasterized
        elif dtype == 'float32':
            i_fn = get_test_img_f32_rasterized
        elif dtype == 'float64':
            i_fn = get_test_img_f64_rasterized
        else:
            raise ValueError('invalid dtype', dtype)
    else:
        l_fn = get_train_labels
        if dtype == 'uint8':
            i_fn = get_train_img_u8_rasterized
        elif dtype == 'float32':
            i_fn = get_train_img_f32_rasterized
        elif dtype == 'float64':
            i_fn = get_train_img_f64_rasterized
        else:
            raise ValueError('invalid dtype', dtype)

    if split == 'test':
        idx = s_idx
    elif split == 'train':
        idx = s_idx % 50000
    else: #valid
        idx = s_idx + 50000

    x = TensorFnDataset(dtype, (False,), i_fn, (784,))(idx)
    y = TensorFnDataset('int32', (), l_fn)(idx)
    if x.ndim == 1:
        if not rasterized:
            x = x.reshape((28,28))
    elif x.ndim == 2:
        if not rasterized:
            x = x.reshape((x.shape[0], 28,28))
    else:
        assert False, 'what happened?'

    return x, y
nclasses = 10

def glviewer(split='train'):
    from glviewer import GlViewer
    i = theano.tensor.iscalar()
    f = theano.function([i], mnist(i, split, dtype='uint8', rasterized=False)[0])
    GlViewer(f).main()