Mercurial > pylearn
view pylearn/datasets/icml07.py @ 1479:1b69d435f09f
fix error string.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Wed, 25 May 2011 09:26:47 -0400 |
parents | 1e4dc99a3b13 |
children |
line wrap: on
line source
""" Functions related to the datasets used in Larochelle et al. 2007 (incl. modified MNIST). """ import os, sys import numpy from config import get_filepath_in_roots from pylearn.io.amat import AMat from pylearn.datasets.config import data_root # config from pylearn.datasets.dataset import Dataset class MNIST_rotated_background(object): def __init__(self, n_train=10000, n_valid=2000, n_test=50000): basedir = os.path.join(data_root(), 'icml07data', 'npy') x_all = numpy.load(os.path.join(basedir, 'mnist_rotated_background_images_inputs.npy')) y_all = numpy.load(os.path.join(basedir, 'mnist_rotated_background_images_labels.npy')) vstart = n_train tstart = n_train + n_valid self.train = Dataset.Obj(x=x_all[:n_train], y=y_all[:n_train]) self.valid = Dataset.Obj(x=x_all[vstart:vstart+n_valid], y=y_all[vstart:vstart+n_valid]) self.test = Dataset.Obj(x=x_all[tstart:tstart+n_test], y=y_all[tstart:tstart+n_test]) self.n_classes = 10 self.img_shape = (28,28) class DatasetLoader(object): """ A class for loading an ICML07 dataset into memory. The class has functionality to - download the dataset from the internet (in amat format) - convert the dataset from amat format to npy format - load the dataset from either amat or npy source files """ def __init__(self, http_source, n_inputs, n_classes, n_train, n_valid, n_test, npy_filename_root, amat_filename_root=None, amat_filename_train=None, amat_filename_test=None, amat_filename_all=None, ): self.__dict__.update(locals()) del self.__dict__['self'] def download(self, todir): #TODO: write a system call to wget to dl the file from self.http_source raise NotImplementedError() def load_from_amat(self): if self.amat_filename_all is not None: amat_all = AMat(self.amat_filename_all) allmat = amat_all.all assert allmat.shape[0] == self.n_train + self.n_valid + self.n_test, allmat.shape else: if self.amat_filename_root is not None: amat_train = AMat(self.amat_filename_root+'_train.amat') amat_test = AMat(self.amat_filename_root+'_test.amat') else: amat_train = AMat(self.amat_filename_train) amat_test = AMat(self.amat_filename_test) assert amat_train.all.shape[0] == self.n_train + self.n_valid assert amat_test.all.shape[0] == self.n_test allmat = numpy.vstack((amat_train.all, amat_test.all)) # CHECKPOINT: allmat has been computed by this point. assert allmat.shape[1] == self.n_inputs+1 inputs = allmat[:, :self.n_inputs].astype('float32') labels = allmat[:, self.n_inputs].astype('int8') assert numpy.allclose(labels, allmat[:, self.n_inputs]) assert numpy.all(labels < self.n_classes) return inputs, labels def load_from_amat_save_to_numpy(self): inputs, labels = self.load_from_amat() numpy.save(self.npy_filename_root+'_inputs.npy', inputs) numpy.save(self.npy_filename_root+'_labels.npy', labels) return inputs, labels def load_from_numpy(self, mmap_mode='r'): """Much faster than load_from_amat""" inputs = numpy.load(self.npy_filename_root+'_inputs.npy', mmap_mode=mmap_mode) labels = numpy.load(self.npy_filename_root+'_labels.npy', mmap_mode=mmap_mode) assert inputs.shape == (self.n_train + self.n_valid + self.n_test, self.n_inputs) assert labels.shape[0] == inputs.shape[0] assert numpy.all(labels < self.n_classes) return inputs, labels def icml07_loaders(new_version=True, rootdir=None): if rootdir is None: rootdir = get_filepath_in_roots('icml07data_twiki') if rootdir is None: raise IOError('dataset not found (no icml07data_twiki folder in PYLEARN_DATA_ROOT or DBPATH environment variable.') rval = dict( mnist_basic=DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist.zip', amat_filename_root=os.path.join(rootdir, 'mnist'), npy_filename_root=os.path.join(rootdir, 'mnist_basic'), n_inputs=784, n_classes=10, n_train=10000, n_valid=2000, n_test=50000 ), mnist_background_images=DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_background_images.zip', amat_filename_root=os.path.join(rootdir, 'mnist_background_images'), npy_filename_root=os.path.join(rootdir, 'mnist_background_images'), n_inputs=784, n_classes=10, n_train=10000, n_valid=2000, n_test=50000 ), mnist_background_random=DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_background_random.zip', amat_filename_root=os.path.join(rootdir, 'mnist_background_random'), npy_filename_root=os.path.join(rootdir, 'mnist_background_random'), n_inputs=784, n_classes=10, n_train=10000, n_valid=2000, n_test=50000 ), rectangles=DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/rectangles.zip', amat_filename_root=os.path.join(rootdir, 'rectangles'), npy_filename_root=os.path.join(rootdir, 'rectangles'), n_inputs=784, n_classes=10, n_train=1000, n_valid=200, n_test=50000 ), rectangles_images=DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/rectangles_images.zip', amat_filename_root=os.path.join(rootdir, 'rectangles_im'), npy_filename_root=os.path.join(rootdir, 'rectangles_images'), n_inputs=784, n_classes=10, n_train=10000, n_valid=2000, n_test=50000 ), convex=DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/convex.zip', amat_filename_root=os.path.join(rootdir, 'convex'), npy_filename_root=os.path.join(rootdir, 'convex'), n_inputs=784, n_classes=10, n_train=6500, #not sure about this train/valid split n_valid=1500, n_test=50000 ), ) for level in range(1,7): rval['mnist_noise_%i'%level] = DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_noise_variation.tar.gz', amat_filename_all=os.path.join(rootdir, 'mnist_noise_variations_all_%i.amat'%level), npy_filename_root=os.path.join(rootdir, 'mnist_noise_%i'%level), n_inputs=784, n_classes=10, n_train=10000, n_valid=2000, n_test=2000 ) if new_version: rval['mnist_rotated'] = DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip', amat_filename_test=os.path.join(rootdir, 'mnist_all_rotation_normalized_float_test.amat'), amat_filename_train=os.path.join(rootdir, 'mnist_all_rotation_normalized_float_train_valid.amat'), npy_filename_root=os.path.join(rootdir, 'mnist_rotated'), n_inputs=784, n_classes=10, n_train=10000, n_valid=2000, n_test=50000 ) rval['mnist_rotated_background_images'] = DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_back_image_new.zip', amat_filename_test=os.path.join(rootdir, 'mnist_all_background_images_rotation_normalized_test.amat'), amat_filename_train=os.path.join(rootdir, 'mnist_all_background_images_rotation_normalized_train_valid.amat'), npy_filename_root=os.path.join(rootdir, 'mnist_rotated_background_images'), n_inputs=784, n_classes=10, n_train=10000, n_valid=2000, n_test=50000 ) else: raise NotImplementedError('TODO: what are the amat_filenames here') rval['mnist_rotated'] = DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation.zip') rval['mnist_rotated_background_images'] = DatasetLoader( http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_back_image.zip') return rval