Mercurial > pylearn
changeset 936:f732ec90e249
added code comments and "all" attribute to datasets.cifar10
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Mon, 10 May 2010 00:25:42 -0400 |
parents | 7305246f21f8 |
children | 594f4fda4829 daa355332b66 |
files | pylearn/datasets/cifar10.py |
diffstat | 1 files changed, 34 insertions(+), 5 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/datasets/cifar10.py Sat Apr 17 18:33:53 2010 -0400 +++ b/pylearn/datasets/cifar10.py Mon May 10 00:25:42 2010 -0400 @@ -7,19 +7,46 @@ import numpy import cPickle +import logging +_logger = logging.getLogger('pylearn.datasets.cifar10') + from pylearn.datasets.config import data_root # config -from pylearn.datasets.dataset import Dataset +from pylearn.datasets.dataset import Dataset # dataset.py def unpickle(file): - path = os.path.join(data_root(), 'cifar10', 'cifar-10-batches-py') - fname = os.path.join(path, file) - print 'loading file %s' % fname + fname = os.path.join(data_root(), + 'cifar10', + 'cifar-10-batches-py', + file) + _logger.info('loading file %s' % fname) fo = open(fname, 'rb') dict = cPickle.load(fo) fo.close() return dict -class cifar10(): +class cifar10(object): + """ + + This class gives access to meta-data of cifar10 dataset. + The constructor loads it from <data>/cifar10/cifar-10-batches-py/ + where <data> is the pylearn data root (os.getenv('PYLEARN_DATA_ROOT')). + + Attributes: + + self.img_shape - the unrasterized image shape of each row in all.x + self.img_size - the number of pixels in (aka length of) each row + self.n_classes - the number of labels in the dataset (10) + + self.all.x matrix - all train and test images as rasterized rows + self.all.y vector - all train and test labels as integers + self.train.x matrix - first ntrain rows of all.x + self.train.y matrix - first ntrain elements of all.y + self.valid.x matrix - rows ntrain to ntrain+nvalid of all.x + self.valid.y vector - elements ntrain to ntrain+nvalid of all.y + self.test.x matrix - rows ntrain+valid to end of all.x + self.test.y vector - elements ntrain+valid to end of all.y + + """ def __init__(self, dtype='uint8', ntrain=40000, nvalid=10000, ntest=10000): assert ntrain + nvalid <= 50000 @@ -44,6 +71,8 @@ nloaded += 10000 if nloaded >= ntrain + nvalid + ntest: break; + + self.all = Dataset.Obj(x=x, y=y) self.train = Dataset.Obj(x=x[0:ntrain], y=y[0:ntrain]) self.valid = Dataset.Obj(x=x[ntrain:ntrain+nvalid],