comparison datasets/ftfile.py @ 615:337253b82409

repair the class/fct that allow to read pnist07 and others by allowing them to read gziped file.
author Frederic Bastien <nouiz@nouiz.org>
date Fri, 07 Jan 2011 11:44:23 -0500
parents 212b142dcfc8
children
comparison
equal deleted inserted replaced
614:212b142dcfc8 615:337253b82409
1 from itertools import izip
2 import os
3
4 import numpy
1 from pylearn.io.filetensor import _read_header, _prod 5 from pylearn.io.filetensor import _read_header, _prod
2 import numpy, theano 6
3 from dataset import DataSet 7 from dataset import DataSet
4 from dsetiter import DataIterator 8 from dsetiter import DataIterator
5 from itertools import izip, imap 9
6 10
7 class FTFile(object): 11 class FTFile(object):
8 def __init__(self, fname, scale=1, dtype=None): 12 def __init__(self, fname, scale=1, dtype=None):
9 r""" 13 r"""
10 Tests: 14 Tests:
11 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') 15 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
12 """ 16 """
13 self.file = open(fname, 'rb') 17 if os.path.exists(fname):
14 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) 18 self.file = open(fname, 'rb')
19 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
20 self.gz=False
21 else:
22 import gzip
23 self.file = gzip.open(fname+'.gz','rb')
24 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False, True)
25 self.gz=True
26
15 self.size = self.dim[0] 27 self.size = self.dim[0]
16 self.scale = scale 28 self.scale = scale
17 self.dtype = dtype 29 self.dtype = dtype
18 30
19 def skip(self, num): 31 def skip(self, num):
79 """ 91 """
80 if num > self.size: 92 if num > self.size:
81 num = self.size 93 num = self.size
82 self.dim[0] = num 94 self.dim[0] = num
83 self.size -= num 95 self.size -= num
84 res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) 96 if self.gz:
97 d = self.file.read(_prod(self.dim)*self.elsize)
98 res = numpy.fromstring(d, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
99 else:
100 res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
85 if self.dtype is not None: 101 if self.dtype is not None:
86 res = res.astype(self.dtype) 102 res = res.astype(self.dtype)
87 if self.scale != 1: 103 if self.scale != 1:
88 res /= self.scale 104 res /= self.scale
89 return res 105 return res