Mercurial > pylearn
view pylearn/dataset_ops/MNIST.py @ 873:223caaea433a
cut commented code
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 16 Nov 2009 13:20:28 -0500 |
parents | 77e6b2d3e5e5 |
children |
line wrap: on
line source
"""Regular MNIST using the dataset protocol """ import os, numpy import theano from pylearn.datasets.config import data_root # config from pylearn.io.ubyte import read_ubyte_matrix from protocol import TensorFnDataset # protocol.py __init__.py from .memo import memo @memo def get_train_img_u8_rasterized(): """Returns 60000 x 784 MNIST train set""" return read_ubyte_matrix( os.path.join(data_root(), 'mnist', 'train-images-idx3-ubyte'), 60000, 784, 16, write=False, align=True, as_dtype='uint8') @memo def get_test_img_u8_rasterized(): """Returns 10000 x 784 MNIST test set""" return read_ubyte_matrix( os.path.join(data_root(), 'mnist', 't10k-images-idx3-ubyte'), 10000, 784, 16, write=False, align=True, as_dtype='uint8') @memo def get_train_labels(): # these are actually uint8, but the nnet classif code is for ints. return read_ubyte_matrix( os.path.join(data_root(), 'mnist', 'train-labels-idx1-ubyte'), 60000, 1, 8, write=False, align=True, as_dtype='int32').reshape(60000) @memo def get_test_labels(): # these are actually uint8, but the nnet classif code is for ints. return read_ubyte_matrix( os.path.join(data_root(), 'mnist', 't10k-labels-idx1-ubyte'), 10000, 1, 8, write=False, align=True, as_dtype='int32').reshape(10000) #This will cause both the uint8 version and the float version of the dataset to be cached. # For larger datasets, it would be better to use Theano's cast(x, dtype) to do this conversion # on the fly. @memo def get_train_img_f32_rasterized(): return get_train_img_u8_rasterized() / numpy.asarray(255, dtype='float32') @memo def get_train_img_f64_rasterized(): return get_train_img_u8_rasterized() / numpy.asarray(255, dtype='float64') @memo def get_test_img_f32_rasterized(): return get_test_img_u8_rasterized() / numpy.asarray(255, dtype='float32') @memo def get_test_img_f64_rasterized(): return get_test_img_u8_rasterized() / numpy.asarray(255, dtype='float64') def mnist(s_idx, split, dtype='float64', rasterized=False): """ :param s_idx: :param split: :param dtype: :param rasterized: return examples as vectors (True) or 28x28 matrices (False) """ if split not in ('train', 'valid', 'test'): raise ValueError('split should be train, valid, or test', split) if split == 'test': l_fn = get_test_labels if dtype == 'uint8': i_fn = get_test_img_u8_rasterized elif dtype == 'float32': i_fn = get_test_img_f32_rasterized elif dtype == 'float64': i_fn = get_test_img_f64_rasterized else: raise ValueError('invalid dtype', dtype) else: l_fn = get_train_labels if dtype == 'uint8': i_fn = get_train_img_u8_rasterized elif dtype == 'float32': i_fn = get_train_img_f32_rasterized elif dtype == 'float64': i_fn = get_train_img_f64_rasterized else: raise ValueError('invalid dtype', dtype) if split == 'test': idx = s_idx elif split == 'train': idx = s_idx % 50000 else: #valid idx = s_idx + 50000 x = TensorFnDataset(dtype, (False,), i_fn, (784,))(idx) y = TensorFnDataset('int32', (), l_fn)(idx) if x.ndim == 1: if not rasterized: x = x.reshape((28,28)) elif x.ndim == 2: if not rasterized: x = x.reshape((x.shape[0], 28,28)) else: assert False, 'what happened?' return x, y nclasses = 10 def glviewer(split='train'): from glviewer import GlViewer i = theano.tensor.iscalar() f = theano.function([i], mnist(i, split, dtype='uint8', rasterized=False)[0]) GlViewer(f).main()