changeset 1353:2024c5618466

adding icml07 dataset
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 03 Nov 2010 12:49:24 -0400
parents cc3e3e596500
children be3030305d4b
files pylearn/datasets/icml07.py
diffstat 1 files changed, 172 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/datasets/icml07.py	Wed Nov 03 12:49:24 2010 -0400
@@ -0,0 +1,172 @@
+""" Functions related to the datasets used in Larochelle et al. 2007 (incl. modified MNIST). 
+"""
+import os, sys
+import numpy
+
+from pylearn.io.amat import AMat
+
+class DatasetLoader(object):
+    def __init__(self, http_source, 
+            n_inputs, n_classes,
+            n_train, n_valid, n_test,
+            npy_filename_root, 
+            amat_filename_root=None, 
+            amat_filename_train=None,
+            amat_filename_test=None,
+            amat_filename_all=None,
+            ):
+        self.__dict__.update(locals())
+        del self.__dict__['self']
+
+    def download(self, todir):
+        raise NotImplementedError()
+
+    def load_from_amat(self):
+        if self.amat_filename_all is not None:
+            raise NotImplementedError()
+        else:
+            if self.amat_filename_root is not None:
+                amat_train = AMat(self.amat_filename_root+'_train.amat')
+                amat_test = AMat(self.amat_filename_root+'_test.amat')
+            else:
+                amat_train = AMat(self.amat_filename_train)
+                amat_test = AMat(self.amat_filename_test)
+            assert amat_train.all.shape[0] == self.n_train + self.n_valid
+            assert amat_test.all.shape[0] == self.n_test
+            allmat = numpy.vstack((amat_train.all, amat_test.all))
+        # CHECKPOINT: allmat has been computed by this point.
+        assert allmat.shape[1] == self.n_inputs+1
+        inputs = allmat[:, :self.n_inputs].astype('float32')
+        labels = allmat[:, self.n_inputs].astype('int8')
+        assert numpy.allclose(labels, allmat[:, self.n_inputs])
+        assert numpy.all(labels < self.n_classes)
+        return inputs, labels
+
+    def load_from_amat_save_to_numpy(self):
+        inputs, labels = self.load_from_amat()
+        numpy.save(self.npy_filename_root+'_inputs.npy', inputs)
+        numpy.save(self.npy_filename_root+'_labels.npy', labels)
+        return inputs, labels
+
+    def load_from_numpy(self, mmap_mode='r'):
+        """Much faster than load_from_amat"""
+        inputs = numpy.load(self.npy_filename_root+'_inputs.npy', mmap_mode=mmap_mode)
+        labels = numpy.load(self.npy_filename_root+'_labels.npy', mmap_mode=mmap_mode)
+        assert inputs.shape == (self.n_train + self.n_valid + self.n_test, self.n_inputs)
+        assert labels.shape[0] == inputs.shape[0]
+        assert numpy.all(labels < self.n_classes)
+        return inputs, labels
+
+
+def icml07_loaders(new_version=True, rootdir='.'):
+    rval = dict(
+        mnist_basic=DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist.zip',
+            amat_filename_root=os.path.join(rootdir, 'mnist'),
+            npy_filename_root=os.path.join(rootdir, 'mnist_basic'),
+            n_inputs=784,
+            n_classes=10,
+            n_train=10000,
+            n_valid=2000,
+            n_test=50000
+            ),
+        mnist_background_images=DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_background_images.zip',
+            amat_filename_root=os.path.join(rootdir, 'mnist_background_images'),
+            npy_filename_root=os.path.join(rootdir, 'mnist_background_images'),
+            n_inputs=784,
+            n_classes=10,
+            n_train=10000,
+            n_valid=2000,
+            n_test=50000
+            ),
+        mnist_background_random=DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_background_random.zip',
+            amat_filename_root=os.path.join(rootdir, 'mnist_background_random'),
+            npy_filename_root=os.path.join(rootdir, 'mnist_background_random'),
+            n_inputs=784,
+            n_classes=10,
+            n_train=10000,
+            n_valid=2000,
+            n_test=50000
+            ),
+        rectangles=DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/rectangles.zip',
+            amat_filename_root=os.path.join(rootdir, 'rectangles'),
+            npy_filename_root=os.path.join(rootdir, 'rectangles'),
+            n_inputs=784,
+            n_classes=10,
+            n_train=1000,
+            n_valid=200,
+            n_test=50000
+            ),
+        rectangles_images=DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/rectangles_images.zip',
+            amat_filename_root=os.path.join(rootdir, 'rectangles_im'),
+            npy_filename_root=os.path.join(rootdir, 'rectangles_images'),
+            n_inputs=784,
+            n_classes=10,
+            n_train=10000,
+            n_valid=2000,
+            n_test=50000
+            ),
+        convex=DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/convex.zip',
+            amat_filename_root=os.path.join(rootdir, 'convex'),
+            npy_filename_root=os.path.join(rootdir, 'convex'),
+            n_inputs=784,
+            n_classes=10,
+            n_train=6500, #not sure about this train/valid split
+            n_valid=1500,
+            n_test=50000
+            ),
+        )
+    for level in range(1,6):
+        rval['mnist_noise_%i'%level] = DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_noise_variation.tar.gz',
+            amat_filename_all=os.path.join(rootdir,
+                'mnist_noise_variations_all_%i.amat'%level),
+            npy_filename_root=os.path.join(rootdir, 'mnist_noise_%i'%level),
+            n_inputs=784,
+            n_classes=10,
+            n_train=10000,
+            n_valid=2000,
+            n_test=50000
+            )
+
+    if new_version:
+        rval['mnist_rotated'] = DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_new.zip',
+            amat_filename_test=os.path.join(rootdir,
+                'mnist_all_rotation_normalized_float_test.amat'),
+            amat_filename_train=os.path.join(rootdir,
+                'mnist_all_rotation_normalized_float_train_valid.amat'),
+            npy_filename_root=os.path.join(rootdir, 'mnist_rotated'),
+            n_inputs=784,
+            n_classes=10,
+            n_train=10000,
+            n_valid=2000,
+            n_test=50000
+            )
+        rval['mnist_rotated_background_images'] = DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_back_image_new.zip',
+            amat_filename_test=os.path.join(rootdir,
+                'mnist_all_background_images_rotation_normalized_test.amat'),
+            amat_filename_train=os.path.join(rootdir,
+                'mnist_all_background_images_rotation_normalized_train_valid.amat'),
+            npy_filename_root=os.path.join(rootdir, 'mnist_rotated_background_images'),
+            n_inputs=784,
+            n_classes=10,
+            n_train=10000,
+            n_valid=2000,
+            n_test=50000
+            )
+    else:
+        raise NotImplementedError('TODO: what are the amat_filenames here')
+        rval['mnist_rotated'] = DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation.zip')
+        rval['mnist_rotated_background_images'] = DatasetLoader(
+            http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist_rotation_back_image.zip')
+    return rval
+
+