Mercurial > ift6266
comparison datasets/ftfile.py @ 180:76bc047df5ee
Add dtype conversion and rescaling to the read path.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Sat, 27 Feb 2010 16:50:16 -0500 |
parents | 938bd350dbf0 |
children | f0f47b045cbf |
comparison
equal
deleted
inserted
replaced
179:defd388aba0c | 180:76bc047df5ee |
---|---|
3 from dataset import DataSet | 3 from dataset import DataSet |
4 from dsetiter import DataIterator | 4 from dsetiter import DataIterator |
5 from itertools import izip, imap | 5 from itertools import izip, imap |
6 | 6 |
7 class FTFile(object): | 7 class FTFile(object): |
8 def __init__(self, fname): | 8 def __init__(self, fname, scale=1, dtype=None): |
9 r""" | 9 r""" |
10 Tests: | 10 Tests: |
11 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') | 11 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') |
12 """ | 12 """ |
13 self.file = open(fname, 'rb') | 13 self.file = open(fname, 'rb') |
14 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) | 14 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) |
15 self.size = self.dim[0] | 15 self.size = self.dim[0] |
16 self.scale = scale | |
17 self.dtype = dtype | |
16 | 18 |
17 def skip(self, num): | 19 def skip(self, num): |
18 r""" | 20 r""" |
19 Skips `num` items in the file. | 21 Skips `num` items in the file. |
20 | 22 |
77 """ | 79 """ |
78 if num > self.size: | 80 if num > self.size: |
79 num = self.size | 81 num = self.size |
80 self.dim[0] = num | 82 self.dim[0] = num |
81 self.size -= num | 83 self.size -= num |
82 return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) | 84 res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) |
85 if self.dtype is not None: | |
86 res = res.astype(self.dtype) | |
87 if self.scale != 1: | |
88 res /= self.scale | |
89 return res | |
83 | 90 |
84 class FTSource(object): | 91 class FTSource(object): |
85 def __init__(self, file, skip=0, size=None): | 92 def __init__(self, file, skip=0, size=None, dtype=None, scale=1): |
86 r""" | 93 r""" |
87 Create a data source from a possible subset of a .ft file. | 94 Create a data source from a possible subset of a .ft file. |
88 | 95 |
89 Parameters: | 96 Parameters: |
90 `file` (string) -- the filename | 97 `file` (string) -- the filename |
94 `size` (int, optional) -- truncates number of examples | 101 `size` (int, optional) -- truncates number of examples |
95 read (after skipping). If | 102 read (after skipping). If |
96 negative truncates to | 103 negative truncates to |
97 filesize - size | 104 filesize - size |
98 (also after skipping). | 105 (also after skipping). |
99 | 106 `dtype` (dtype, optional) -- convert the data to this |
107 dtype after reading. | |
108 `scale` (number, optional) -- scale (that is divide) the | |
109 data by this number (after | |
110 dtype conversion, if any). | |
111 | |
100 Tests: | 112 Tests: |
101 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | 113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') |
102 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) | 114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) |
103 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) | 115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) |
104 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) | 116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) |
105 """ | 117 """ |
106 self.file = file | 118 self.file = file |
107 self.skip = skip | 119 self.skip = skip |
108 self.size = size | 120 self.size = size |
121 self.dtype = dtype | |
122 self.scale = scale | |
109 | 123 |
110 def open(self): | 124 def open(self): |
111 r""" | 125 r""" |
112 Returns an FTFile that corresponds to this dataset. | 126 Returns an FTFile that corresponds to this dataset. |
113 | 127 |
125 1 | 139 1 |
126 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) | 140 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) |
127 >>> s.open().size | 141 >>> s.open().size |
128 58636 | 142 58636 |
129 """ | 143 """ |
130 f = FTFile(self.file) | 144 f = FTFile(self.file, scale=self.scale, dtype=self.dtype) |
131 if self.skip != 0: | 145 if self.skip != 0: |
132 f.skip(self.skip) | 146 f.skip(self.skip) |
133 if self.size is not None and self.size < f.size: | 147 if self.size is not None and self.size < f.size: |
134 if self.size < 0: | 148 if self.size < 0: |
135 f.size += self.size | 149 f.size += self.size |
139 | 153 |
140 class FTData(object): | 154 class FTData(object): |
141 r""" | 155 r""" |
142 This is a list of FTSources. | 156 This is a list of FTSources. |
143 """ | 157 """ |
144 def __init__(self, datafiles, labelfiles, skip=0, size=None): | 158 def __init__(self, datafiles, labelfiles, skip=0, size=None, |
145 self.inputs = [FTSource(f, skip, size) for f in datafiles] | 159 inscale=1, indtype=None, outscale=1, outdtype=None): |
146 self.outputs = [FTSource(f, skip, size) for f in labelfiles] | 160 self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype) |
161 for f in datafiles] | |
162 self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype) | |
163 for f in labelfiles] | |
147 | 164 |
148 def open_inputs(self): | 165 def open_inputs(self): |
149 return [f.open() for f in self.inputs] | 166 return [f.open() for f in self.inputs] |
150 | 167 |
151 def open_outputs(self): | 168 def open_outputs(self): |
152 return [f.open() for f in self.outputs] | 169 return [f.open() for f in self.outputs] |
153 | 170 |
154 | 171 |
155 class FTDataSet(DataSet): | 172 class FTDataSet(DataSet): |
156 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None): | 173 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1): |
157 r""" | 174 r""" |
158 Defines a DataSet from a bunch of files. | 175 Defines a DataSet from a bunch of files. |
159 | 176 |
160 Parameters: | 177 Parameters: |
161 `train_data` -- list of train data files | 178 `train_data` -- list of train data files |
163 `test_data`, `test_labels` -- same thing as train, but for | 180 `test_data`, `test_labels` -- same thing as train, but for |
164 test. The number of files | 181 test. The number of files |
165 can differ from train. | 182 can differ from train. |
166 `valid_data`, `valid_labels` -- same thing again for validation. | 183 `valid_data`, `valid_labels` -- same thing again for validation. |
167 (optional) | 184 (optional) |
185 `indtype`, `outdtype`, -- see FTSource.__init__() | |
186 `inscale`, `outscale` (optional) | |
187 | |
168 | 188 |
169 If `valid_data` and `valid_labels` are not supplied then a sample | 189 If `valid_data` and `valid_labels` are not supplied then a sample |
170 approximately equal in size to the test set is taken from the train | 190 approximately equal in size to the test set is taken from the train |
171 set. | 191 set. |
172 """ | 192 """ |