Mercurial > pylearn
comparison pylearn/datasets/MNIST.py @ 537:b054271b2504
new file structure layout, factories, etc.
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 12 Nov 2008 21:57:54 -0500 |
parents | datasets/MNIST.py@58810b63292b |
children | 16f91ca016b1 |
comparison
equal
deleted
inserted
replaced
518:4aa7f74ea93f | 537:b054271b2504 |
---|---|
1 """ | |
2 Various routines to load/access MNIST data. | |
3 """ | |
4 from __future__ import absolute_import | |
5 | |
6 import os | |
7 import numpy | |
8 | |
9 from ..io.amat import AMat | |
10 from .config import data_root | |
11 from .dataset import dataset_factory, Dataset | |
12 | |
13 def head(n=10, path=None): | |
14 """Load the first MNIST examples. | |
15 | |
16 Returns two matrices: x, y. x has N rows of 784 columns. Each row of x represents the | |
17 28x28 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] | |
18 is the label of the i'th row of x. | |
19 | |
20 """ | |
21 path = os.path.join(data_root(), 'mnist','mnist_with_header.amat') if path is None else path | |
22 | |
23 dat = AMat(path=path, head=n) | |
24 | |
25 try: | |
26 assert dat.input.shape[0] == n | |
27 assert dat.target.shape[0] == n | |
28 except Exception , e: | |
29 raise Exception("failed to read MNIST data", (dat, e)) | |
30 | |
31 return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0]) | |
32 | |
33 def all(path=None): | |
34 return head(n=None, path=path) | |
35 | |
36 def train_valid_test(ntrain=50000, nvalid=10000, ntest=10000, path=None): | |
37 all_x, all_targ = head(ntrain+nvalid+ntest, path=path) | |
38 | |
39 rval = Dataset() | |
40 | |
41 rval.train = Dataset.Obj(x=all_x[0:ntrain], | |
42 y=all_targ[0:ntrain]) | |
43 rval.valid = Dataset.Obj(x=all_x[ntrain:ntrain+nvalid], | |
44 y=all_targ[ntrain:ntrain+nvalid]) | |
45 rval.test = Dataset.Obj(x=all_x[ntrain+nvalid:ntrain+nvalid+ntest], | |
46 y=all_targ[ntrain+nvalid:ntrain+nvalid+ntest]) | |
47 | |
48 rval.n_classes = 10 | |
49 return rval | |
50 | |
51 | |
52 | |
53 @dataset_factory('MNIST') | |
54 def mnist_factory(variant="", ntrain=None, nvalid=None, ntest=None): | |
55 if variant=="": | |
56 return train_valid_test() | |
57 elif variant=="1k": | |
58 return train_valid_test(ntrain=1000, nvalid=200, ntest=200) | |
59 elif variant=="10k": | |
60 return train_valid_test(ntrain=10000, nvalid=2000, ntest=2000) | |
61 elif variant=="custom": | |
62 return train_valid_test(ntrain=ntrain, nvalid=nvalid, ntest=ntest) | |
63 else: | |
64 raise Exception('Unknown MNIST variant', variant) |