diff datasets/ftfile.py @ 615:337253b82409

repair the class/fct that allow to read pnist07 and others by allowing them to read gziped file.
author Frederic Bastien <nouiz@nouiz.org>
date Fri, 07 Jan 2011 11:44:23 -0500
parents 212b142dcfc8
children
line wrap: on
line diff
--- a/datasets/ftfile.py	Thu Jan 06 14:23:41 2011 -0500
+++ b/datasets/ftfile.py	Fri Jan 07 11:44:23 2011 -0500
@@ -1,8 +1,12 @@
+from itertools import izip
+import os
+
+import numpy
 from pylearn.io.filetensor import _read_header, _prod
-import numpy, theano
+
 from dataset import DataSet
 from dsetiter import DataIterator
-from itertools import izip, imap
+
 
 class FTFile(object):
     def __init__(self, fname, scale=1, dtype=None):
@@ -10,8 +14,16 @@
         Tests:
             >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
         """
-        self.file = open(fname, 'rb')
-        self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
+        if os.path.exists(fname):
+            self.file = open(fname, 'rb')
+            self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
+            self.gz=False
+        else:
+            import gzip
+            self.file = gzip.open(fname+'.gz','rb')
+            self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False, True)
+            self.gz=True
+
         self.size = self.dim[0]
         self.scale = scale
         self.dtype = dtype
@@ -81,7 +93,11 @@
             num = self.size
         self.dim[0] = num
         self.size -= num
-        res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
+        if self.gz:
+            d = self.file.read(_prod(self.dim)*self.elsize)
+            res = numpy.fromstring(d, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
+        else:
+            res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
         if self.dtype is not None:
             res = res.astype(self.dtype)
         if self.scale != 1: