diff datasets/ftfile.py @ 613:5e481b224117

fix the reading of PNIST dataset following Dumi compression of the data.
author Frederic Bastien <nouiz@nouiz.org>
date Thu, 06 Jan 2011 13:57:05 -0500
parents a92ec9939e4f
children 212b142dcfc8
line wrap: on
line diff
--- a/datasets/ftfile.py	Mon Dec 20 11:54:35 2010 -0500
+++ b/datasets/ftfile.py	Thu Jan 06 13:57:05 2011 -0500
@@ -1,8 +1,12 @@
+from itertools import izip
+import os
+
+import numpy
 from pylearn.io.filetensor import _read_header, _prod
-import numpy, theano
+
 from dataset import DataSet
 from dsetiter import DataIterator
-from itertools import izip, imap
+
 
 class FTFile(object):
     def __init__(self, fname, scale=1, dtype=None):
@@ -10,8 +14,17 @@
         Tests:
             >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
         """
-        self.file = open(fname, 'rb')
-        self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
+        if os.path.exists(fname):
+            self.file = open(fname, 'rb')
+            self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
+            self.gz=False
+        else:
+            import gzip
+            self.file = gzip.open(fname+'.gz','rb')
+            self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file.read(100), False, True)
+            self.file.seek(0)
+            self.gz=True
+
         self.size = self.dim[0]
         self.scale = scale
         self.dtype = dtype
@@ -81,7 +94,10 @@
             num = self.size
         self.dim[0] = num
         self.size -= num
-        res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
+        if self.gz:
+            res = numpy.fromstring(self.file.read(), dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
+        else:
+            res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
         if self.dtype is not None:
             res = res.astype(self.dtype)
         if self.scale != 1: