Mercurial > pylearn
annotate pylearn/datasets/MNIST.py @ 653:d3d8f5a17909
print warning on undefined PYLEARN_DATA_ROOT
author | bergstra@mlp4.ais.sandbox |
---|---|
date | Wed, 11 Feb 2009 01:42:58 -0500 |
parents | ec27e19bb6eb |
children | 6d927441a38f |
rev | line source |
---|---|
470
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
1 """ |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
2 Various routines to load/access MNIST data. |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
3 """ |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
4 from __future__ import absolute_import |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
5 |
505
74b3e65f5f24
added smallNorb dataset, switched to PYLEARN_DATA_ROOT
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
504
diff
changeset
|
6 import os |
470
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
7 import numpy |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
8 |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
9 from ..io.amat import AMat |
653
d3d8f5a17909
print warning on undefined PYLEARN_DATA_ROOT
bergstra@mlp4.ais.sandbox
parents:
627
diff
changeset
|
10 from .config import data_root # config |
627
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
11 from .dataset import Dataset |
470
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
12 |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
13 def head(n=10, path=None): |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
14 """Load the first MNIST examples. |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
15 |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
16 Returns two matrices: x, y. x has N rows of 784 columns. Each row of x represents the |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
17 28x28 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i] |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
18 is the label of the i'th row of x. |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
19 |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
20 """ |
511
58810b63292b
fixed mnist path
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
505
diff
changeset
|
21 path = os.path.join(data_root(), 'mnist','mnist_with_header.amat') if path is None else path |
470
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
22 |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
23 dat = AMat(path=path, head=n) |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
24 |
504
19ab9ce916e3
slightly more sophisticated system for finding the mnist data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
475
diff
changeset
|
25 try: |
19ab9ce916e3
slightly more sophisticated system for finding the mnist data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
475
diff
changeset
|
26 assert dat.input.shape[0] == n |
19ab9ce916e3
slightly more sophisticated system for finding the mnist data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
475
diff
changeset
|
27 assert dat.target.shape[0] == n |
19ab9ce916e3
slightly more sophisticated system for finding the mnist data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
475
diff
changeset
|
28 except Exception , e: |
19ab9ce916e3
slightly more sophisticated system for finding the mnist data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
475
diff
changeset
|
29 raise Exception("failed to read MNIST data", (dat, e)) |
19ab9ce916e3
slightly more sophisticated system for finding the mnist data
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
475
diff
changeset
|
30 |
470
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
31 return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0]) |
bd937e845bbb
new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff
changeset
|
32 |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
33 def all(path=None): |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
34 return head(n=None, path=path) |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
35 |
475
11e0357f06f4
typo in MNIST.train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
471
diff
changeset
|
36 def train_valid_test(ntrain=50000, nvalid=10000, ntest=10000, path=None): |
11e0357f06f4
typo in MNIST.train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
471
diff
changeset
|
37 all_x, all_targ = head(ntrain+nvalid+ntest, path=path) |
471
45b3eb429c15
added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
470
diff
changeset
|
38 |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
39 rval = Dataset() |
471
45b3eb429c15
added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
470
diff
changeset
|
40 |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
41 rval.train = Dataset.Obj(x=all_x[0:ntrain], |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
42 y=all_targ[0:ntrain]) |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
43 rval.valid = Dataset.Obj(x=all_x[ntrain:ntrain+nvalid], |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
44 y=all_targ[ntrain:ntrain+nvalid]) |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
45 rval.test = Dataset.Obj(x=all_x[ntrain+nvalid:ntrain+nvalid+ntest], |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
46 y=all_targ[ntrain+nvalid:ntrain+nvalid+ntest]) |
471
45b3eb429c15
added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
470
diff
changeset
|
47 |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
48 rval.n_classes = 10 |
563
16f91ca016b1
* added NStages as a stopper (moved from hpu/conv)
desjagui@atchoum.iro.umontreal.ca
parents:
537
diff
changeset
|
49 rval.img_shape = (28,28) |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
50 return rval |
475
11e0357f06f4
typo in MNIST.train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
471
diff
changeset
|
51 |
11e0357f06f4
typo in MNIST.train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
471
diff
changeset
|
52 |
627
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
53 def full(): |
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
54 return train_valid_test() |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
55 |
627
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
56 def first_1k(): |
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
57 return train_valid_test(ntrain=1000, nvalid=200, ntest=200) |
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
58 |
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
59 def first_10k(): |
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
60 return train_valid_test(ntrain=10000, nvalid=2000, ntest=2000) |
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
61 |
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
62 #old method from factory idea days... delete when ready -JB20090119 |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
63 def mnist_factory(variant="", ntrain=None, nvalid=None, ntest=None): |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
64 if variant=="": |
627
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
65 return full() |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
66 elif variant=="1k": |
627
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
67 return first_1k() |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
68 elif variant=="10k": |
627
ec27e19bb6eb
moving away from mnist_factory
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
563
diff
changeset
|
69 return first_10k() |
537
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
70 elif variant=="custom": |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
71 return train_valid_test(ntrain=ntrain, nvalid=nvalid, ntest=ntest) |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
72 else: |
b054271b2504
new file structure layout, factories, etc.
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
511
diff
changeset
|
73 raise Exception('Unknown MNIST variant', variant) |