view pylearn/dataset_ops/shapeset1.py @ 931:1c62fa857cab

forcing int32 label dtype in shapeset1
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 15 Apr 2010 10:50:10 -0400
parents 0f33afbf517e
children
line wrap: on
line source

"""A Theano Op to load/access Shapeset1
"""
import theano, numpy

from .protocol import TensorFnDataset

from ..datasets.shapeset1 import head_train, head_valid, head_test

#make global functions so Op can be pickled.
_train_cache = {}
def train_img(dtype):
    if dtype not in _train_cache:
        x, y = head_train()
        if dtype.startswith('uint') or dtype.startswith('int'):
            x *= 255
        _train_cache[dtype] = numpy.asarray(x, dtype=dtype)
        _train_cache['lbl'] = numpy.asarray(y, dtype='int32')
    return _train_cache[dtype]
def train_lbl():
    if 'lbl' not in _train_cache:
        x, y = head_train()
        # cache x in some format now that it's read (it isn't that big).
        _train_cache[x.dtype] = x 
        _train_cache['lbl'] = numpy.asarray(y, dtype='int32')
    return _train_cache['lbl']
_valid_cache = {}
def valid_img(dtype):
    if dtype not in _valid_cache:
        x, y = head_valid()
        if dtype.startswith('uint') or dtype.startswith('int'):
            x *= 255
        _valid_cache[dtype] = numpy.asarray(x, dtype=dtype)
        _valid_cache['lbl'] = numpy.asarray(y, dtype='int32')
    return _valid_cache[dtype]
def valid_lbl():
    if 'lbl' not in _valid_cache:
        x, y = head_valid()
        # cache x in some format now that it's read (it isn't that big).
        _valid_cache[x.dtype] = x 
        _valid_cache['lbl'] = numpy.asarray(y, dtype='int32')
    return _valid_cache['lbl']
_test_cache = {}
def test_img(dtype):
    if dtype not in _test_cache:
        x, y = head_test()
        if dtype.startswith('uint') or dtype.startswith('int'):
            x *= 255
        _test_cache[dtype] = numpy.asarray(x, dtype=dtype)
        _test_cache['lbl'] = numpy.asarray(y, dtype='int32')
    return _test_cache[dtype]
def test_lbl():
    if 'lbl' not in _test_cache:
        x, y = head_test()
        # cache x in some format now that it's read (it isn't that big).
        _test_cache[x.dtype] = x 
        _test_cache['lbl'] = numpy.asarray(y, dtype='int32')
    return _test_cache['lbl']

_split_fns = dict(
        train=(train_img, train_lbl),
        valid=(valid_img, valid_lbl),
        test=(test_img, test_lbl))

def shapeset1(s_idx, split, dtype='float64', rasterized=False):
    """
    :param s_idx:

    :param split:

    :param dtype:

    :param rasterized: return examples as vectors (True) or 28x28 matrices (False)

    """

    x_fn, y_fn = _split_fns[split]

    x = TensorFnDataset(dtype=dtype, bcast=(False,), fn=(x_fn, (dtype,)),
            single_shape=(1024,))(s_idx)
    y = TensorFnDataset(dtype='int32', bcast=(), fn=y_fn)(s_idx)
    if x.ndim == 1:
        if not rasterized:
            x = x.reshape((32,32))
    elif x.ndim == 2:
        if not rasterized:
            x = x.reshape((x.shape[0], 32,32))
    else:
        assert False, 'what happened?'
    return x, y
nclasses = 10

def glviewer(split='train'):
    from glviewer import GlViewer
    i = theano.tensor.iscalar()
    f = theano.function([i], shapeset1(i, split, dtype='uint8', rasterized=False)[0])
    GlViewer(f).main()