Mercurial > pylearn
comparison pylearn/datasets/icml07.py @ 1471:281efa9a4463
icml07_loaders uses get_filepath_in_roots
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 18 May 2011 10:51:11 -0400 |
parents | ba8a32b71356 |
children | e7401822d596 |
comparison
equal
deleted
inserted
replaced
1470:94268a161925 | 1471:281efa9a4463 |
---|---|
1 """ Functions related to the datasets used in Larochelle et al. 2007 (incl. modified MNIST). | 1 """ Functions related to the datasets used in Larochelle et al. 2007 (incl. modified MNIST). |
2 """ | 2 """ |
3 import os, sys | 3 import os, sys |
4 import numpy | 4 import numpy |
5 | 5 |
6 from config import get_filepath_in_roots | |
6 from pylearn.io.amat import AMat | 7 from pylearn.io.amat import AMat |
7 | 8 |
8 class DatasetLoader(object): | 9 class DatasetLoader(object): |
9 """ | 10 """ |
10 A class for loading an ICML07 dataset into memory. | 11 A class for loading an ICML07 dataset into memory. |
66 assert inputs.shape == (self.n_train + self.n_valid + self.n_test, self.n_inputs) | 67 assert inputs.shape == (self.n_train + self.n_valid + self.n_test, self.n_inputs) |
67 assert labels.shape[0] == inputs.shape[0] | 68 assert labels.shape[0] == inputs.shape[0] |
68 assert numpy.all(labels < self.n_classes) | 69 assert numpy.all(labels < self.n_classes) |
69 return inputs, labels | 70 return inputs, labels |
70 | 71 |
71 def icml07_loaders(new_version=True, rootdir='.'): | 72 def icml07_loaders(new_version=True, rootdir=None): |
73 if rootdir is None: | |
74 rootdir = get_filepath_in_roots('icml07data_twiki') | |
72 rval = dict( | 75 rval = dict( |
73 mnist_basic=DatasetLoader( | 76 mnist_basic=DatasetLoader( |
74 http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist.zip', | 77 http_source='http://www.iro.umontreal.ca/~lisa/icml2007data/mnist.zip', |
75 amat_filename_root=os.path.join(rootdir, 'mnist'), | 78 amat_filename_root=os.path.join(rootdir, 'mnist'), |
76 npy_filename_root=os.path.join(rootdir, 'mnist_basic'), | 79 npy_filename_root=os.path.join(rootdir, 'mnist_basic'), |