comparison datasets/ftfile.py @ 613:5e481b224117

fix the reading of PNIST dataset following Dumi compression of the data.
author Frederic Bastien <nouiz@nouiz.org>
date Thu, 06 Jan 2011 13:57:05 -0500
parents a92ec9939e4f
children 212b142dcfc8
comparison
equal deleted inserted replaced
612:21d53fd07f6e 613:5e481b224117
1 from itertools import izip
2 import os
3
4 import numpy
1 from pylearn.io.filetensor import _read_header, _prod 5 from pylearn.io.filetensor import _read_header, _prod
2 import numpy, theano 6
3 from dataset import DataSet 7 from dataset import DataSet
4 from dsetiter import DataIterator 8 from dsetiter import DataIterator
5 from itertools import izip, imap 9
6 10
7 class FTFile(object): 11 class FTFile(object):
8 def __init__(self, fname, scale=1, dtype=None): 12 def __init__(self, fname, scale=1, dtype=None):
9 r""" 13 r"""
10 Tests: 14 Tests:
11 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') 15 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
12 """ 16 """
13 self.file = open(fname, 'rb') 17 if os.path.exists(fname):
14 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) 18 self.file = open(fname, 'rb')
19 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
20 self.gz=False
21 else:
22 import gzip
23 self.file = gzip.open(fname+'.gz','rb')
24 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file.read(100), False, True)
25 self.file.seek(0)
26 self.gz=True
27
15 self.size = self.dim[0] 28 self.size = self.dim[0]
16 self.scale = scale 29 self.scale = scale
17 self.dtype = dtype 30 self.dtype = dtype
18 31
19 def skip(self, num): 32 def skip(self, num):
79 """ 92 """
80 if num > self.size: 93 if num > self.size:
81 num = self.size 94 num = self.size
82 self.dim[0] = num 95 self.dim[0] = num
83 self.size -= num 96 self.size -= num
84 res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) 97 if self.gz:
98 res = numpy.fromstring(self.file.read(), dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
99 else:
100 res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
85 if self.dtype is not None: 101 if self.dtype is not None:
86 res = res.astype(self.dtype) 102 res = res.astype(self.dtype)
87 if self.scale != 1: 103 if self.scale != 1:
88 res /= self.scale 104 res /= self.scale
89 return res 105 return res