comparison datasets/ftfile.py @ 262:716c99f4eb3a

merge
author Xavier Glorot <glorotxa@iro.umontreal.ca>
date Wed, 17 Mar 2010 16:41:51 -0400
parents 966272e7f14b
children a92ec9939e4f
comparison
equal deleted inserted replaced
261:6d16a2bf142b 262:716c99f4eb3a
87 if self.scale != 1: 87 if self.scale != 1:
88 res /= self.scale 88 res /= self.scale
89 return res 89 return res
90 90
91 class FTSource(object): 91 class FTSource(object):
92 def __init__(self, file, skip=0, size=None, dtype=None, scale=1): 92 def __init__(self, file, skip=0, size=None, maxsize=None,
93 dtype=None, scale=1):
93 r""" 94 r"""
94 Create a data source from a possible subset of a .ft file. 95 Create a data source from a possible subset of a .ft file.
95 96
96 Parameters: 97 Parameters:
97 `file` (string) -- the filename 98 `file` -- (string) the filename
98 `skip` (int, optional) -- amount of examples to skip from 99 `skip` -- (int, optional) amount of examples to skip from
99 the start of the file. If 100 the start of the file. If negative, skips
100 negative, skips filesize - skip. 101 filesize - skip.
101 `size` (int, optional) -- truncates number of examples 102 `size` -- (int, optional) truncates number of examples
102 read (after skipping). If 103 read (after skipping). If negative truncates to
103 negative truncates to 104 filesize - size (also after skipping).
104 filesize - size 105 `maxsize` -- (int, optional) the maximum size of the file
105 (also after skipping). 106 `dtype` -- (dtype, optional) convert the data to this
106 `dtype` (dtype, optional) -- convert the data to this 107 dtype after reading.
107 dtype after reading. 108 `scale` -- (number, optional) scale (that is divide) the
108 `scale` (number, optional) -- scale (that is divide) the 109 data by this number (after dtype conversion, if
109 data by this number (after 110 any).
110 dtype conversion, if any). 111
111 112 Tests:
112 Tests: 113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') 114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) 115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) 116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
117 """ 117 """
118 self.file = file 118 self.file = file
119 self.skip = skip 119 self.skip = skip
120 self.size = size 120 self.size = size
121 self.dtype = dtype 121 self.dtype = dtype
122 self.scale = scale 122 self.scale = scale
123 self.maxsize = maxsize
123 124
124 def open(self): 125 def open(self):
125 r""" 126 r"""
126 Returns an FTFile that corresponds to this dataset. 127 Returns an FTFile that corresponds to this dataset.
127 128
128 Tests: 129 Tests:
129 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') 130 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
130 >>> f = s.open() 131 >>> f = s.open()
131 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) 132 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
132 >>> len(s.open().read(2)) 133 >>> len(s.open().read(2))
133 1 134 1
134 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) 135 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
135 >>> s.open().size 136 >>> s.open().size
136 1000 137 1000
137 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) 138 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
138 >>> s.open().size 139 >>> s.open().size
139 1 140 1
140 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) 141 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10)
141 >>> s.open().size 142 >>> s.open().size
142 58636 143 58636
143 """ 144 """
144 f = FTFile(self.file, scale=self.scale, dtype=self.dtype) 145 f = FTFile(self.file, scale=self.scale, dtype=self.dtype)
145 if self.skip != 0: 146 if self.skip != 0:
146 f.skip(self.skip) 147 f.skip(self.skip)
147 if self.size is not None and self.size < f.size: 148 if self.size is not None and self.size < f.size:
148 if self.size < 0: 149 if self.size < 0:
149 f.size += self.size 150 f.size += self.size
151 if f.size < 0:
152 f.size = 0
150 else: 153 else:
151 f.size = self.size 154 f.size = self.size
155 if self.maxsize is not None and f.size > self.maxsize:
156 f.size = self.maxsize
152 return f 157 return f
153 158
154 class FTData(object): 159 class FTData(object):
155 r""" 160 r"""
156 This is a list of FTSources. 161 This is a list of FTSources.
157 """ 162 """
158 def __init__(self, datafiles, labelfiles, skip=0, size=None, 163 def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None,
159 inscale=1, indtype=None, outscale=1, outdtype=None): 164 inscale=1, indtype=None, outscale=1, outdtype=None):
160 self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype) 165 if maxsize is not None:
166 maxsize /= len(datafiles)
167 self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype)
161 for f in datafiles] 168 for f in datafiles]
162 self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype) 169 self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype)
163 for f in labelfiles] 170 for f in labelfiles]
164 171
165 def open_inputs(self): 172 def open_inputs(self):
166 return [f.open() for f in self.inputs] 173 return [f.open() for f in self.inputs]
167 174
168 def open_outputs(self): 175 def open_outputs(self):
169 return [f.open() for f in self.outputs] 176 return [f.open() for f in self.outputs]
170 177
171 178
172 class FTDataSet(DataSet): 179 class FTDataSet(DataSet):
173 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1): 180 def __init__(self, train_data, train_lbl, test_data, test_lbl,
181 valid_data=None, valid_lbl=None, indtype=None, outdtype=None,
182 inscale=1, outscale=1, maxsize=None):
174 r""" 183 r"""
175 Defines a DataSet from a bunch of files. 184 Defines a DataSet from a bunch of files.
176 185
177 Parameters: 186 Parameters:
178 `train_data` -- list of train data files 187 `train_data` -- list of train data files
182 can differ from train. 191 can differ from train.
183 `valid_data`, `valid_labels` -- same thing again for validation. 192 `valid_data`, `valid_labels` -- same thing again for validation.
184 (optional) 193 (optional)
185 `indtype`, `outdtype`, -- see FTSource.__init__() 194 `indtype`, `outdtype`, -- see FTSource.__init__()
186 `inscale`, `outscale` (optional) 195 `inscale`, `outscale` (optional)
196 `maxsize` -- maximum size of the set returned
187 197
188 198
189 If `valid_data` and `valid_labels` are not supplied then a sample 199 If `valid_data` and `valid_labels` are not supplied then a sample
190 approximately equal in size to the test set is taken from the train 200 approximately equal in size to the test set is taken from the train
191 set. 201 set.
192 """ 202 """
193 if valid_data is None: 203 if valid_data is None:
194 total_valid_size = sum(FTFile(td).size for td in test_data) 204 total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize)
195 valid_size = total_valid_size/len(train_data) 205 valid_size = total_valid_size/len(train_data)
196 self._train = FTData(train_data, train_lbl, size=-valid_size, 206 self._train = FTData(train_data, train_lbl, size=-valid_size,
197 inscale=inscale, outscale=outscale, indtype=indtype, 207 inscale=inscale, outscale=outscale,
198 outdtype=outdtype) 208 indtype=indtype, outdtype=outdtype,
209 maxsize=maxsize)
199 self._valid = FTData(train_data, train_lbl, skip=-valid_size, 210 self._valid = FTData(train_data, train_lbl, skip=-valid_size,
200 inscale=inscale, outscale=outscale, indtype=indtype, 211 inscale=inscale, outscale=outscale,
201 outdtype=outdtype) 212 indtype=indtype, outdtype=outdtype,
213 maxsize=maxsize)
202 else: 214 else:
203 self._train = FTData(train_data, train_lbl,inscale=inscale, 215 self._train = FTData(train_data, train_lbl, maxsize=maxsize,
204 outscale=outscale, indtype=indtype, outdtype=outdtype) 216 inscale=inscale, outscale=outscale,
205 self._valid = FTData(valid_data, valid_lbl,inscale=inscale, 217 indtype=indtype, outdtype=outdtype)
206 outscale=outscale, indtype=indtype, outdtype=outdtype) 218 self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize,
207 self._test = FTData(test_data, test_lbl,inscale=inscale, 219 inscale=inscale, outscale=outscale,
208 outscale=outscale, indtype=indtype, outdtype=outdtype) 220 indtype=indtype, outdtype=outdtype)
221 self._test = FTData(test_data, test_lbl, maxsize=maxsize,
222 inscale=inscale, outscale=outscale,
223 indtype=indtype, outdtype=outdtype)
209 224
210 def _return_it(self, batchsize, bufsize, ftdata): 225 def _return_it(self, batchsize, bufsize, ftdata):
211 return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), 226 return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
212 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) 227 DataIterator(ftdata.open_outputs(), batchsize, bufsize))