Mercurial > ift6266
comparison datasets/ftfile.py @ 266:1e4e60ddadb1
Merge. Ah, et dans le dernier commit, j'avais oublié de mentionner que j'ai ajouté du code pour gérer l'isolation de différents clones pour rouler des expériences et modifier le code en même temps.
author | fsavard |
---|---|
date | Fri, 19 Mar 2010 10:56:16 -0400 |
parents | 966272e7f14b |
children | a92ec9939e4f |
comparison
equal
deleted
inserted
replaced
265:c8fe09a65039 | 266:1e4e60ddadb1 |
---|---|
87 if self.scale != 1: | 87 if self.scale != 1: |
88 res /= self.scale | 88 res /= self.scale |
89 return res | 89 return res |
90 | 90 |
91 class FTSource(object): | 91 class FTSource(object): |
92 def __init__(self, file, skip=0, size=None, dtype=None, scale=1): | 92 def __init__(self, file, skip=0, size=None, maxsize=None, |
93 dtype=None, scale=1): | |
93 r""" | 94 r""" |
94 Create a data source from a possible subset of a .ft file. | 95 Create a data source from a possible subset of a .ft file. |
95 | 96 |
96 Parameters: | 97 Parameters: |
97 `file` (string) -- the filename | 98 `file` -- (string) the filename |
98 `skip` (int, optional) -- amount of examples to skip from | 99 `skip` -- (int, optional) amount of examples to skip from |
99 the start of the file. If | 100 the start of the file. If negative, skips |
100 negative, skips filesize - skip. | 101 filesize - skip. |
101 `size` (int, optional) -- truncates number of examples | 102 `size` -- (int, optional) truncates number of examples |
102 read (after skipping). If | 103 read (after skipping). If negative truncates to |
103 negative truncates to | 104 filesize - size (also after skipping). |
104 filesize - size | 105 `maxsize` -- (int, optional) the maximum size of the file |
105 (also after skipping). | 106 `dtype` -- (dtype, optional) convert the data to this |
106 `dtype` (dtype, optional) -- convert the data to this | 107 dtype after reading. |
107 dtype after reading. | 108 `scale` -- (number, optional) scale (that is divide) the |
108 `scale` (number, optional) -- scale (that is divide) the | 109 data by this number (after dtype conversion, if |
109 data by this number (after | 110 any). |
110 dtype conversion, if any). | 111 |
111 | 112 Tests: |
112 Tests: | 113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') |
113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | 114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) |
114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) | 115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) |
115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) | 116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) |
116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) | |
117 """ | 117 """ |
118 self.file = file | 118 self.file = file |
119 self.skip = skip | 119 self.skip = skip |
120 self.size = size | 120 self.size = size |
121 self.dtype = dtype | 121 self.dtype = dtype |
122 self.scale = scale | 122 self.scale = scale |
123 self.maxsize = maxsize | |
123 | 124 |
124 def open(self): | 125 def open(self): |
125 r""" | 126 r""" |
126 Returns an FTFile that corresponds to this dataset. | 127 Returns an FTFile that corresponds to this dataset. |
127 | 128 |
128 Tests: | 129 Tests: |
129 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | 130 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') |
130 >>> f = s.open() | 131 >>> f = s.open() |
131 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) | 132 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) |
132 >>> len(s.open().read(2)) | 133 >>> len(s.open().read(2)) |
133 1 | 134 1 |
134 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) | 135 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) |
135 >>> s.open().size | 136 >>> s.open().size |
136 1000 | 137 1000 |
137 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) | 138 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) |
138 >>> s.open().size | 139 >>> s.open().size |
139 1 | 140 1 |
140 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) | 141 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) |
141 >>> s.open().size | 142 >>> s.open().size |
142 58636 | 143 58636 |
143 """ | 144 """ |
144 f = FTFile(self.file, scale=self.scale, dtype=self.dtype) | 145 f = FTFile(self.file, scale=self.scale, dtype=self.dtype) |
145 if self.skip != 0: | 146 if self.skip != 0: |
146 f.skip(self.skip) | 147 f.skip(self.skip) |
147 if self.size is not None and self.size < f.size: | 148 if self.size is not None and self.size < f.size: |
148 if self.size < 0: | 149 if self.size < 0: |
149 f.size += self.size | 150 f.size += self.size |
151 if f.size < 0: | |
152 f.size = 0 | |
150 else: | 153 else: |
151 f.size = self.size | 154 f.size = self.size |
155 if self.maxsize is not None and f.size > self.maxsize: | |
156 f.size = self.maxsize | |
152 return f | 157 return f |
153 | 158 |
154 class FTData(object): | 159 class FTData(object): |
155 r""" | 160 r""" |
156 This is a list of FTSources. | 161 This is a list of FTSources. |
157 """ | 162 """ |
158 def __init__(self, datafiles, labelfiles, skip=0, size=None, | 163 def __init__(self, datafiles, labelfiles, skip=0, size=None, maxsize=None, |
159 inscale=1, indtype=None, outscale=1, outdtype=None): | 164 inscale=1, indtype=None, outscale=1, outdtype=None): |
160 self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype) | 165 if maxsize is not None: |
166 maxsize /= len(datafiles) | |
167 self.inputs = [FTSource(f, skip, size, maxsize, scale=inscale, dtype=indtype) | |
161 for f in datafiles] | 168 for f in datafiles] |
162 self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype) | 169 self.outputs = [FTSource(f, skip, size, maxsize, scale=outscale, dtype=outdtype) |
163 for f in labelfiles] | 170 for f in labelfiles] |
164 | 171 |
165 def open_inputs(self): | 172 def open_inputs(self): |
166 return [f.open() for f in self.inputs] | 173 return [f.open() for f in self.inputs] |
167 | 174 |
168 def open_outputs(self): | 175 def open_outputs(self): |
169 return [f.open() for f in self.outputs] | 176 return [f.open() for f in self.outputs] |
170 | 177 |
171 | 178 |
172 class FTDataSet(DataSet): | 179 class FTDataSet(DataSet): |
173 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1): | 180 def __init__(self, train_data, train_lbl, test_data, test_lbl, |
181 valid_data=None, valid_lbl=None, indtype=None, outdtype=None, | |
182 inscale=1, outscale=1, maxsize=None): | |
174 r""" | 183 r""" |
175 Defines a DataSet from a bunch of files. | 184 Defines a DataSet from a bunch of files. |
176 | 185 |
177 Parameters: | 186 Parameters: |
178 `train_data` -- list of train data files | 187 `train_data` -- list of train data files |
182 can differ from train. | 191 can differ from train. |
183 `valid_data`, `valid_labels` -- same thing again for validation. | 192 `valid_data`, `valid_labels` -- same thing again for validation. |
184 (optional) | 193 (optional) |
185 `indtype`, `outdtype`, -- see FTSource.__init__() | 194 `indtype`, `outdtype`, -- see FTSource.__init__() |
186 `inscale`, `outscale` (optional) | 195 `inscale`, `outscale` (optional) |
196 `maxsize` -- maximum size of the set returned | |
187 | 197 |
188 | 198 |
189 If `valid_data` and `valid_labels` are not supplied then a sample | 199 If `valid_data` and `valid_labels` are not supplied then a sample |
190 approximately equal in size to the test set is taken from the train | 200 approximately equal in size to the test set is taken from the train |
191 set. | 201 set. |
192 """ | 202 """ |
193 if valid_data is None: | 203 if valid_data is None: |
194 total_valid_size = sum(FTFile(td).size for td in test_data) | 204 total_valid_size = min(sum(FTFile(td).size for td in test_data), maxsize) |
195 valid_size = total_valid_size/len(train_data) | 205 valid_size = total_valid_size/len(train_data) |
196 self._train = FTData(train_data, train_lbl, size=-valid_size, | 206 self._train = FTData(train_data, train_lbl, size=-valid_size, |
197 inscale=inscale, outscale=outscale, indtype=indtype, | 207 inscale=inscale, outscale=outscale, |
198 outdtype=outdtype) | 208 indtype=indtype, outdtype=outdtype, |
209 maxsize=maxsize) | |
199 self._valid = FTData(train_data, train_lbl, skip=-valid_size, | 210 self._valid = FTData(train_data, train_lbl, skip=-valid_size, |
200 inscale=inscale, outscale=outscale, indtype=indtype, | 211 inscale=inscale, outscale=outscale, |
201 outdtype=outdtype) | 212 indtype=indtype, outdtype=outdtype, |
213 maxsize=maxsize) | |
202 else: | 214 else: |
203 self._train = FTData(train_data, train_lbl,inscale=inscale, | 215 self._train = FTData(train_data, train_lbl, maxsize=maxsize, |
204 outscale=outscale, indtype=indtype, outdtype=outdtype) | 216 inscale=inscale, outscale=outscale, |
205 self._valid = FTData(valid_data, valid_lbl,inscale=inscale, | 217 indtype=indtype, outdtype=outdtype) |
206 outscale=outscale, indtype=indtype, outdtype=outdtype) | 218 self._valid = FTData(valid_data, valid_lbl, maxsize=maxsize, |
207 self._test = FTData(test_data, test_lbl,inscale=inscale, | 219 inscale=inscale, outscale=outscale, |
208 outscale=outscale, indtype=indtype, outdtype=outdtype) | 220 indtype=indtype, outdtype=outdtype) |
221 self._test = FTData(test_data, test_lbl, maxsize=maxsize, | |
222 inscale=inscale, outscale=outscale, | |
223 indtype=indtype, outdtype=outdtype) | |
209 | 224 |
210 def _return_it(self, batchsize, bufsize, ftdata): | 225 def _return_it(self, batchsize, bufsize, ftdata): |
211 return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), | 226 return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), |
212 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) | 227 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) |