Mercurial > pylearn
changeset 794:951272679910
get the mnist data from the pmat file and not the amat file
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Wed, 15 Jul 2009 13:18:55 -0400 |
parents | 4e70f509ec01 |
children | f30bb746f279 |
files | pylearn/datasets/MNIST.py |
diffstat | 1 files changed, 5 insertions(+), 9 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/datasets/MNIST.py Mon Jul 13 16:56:31 2009 -0400 +++ b/pylearn/datasets/MNIST.py Wed Jul 15 13:18:55 2009 -0400 @@ -6,7 +6,7 @@ import os import numpy -from ..io.amat import AMat +from ..io.pmat import PMat from .config import data_root # config from .dataset import Dataset @@ -18,17 +18,13 @@ is the label of the i'th row of x. """ - path = os.path.join(data_root(), 'mnist','mnist_with_header.amat') if path is None else path + path = os.path.join(data_root(), 'mnist','mnist_all.pmat') if path is None else path - dat = AMat(path=path, head=n) + dat = PMat(fname=path) - try: - assert dat.input.shape[0] == n - assert dat.target.shape[0] == n - except Exception , e: - raise Exception("failed to read MNIST data", (dat, e)) + rows=dat.getRows(0,n) - return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0]) + return rows[:,0:-1], numpy.asarray(rows[:,-1], dtype='int64') def all(path=None): return head(n=None, path=path)