Mercurial > pylearn
comparison datasets/MNIST.py @ 475:11e0357f06f4
typo in MNIST.train_valid_test
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 23 Oct 2008 18:06:21 -0400 |
parents | 45b3eb429c15 |
children | 19ab9ce916e3 |
comparison
equal
deleted
inserted
replaced
474:40c8a46b3da7 | 475:11e0357f06f4 |
---|---|
23 | 23 |
24 dat = AMat(path=path, head=n) | 24 dat = AMat(path=path, head=n) |
25 | 25 |
26 return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0]) | 26 return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0]) |
27 | 27 |
28 def all(path=None): | 28 def train_valid_test(ntrain=50000, nvalid=10000, ntest=10000, path=None): |
29 return head(n=None, path=path) | 29 all_x, all_targ = head(ntrain+nvalid+ntest, path=path) |
30 | |
31 | |
32 def train_valid_test(path=None, ntrain=50000, nvalid=10000, ntest=10000): | |
33 all_x, all_targ = all(path=path) | |
34 | 30 |
35 train = all_x[0:ntrain], all_targ[0:ntrain] | 31 train = all_x[0:ntrain], all_targ[0:ntrain] |
36 valid = all_x[ntrain:ntrain+nvalid], all_targ[ntrain:ntrain+nvalid] | 32 valid = all_x[ntrain:ntrain+nvalid], all_targ[ntrain:ntrain+nvalid] |
37 test = all_x[ntrain+nvalid:ntrain+nvalid+ntest], all_targ[ntrain+nvalid:ntrain+nvalid+ntest] | 33 test = all_x[ntrain+nvalid:ntrain+nvalid+ntest], all_targ[ntrain+nvalid:ntrain+nvalid+ntest] |
38 | 34 |
39 return train, valid, test | 35 return train, valid, test |
40 | 36 |
37 def all(path=None): | |
38 return head(n=None, path=path) | |
39 | |
40 |