annotate datasets/MNIST.py @ 471:45b3eb429c15

added train_valid_test
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 23 Oct 2008 13:26:11 -0400
parents bd937e845bbb
children 11e0357f06f4
rev   line source
470
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
1 """
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
2 Various routines to load/access MNIST data.
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
3 """
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
4 from __future__ import absolute_import
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
5
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
6 import numpy
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
7
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
8 from ..amat import AMat
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
9
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
10 default_path = '/u/bergstrj/pub/data/mnist.amat'
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
11 """the location of a file containing mnist data in .amat format"""
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
12
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
13
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
14 def head(n=10, path=None):
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
15 """Load the first MNIST examples.
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
16
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
17 Returns two matrices: x, y. x has N rows of 784 columns. Each row of x represents the
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
18 28x28 grey-scale pixels in raster order. y is a vector of N integers. Each element y[i]
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
19 is the label of the i'th row of x.
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
20
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
21 """
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
22 path = path if path is not None else default_path
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
23
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
24 dat = AMat(path=path, head=n)
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
25
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
26 return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0])
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
27
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
28 def all(path=None):
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
29 return head(n=None, path=path)
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
30
bd937e845bbb new stuff: algorithms/logistic_regression, datasets/MNIST
James Bergstra <bergstrj@iro.umontreal.ca>
parents:
diff changeset
31
471
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
32 def train_valid_test(path=None, ntrain=50000, nvalid=10000, ntest=10000):
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
33 all_x, all_targ = all(path=path)
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
34
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
35 train = all_x[0:ntrain], all_targ[0:ntrain]
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
36 valid = all_x[ntrain:ntrain+nvalid], all_targ[ntrain:ntrain+nvalid]
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
37 test = all_x[ntrain+nvalid:ntrain+nvalid+ntest], all_targ[ntrain+nvalid:ntrain+nvalid+ntest]
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
38
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
39 return train, valid, test
45b3eb429c15 added train_valid_test
James Bergstra <bergstrj@iro.umontreal.ca>
parents: 470
diff changeset
40