Mercurial > pylearn
view pylearn/datasets/cifar10.py @ 1479:1b69d435f09f
fix error string.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Wed, 25 May 2011 09:26:47 -0400 |
parents | 5ae77ac21609 |
children |
line wrap: on
line source
""" Various routines to load/access MNIST data. """ from __future__ import absolute_import import os import numpy import cPickle import logging _logger = logging.getLogger('pylearn.datasets.cifar10') from pylearn.datasets.config import data_root # config from pylearn.datasets.dataset import Dataset # dataset.py def unpickle(file): fname = os.path.join(data_root(), 'cifar10', 'cifar-10-batches-py', file) _logger.info('loading file %s' % fname) fo = open(fname, 'rb') dict = cPickle.load(fo) fo.close() return dict class cifar10(object): """ This class gives access to meta-data of cifar10 dataset. The constructor loads it from <data>/cifar10/cifar-10-batches-py/ where <data> is the pylearn data root (os.getenv('PYLEARN_DATA_ROOT')). Attributes: self.img_shape - the unrasterized image shape of each row in all.x self.img_size - the number of pixels in (aka length of) each row self.n_classes - the number of labels in the dataset (10) self.all.x matrix - all train and test images as rasterized rows self.all.y vector - all train and test labels as integers self.train.x matrix - first ntrain rows of all.x self.train.y matrix - first ntrain elements of all.y self.valid.x matrix - rows ntrain to ntrain+nvalid of all.x self.valid.y vector - elements ntrain to ntrain+nvalid of all.y self.test.x matrix - rows ntrain+valid to end of all.x self.test.y vector - elements ntrain+valid to end of all.y """ def __init__(self, dtype='uint8', ntrain=40000, nvalid=10000, ntest=10000): assert ntrain + nvalid <= 50000 assert ntest <= 10000 self.img_shape = (3,32,32) self.img_size = numpy.prod(self.img_shape) self.n_classes = 10 lenx = numpy.ceil((ntrain + nvalid) / 10000.)*10000 x = numpy.zeros((lenx,self.img_size), dtype=dtype) y = numpy.zeros(lenx, dtype=dtype) fnames = ['data_batch_%i'%i for i in range(1,6)] # load train and validation data nloaded = 0 for i, fname in enumerate(fnames): data = unpickle(fname) x[i*10000:(i+1)*10000, :] = data['data'] y[i*10000:(i+1)*10000] = data['labels'] nloaded += 10000 if nloaded >= ntrain + nvalid + ntest: break; self.all = Dataset.Obj(x=x, y=y) self.train = Dataset.Obj(x=x[0:ntrain], y=y[0:ntrain]) self.valid = Dataset.Obj(x=x[ntrain:ntrain+nvalid], y=y[ntrain:ntrain+nvalid]) # load test data data = unpickle('test_batch') self.test = Dataset.Obj(x=data['data'][0:ntest], y=data['labels'][0:ntest]) def preprocess(self, x): return numpy.float64( x *1.0 / 255.0) def first_1k(dtype='uint8', ntrain=1000, nvalid=200, ntest=200): return cifar10(dtype, ntrain, nvalid, ntest) def tile_rasterized_examples(X, img_shape=(32,32)): """Returns an ndarray that is ready to be passed to `image_tiling.save_tiled_raster_images` This function is for the `x` matrices in the cifar dataset, or for the weight matrices (filters) used to multiply them. """ ndim = img_shape[0]*img_shape[1] assert ndim *3 == X.shape[1], (ndim, X.shape) X = X.astype('float32') r = X[:,:ndim] g = X[:,ndim:ndim*2] b = X[:,ndim*2:] from pylearn.io.image_tiling import tile_raster_images rval = tile_raster_images((r,g,b,None), img_shape=img_shape) return rval