Mercurial > ift6266
comparison datasets/ftfile.py @ 262:716c99f4eb3a
merge
author | Xavier Glorot <glorotxa@iro.umontreal.ca> |
---|---|
date | Wed, 17 Mar 2010 16:41:51 -0400 |
parents | 966272e7f14b |
children | a92ec9939e4f |
comparison
equal
deleted
inserted
replaced
261:6d16a2bf142b | 262:716c99f4eb3a |
---|---|
87 if self.scale != 1: | 87 if self.scale != 1: |
88 res /= self.scale | 88 res /= self.scale |
89 return res | 89 return res |
90 | 90 |
91 class FTSource(object): | 91 class FTSource(object): |
92 def __init__(self, file, skip=0, size=None, dtype=None, scale=1): | 92 def __init__(self, file, skip=0, size=None, maxsize=None, |
93 dtype=None, scale=1): | |
93 r""" | 94 r""" |
94 Create a data source from a possible subset of a .ft file. | 95 Create a data source from a possible subset of a .ft file. |
95 | 96 |
96 Parameters: | 97 Parameters: |
97 `file` (string) -- the filename | 98 `file` -- (string) the filename |
98 `skip` (int, optional) -- amount of examples to skip from | 99 `skip` -- (int, optional) amount of examples to skip from |
99 the start of the file. If | 100 the start of the file. If negative, skips |
100 negative, skips filesize - skip. | 101 filesize - skip. |
101 `size` (int, optional) -- truncates number of examples | 102 `size` -- (int, optional) truncates number of examples |
102 read (after skipping). If | 103 read (after skipping). If negative truncates to |
103 negative truncates to | 104 filesize - size (also after skipping). |
104 filesize - size | 105 `maxsize` -- (int, optional) the maximum size of the file |
105 (also after skipping). | 106 `dtype` -- (dtype, optional) convert the data to this |
106 `dtype` (dtype, optional) -- convert the data to this | 107 dtype after reading. |
107 dtype after reading. | 108 `scale` -- (number, optional) scale (that is divide) the |
108 `scale` (number, optional) -- scale (that is divide) the | 109 data by this number (after dtype conversion, if |
109 data by this number (after | 110 any). |
110 dtype conversion, if any). | 111 |
111 | 112 Tests: |
112 Tests: | 113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') |
113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | 114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) |
114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) | 115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) |
115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) | 116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) |
116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) | |
117 """ | 117 """ |
118 self.file = file | 118 self.file = file |
119 self.skip = skip | 119 self.skip = skip |
120 self.size = size | 120 self.size = size |
121 self.dtype = dtype | 121 self.dtype = dtype |
122 self.scale = scale | 122 self.scale = scale |
123 self.maxsize = maxsize | |
123 | 124 |
124 def open(self): | 125 def open(self): |
125 r""" | 126 r""" |
126 Returns an FTFile that corresponds to this dataset. | 127 Returns an FTFile that corresponds to this dataset. |
127 | 128 |
128 Tests: | 129 Tests: |
129 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | 130 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') |
130 >>> f = s.open() | 131 >>> f = s.open() |
131 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) | 132 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) |
132 >>> len(s.open().read(2)) | 133 >>> len(s.open().read(2)) |
133 1 | 134 1 |
134 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) | 135 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) |
135 >>> s.open().size | 136 >>> s.open().size |
136 1000 | 137 1000 |
137 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) | 138 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) |
138 >>> s.open().size | 139 >>> s.open().size |
139 1 | 140 1 |
140 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) | 141 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) |
141 >>> s.open().size | 142 >>> s.open().size |
142 58636 | 143 58636 |
143 """ | 144 """ |
144 f = FTFile(self.file, scale=self.scale, dtype=self.dtype) | 145 f = FTFile(self.file, scale=self.scale, dtype=self.dtype) |
145 if self.skip != 0: | 146 if self.skip != 0: |
146 f.skip(self.skip) | 147 f.skip(self.skip) |
147 if self.size is not None and self.size < f.size: | 148 if self.size is not None and self.size < f.size: |
148 if self.size < 0: | 149 if self.size < 0: |
149 f.size += self.size | 150 f.size += self.size |
151 if f.size < 0: | |
152 f.size = 0 | |
150 else: | 153 else: |
151 f.size = self.size | 154 f.size = self.size |
155 if self.maxsize is not None and f.size > self.maxsize: | |
156 f.size = self.maxsize | |
152 return f | 157 return f |
153 | 158 |
154 class FTData(object): | 159 class FTData(object): |
155 r""" | 160 r""" |
156 This is a list of FTSources. | 161 This is a list of FTSources. |
157 """ | 162 """ |
158 def __init__(self, datafiles, labelfiles, skip=0, size=None, | 163 def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None, |
159 inscale=1, indtype=None, outscale=1, outdtype=None): | 164 inscale=1, indtype=None, outscale=1, outdtype=None): |
160 self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype) | 165 if maxsize is not None: |
166 maxsize /= len(datafiles) | |
167 self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype) | |
161 for f in datafiles] | 168 for f in datafiles] |
162 self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype) | 169 self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype) |
163 for f in labelfiles] | 170 for f in labelfiles] |
164 | 171 |
165 def open_inputs(self): | 172 def open_inputs(self): |
166 return [f.open() for f in self.inputs] | 173 return [f.open() for f in self.inputs] |
167 | 174 |
168 def open_outputs(self): | 175 def open_outputs(self): |
169 return [f.open() for f in self.outputs] | 176 return [f.open() for f in self.outputs] |
170 | 177 |
171 | 178 |
172 class FTDataSet(DataSet): | 179 class FTDataSet(DataSet): |
173 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1): | 180 def __init__(self, train_data, train_lbl, test_data, test_lbl, |
181 valid_data=None, valid_lbl=None, indtype=None, outdtype=None, | |
182 inscale=1, outscale=1, maxsize=None): | |
174 r""" | 183 r""" |
175 Defines a DataSet from a bunch of files. | 184 Defines a DataSet from a bunch of files. |
176 | 185 |
177 Parameters: | 186 Parameters: |
178 `train_data` -- list of train data files | 187 `train_data` -- list of train data files |
182 can differ from train. | 191 can differ from train. |
183 `valid_data`, `valid_labels` -- same thing again for validation. | 192 `valid_data`, `valid_labels` -- same thing again for validation. |
184 (optional) | 193 (optional) |
185 `indtype`, `outdtype`, -- see FTSource.__init__() | 194 `indtype`, `outdtype`, -- see FTSource.__init__() |
186 `inscale`, `outscale` (optional) | 195 `inscale`, `outscale` (optional) |
196 `maxsize` -- maximum size of the set returned | |
187 | 197 |
188 | 198 |
189 If `valid_data` and `valid_labels` are not supplied then a sample | 199 If `valid_data` and `valid_labels` are not supplied then a sample |
190 approximately equal in size to the test set is taken from the train | 200 approximately equal in size to the test set is taken from the train |
191 set. | 201 set. |
192 """ | 202 """ |
193 if valid_data is None: | 203 if valid_data is None: |
194 total_valid_size = sum(FTFile(td).size for td in test_data) | 204 total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize) |
195 valid_size = total_valid_size/len(train_data) | 205 valid_size = total_valid_size/len(train_data) |
196 self._train = FTData(train_data, train_lbl, size=-valid_size, | 206 self._train = FTData(train_data, train_lbl, size=-valid_size, |
197 inscale=inscale, outscale=outscale, indtype=indtype, | 207 inscale=inscale, outscale=outscale, |
198 outdtype=outdtype) | 208 indtype=indtype, outdtype=outdtype, |
209 maxsize=maxsize) | |
199 self._valid = FTData(train_data, train_lbl, skip=-valid_size, | 210 self._valid = FTData(train_data, train_lbl, skip=-valid_size, |
200 inscale=inscale, outscale=outscale, indtype=indtype, | 211 inscale=inscale, outscale=outscale, |
201 outdtype=outdtype) | 212 indtype=indtype, outdtype=outdtype, |
213 maxsize=maxsize) | |
202 else: | 214 else: |
203 self._train = FTData(train_data, train_lbl,inscale=inscale, | 215 self._train = FTData(train_data, train_lbl, maxsize=maxsize, |
204 outscale=outscale, indtype=indtype, outdtype=outdtype) | 216 inscale=inscale, outscale=outscale, |
205 self._valid = FTData(valid_data, valid_lbl,inscale=inscale, | 217 indtype=indtype, outdtype=outdtype) |
206 outscale=outscale, indtype=indtype, outdtype=outdtype) | 218 self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize, |
207 self._test = FTData(test_data, test_lbl,inscale=inscale, | 219 inscale=inscale, outscale=outscale, |
208 outscale=outscale, indtype=indtype, outdtype=outdtype) | 220 indtype=indtype, outdtype=outdtype) |
221 self._test = FTData(test_data, test_lbl, maxsize=maxsize, | |
222 inscale=inscale, outscale=outscale, | |
223 indtype=indtype, outdtype=outdtype) | |
209 | 224 |
210 def _return_it(self, batchsize, bufsize, ftdata): | 225 def _return_it(self, batchsize, bufsize, ftdata): |
211 return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), | 226 return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), |
212 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) | 227 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) |