comparison datasets/ftfile.py @ 187:c03692aa6158

Merge
author fsavard
date Mon, 01 Mar 2010 11:45:59 -0500
parents f0f47b045cbf
children 1faae5079522
comparison
equal deleted inserted replaced
186:d364a130b221 187:c03692aa6158
1 from pylearn.io.filetensor import _read_header, _prod 1 from pylearn.io.filetensor import _read_header, _prod
2 import numpy 2 import numpy, theano
3 from dataset import DataSet 3 from dataset import DataSet
4 from dsetiter import DataIterator 4 from dsetiter import DataIterator
5 from itertools import izip, imap
5 6
6 class FTFile(object): 7 class FTFile(object):
7 def __init__(self, fname): 8 def __init__(self, fname, scale=1, dtype=None):
8 r""" 9 r"""
9 Tests: 10 Tests:
10 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') 11 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
11 """ 12 """
12 self.file = open(fname, 'rb') 13 self.file = open(fname, 'rb')
13 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) 14 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
14 self.size = self.dim[0] 15 self.size = self.dim[0]
16 self.scale = scale
17 self.dtype = dtype
15 18
16 def skip(self, num): 19 def skip(self, num):
17 r""" 20 r"""
18 Skips `num` items in the file. 21 Skips `num` items in the file.
22
23 If `num` is negative, skips size-num.
19 24
20 Tests: 25 Tests:
21 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') 26 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
22 >>> f.size 27 >>> f.size
23 58646 28 58646
28 >>> f.skip(1000) 33 >>> f.skip(1000)
29 >>> f.file.tell() 34 >>> f.file.tell()
30 4020 35 4020
31 >>> f.size 36 >>> f.size
32 57646 37 57646
33 """ 38 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
39 >>> f.size
40 58646
41 >>> f.file.tell()
42 20
43 >>> f.skip(-1000)
44 >>> f.file.tell()
45 230604
46 >>> f.size
47 1000
48 """
49 if num < 0:
50 num += self.size
51 if num < 0:
52 raise ValueError('Skipping past the start of the file')
34 if num >= self.size: 53 if num >= self.size:
35 self.size = 0 54 self.size = 0
36 else: 55 else:
37 self.size -= num 56 self.size -= num
38 f_start = self.file.tell() 57 f_start = self.file.tell()
60 """ 79 """
61 if num > self.size: 80 if num > self.size:
62 num = self.size 81 num = self.size
63 self.dim[0] = num 82 self.dim[0] = num
64 self.size -= num 83 self.size -= num
65 return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) 84 res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
85 if self.dtype is not None:
86 res = res.astype(self.dtype)
87 if self.scale != 1:
88 res /= self.scale
89 return res
66 90
67 class FTSource(object): 91 class FTSource(object):
68 def __init__(self, file, skip=0, size=None): 92 def __init__(self, file, skip=0, size=None, dtype=None, scale=1):
69 r""" 93 r"""
70 Create a data source from a possible subset of a .ft file. 94 Create a data source from a possible subset of a .ft file.
71 95
72 Parameters: 96 Parameters:
73 `file` (string) -- the filename 97 `file` (string) -- the filename
74 `skip` (int, optional) -- amount of examples to skip from the start of the file 98 `skip` (int, optional) -- amount of examples to skip from
75 `size` (int, optional) -- truncates number of examples read (after skipping) 99 the start of the file. If
76 100 negative, skips filesize - skip.
101 `size` (int, optional) -- truncates number of examples
102 read (after skipping). If
103 negative truncates to
104 filesize - size
105 (also after skipping).
106 `dtype` (dtype, optional) -- convert the data to this
107 dtype after reading.
108 `scale` (number, optional) -- scale (that is divide) the
109 data by this number (after
110 dtype conversion, if any).
111
77 Tests: 112 Tests:
78 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') 113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
79 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) 114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
80 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) 115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
81 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) 116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
82 """ 117 """
83 self.file = file 118 self.file = file
84 self.skip = skip 119 self.skip = skip
85 self.size = size 120 self.size = size
121 self.dtype = dtype
122 self.scale = scale
86 123
87 def open(self): 124 def open(self):
88 r""" 125 r"""
89 Returns an FTFile that corresponds to this dataset. 126 Returns an FTFile that corresponds to this dataset.
90 127
98 >>> s.open().size 135 >>> s.open().size
99 1000 136 1000
100 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) 137 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
101 >>> s.open().size 138 >>> s.open().size
102 1 139 1
103 """ 140 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
104 f = FTFile(self.file) 141 >>> s.open().size
142 58636
143 """
144 f = FTFile(self.file, scale=self.scale, dtype=self.dtype)
105 if self.skip != 0: 145 if self.skip != 0:
106 f.skip(self.skip) 146 f.skip(self.skip)
107 if self.size is not None and self.size < f.size: 147 if self.size is not None and self.size < f.size:
108 f.size = self.size 148 if self.size < 0:
149 f.size += self.size
150 else:
151 f.size = self.size
109 return f 152 return f
110 153
111 class FTData(object): 154 class FTData(object):
112 r""" 155 r"""
113 This is a list of FTSources. 156 This is a list of FTSources.
114 """ 157 """
115 def __init__(self, datafiles, labelfiles, skip=0, size=None): 158 def __init__(self, datafiles, labelfiles, skip=0, size=None,
116 self.inputs = [FTSource(f, skip, size) for f in datafiles] 159 inscale=1, indtype=None, outscale=1, outdtype=None):
117 self.outputs = [FTSource(f, skip, size) for f in labelfiles] 160 self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype)
161 for f in datafiles]
162 self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype)
163 for f in labelfiles]
118 164
119 def open_inputs(self): 165 def open_inputs(self):
120 return [f.open() for f in self.inputs] 166 return [f.open() for f in self.inputs]
121 167
122 def open_outputs(self): 168 def open_outputs(self):
123 return [f.open() for f in self.outputs] 169 return [f.open() for f in self.outputs]
124 170
125 171
126 class FTDataSet(DataSet): 172 class FTDataSet(DataSet):
127 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None): 173 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1):
128 r""" 174 r"""
129 Defines a DataSet from a bunch of files. 175 Defines a DataSet from a bunch of files.
130 176
131 Parameters: 177 Parameters:
132 `train_data` -- list of train data files 178 `train_data` -- list of train data files
134 `test_data`, `test_labels` -- same thing as train, but for 180 `test_data`, `test_labels` -- same thing as train, but for
135 test. The number of files 181 test. The number of files
136 can differ from train. 182 can differ from train.
137 `valid_data`, `valid_labels` -- same thing again for validation. 183 `valid_data`, `valid_labels` -- same thing again for validation.
138 (optional) 184 (optional)
185 `indtype`, `outdtype`, -- see FTSource.__init__()
186 `inscale`, `outscale` (optional)
187
139 188
140 If `valid_data` and `valid_labels` are not supplied then a sample 189 If `valid_data` and `valid_labels` are not supplied then a sample
141 approximately equal in size to the test set is taken from the train 190 approximately equal in size to the test set is taken from the train
142 set. 191 set.
143 """ 192 """
144 if valid_data is None: 193 if valid_data is None:
145 total_valid_size = sum(FTFile(td).size for td in test_data) 194 total_valid_size = sum(FTFile(td).size for td in test_data)
146 valid_size = total_valid_size/len(train_data) 195 valid_size = total_valid_size/len(train_data)
147 self._train = FTData(train_data, train_lbl, skip=valid_size) 196 self._train = FTData(train_data, train_lbl, size=-valid_size)
148 self._valid = FTData(train_data, train_lbl, size=valid_size) 197 self._valid = FTData(train_data, train_lbl, skip=-valid_size)
149 else: 198 else:
150 self._train = FTData(train_data, train_lbl) 199 self._train = FTData(train_data, train_lbl)
151 self._valid = FTData(valid_data, valid_lbl) 200 self._valid = FTData(valid_data, valid_lbl)
152 self._test = FTData(test_data, test_lbl) 201 self._test = FTData(test_data, test_lbl)
153 202
154 def _return_it(self, batchsize, bufsize, ftdata): 203 def _return_it(self, batchsize, bufsize, ftdata):
155 return zip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), 204 return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
156 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) 205 DataIterator(ftdata.open_outputs(), batchsize, bufsize))