changeset 936:f732ec90e249

added code comments and "all" attribute to datasets.cifar10
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 10 May 2010 00:25:42 -0400
parents 7305246f21f8
children 594f4fda4829 daa355332b66
files pylearn/datasets/cifar10.py
diffstat 1 files changed, 34 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/cifar10.py	Sat Apr 17 18:33:53 2010 -0400
+++ b/pylearn/datasets/cifar10.py	Mon May 10 00:25:42 2010 -0400
@@ -7,19 +7,46 @@
 import numpy
 import cPickle
 
+import logging
+_logger = logging.getLogger('pylearn.datasets.cifar10')
+
 from pylearn.datasets.config import data_root # config
-from pylearn.datasets.dataset import Dataset
+from pylearn.datasets.dataset import Dataset # dataset.py
 
 def unpickle(file):
-    path = os.path.join(data_root(), 'cifar10', 'cifar-10-batches-py')
-    fname = os.path.join(path, file)
-    print 'loading file %s' % fname
+    fname = os.path.join(data_root(),
+            'cifar10', 
+            'cifar-10-batches-py',
+            file)
+    _logger.info('loading file %s' % fname)
     fo = open(fname, 'rb')
     dict = cPickle.load(fo)
     fo.close()
     return dict
 
-class cifar10():
+class cifar10(object):
+    """
+
+    This class gives access to meta-data of cifar10 dataset.
+    The constructor loads it from <data>/cifar10/cifar-10-batches-py/
+    where <data> is the pylearn data root (os.getenv('PYLEARN_DATA_ROOT')).
+
+    Attributes:
+
+    self.img_shape - the unrasterized image shape of each row in all.x
+    self.img_size - the number of pixels in (aka length of) each row
+    self.n_classes - the number of labels in the dataset (10)
+
+    self.all.x    matrix - all train and test images as rasterized rows
+    self.all.y    vector - all train and test labels as integers
+    self.train.x  matrix - first ntrain rows of all.x
+    self.train.y  matrix - first ntrain elements of all.y
+    self.valid.x  matrix - rows ntrain to ntrain+nvalid of all.x
+    self.valid.y  vector - elements ntrain to ntrain+nvalid of all.y
+    self.test.x   matrix - rows ntrain+valid to end of all.x
+    self.test.y   vector - elements ntrain+valid to end of all.y
+
+    """
 
     def __init__(self, dtype='uint8', ntrain=40000, nvalid=10000, ntest=10000):
         assert ntrain + nvalid <= 50000
@@ -44,6 +71,8 @@
 
             nloaded += 10000
             if nloaded >= ntrain + nvalid + ntest: break;
+
+        self.all = Dataset.Obj(x=x, y=y)
         
         self.train = Dataset.Obj(x=x[0:ntrain], y=y[0:ntrain])
         self.valid = Dataset.Obj(x=x[ntrain:ntrain+nvalid],