# HG changeset patch # User James Bergstra # Date 1273465542 14400 # Node ID f732ec90e249c1a05fe7690b111ac95769f5ea93 # Parent 7305246f21f8359f2a4abca9ca6fa323b826508a added code comments and "all" attribute to datasets.cifar10 diff -r 7305246f21f8 -r f732ec90e249 pylearn/datasets/cifar10.py --- a/pylearn/datasets/cifar10.py Sat Apr 17 18:33:53 2010 -0400 +++ b/pylearn/datasets/cifar10.py Mon May 10 00:25:42 2010 -0400 @@ -7,19 +7,46 @@ import numpy import cPickle +import logging +_logger = logging.getLogger('pylearn.datasets.cifar10') + from pylearn.datasets.config import data_root # config -from pylearn.datasets.dataset import Dataset +from pylearn.datasets.dataset import Dataset # dataset.py def unpickle(file): - path = os.path.join(data_root(), 'cifar10', 'cifar-10-batches-py') - fname = os.path.join(path, file) - print 'loading file %s' % fname + fname = os.path.join(data_root(), + 'cifar10', + 'cifar-10-batches-py', + file) + _logger.info('loading file %s' % fname) fo = open(fname, 'rb') dict = cPickle.load(fo) fo.close() return dict -class cifar10(): +class cifar10(object): + """ + + This class gives access to meta-data of cifar10 dataset. + The constructor loads it from /cifar10/cifar-10-batches-py/ + where is the pylearn data root (os.getenv('PYLEARN_DATA_ROOT')). + + Attributes: + + self.img_shape - the unrasterized image shape of each row in all.x + self.img_size - the number of pixels in (aka length of) each row + self.n_classes - the number of labels in the dataset (10) + + self.all.x matrix - all train and test images as rasterized rows + self.all.y vector - all train and test labels as integers + self.train.x matrix - first ntrain rows of all.x + self.train.y matrix - first ntrain elements of all.y + self.valid.x matrix - rows ntrain to ntrain+nvalid of all.x + self.valid.y vector - elements ntrain to ntrain+nvalid of all.y + self.test.x matrix - rows ntrain+valid to end of all.x + self.test.y vector - elements ntrain+valid to end of all.y + + """ def __init__(self, dtype='uint8', ntrain=40000, nvalid=10000, ntest=10000): assert ntrain + nvalid <= 50000 @@ -44,6 +71,8 @@ nloaded += 10000 if nloaded >= ntrain + nvalid + ntest: break; + + self.all = Dataset.Obj(x=x, y=y) self.train = Dataset.Obj(x=x[0:ntrain], y=y[0:ntrain]) self.valid = Dataset.Obj(x=x[ntrain:ntrain+nvalid],