changeset 794:951272679910

get the mnist data from the pmat file and not the amat file
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Wed, 15 Jul 2009 13:18:55 -0400
parents 4e70f509ec01
children f30bb746f279
files pylearn/datasets/MNIST.py
diffstat 1 files changed, 5 insertions(+), 9 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/datasets/MNIST.py	Mon Jul 13 16:56:31 2009 -0400
+++ b/pylearn/datasets/MNIST.py	Wed Jul 15 13:18:55 2009 -0400
@@ -6,7 +6,7 @@
 import os
 import numpy
 
-from ..io.amat import AMat
+from ..io.pmat import PMat
 from .config import data_root # config
 from .dataset import Dataset
 
@@ -18,17 +18,13 @@
     is the label of the i'th row of x.
     
     """
-    path = os.path.join(data_root(), 'mnist','mnist_with_header.amat') if path is None else path
+    path = os.path.join(data_root(), 'mnist','mnist_all.pmat') if path is None else path
 
-    dat = AMat(path=path, head=n)
+    dat = PMat(fname=path)
 
-    try:
-        assert dat.input.shape[0] == n
-        assert dat.target.shape[0] == n
-    except Exception , e:
-        raise Exception("failed to read MNIST data", (dat, e))
+    rows=dat.getRows(0,n)
 
-    return dat.input, numpy.asarray(dat.target, dtype='int64').reshape(dat.target.shape[0])
+    return rows[:,0:-1], numpy.asarray(rows[:,-1], dtype='int64')
 
 def all(path=None):
     return head(n=None, path=path)