Mercurial > ift6266
comparison datasets/ftfile.py @ 187:c03692aa6158
Merge
author | fsavard |
---|---|
date | Mon, 01 Mar 2010 11:45:59 -0500 |
parents | f0f47b045cbf |
children | 1faae5079522 |
comparison
equal
deleted
inserted
replaced
186:d364a130b221 | 187:c03692aa6158 |
---|---|
1 from pylearn.io.filetensor import _read_header, _prod | 1 from pylearn.io.filetensor import _read_header, _prod |
2 import numpy | 2 import numpy, theano |
3 from dataset import DataSet | 3 from dataset import DataSet |
4 from dsetiter import DataIterator | 4 from dsetiter import DataIterator |
5 from itertools import izip, imap | |
5 | 6 |
6 class FTFile(object): | 7 class FTFile(object): |
7 def __init__(self, fname): | 8 def __init__(self, fname, scale=1, dtype=None): |
8 r""" | 9 r""" |
9 Tests: | 10 Tests: |
10 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') | 11 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') |
11 """ | 12 """ |
12 self.file = open(fname, 'rb') | 13 self.file = open(fname, 'rb') |
13 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) | 14 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) |
14 self.size = self.dim[0] | 15 self.size = self.dim[0] |
16 self.scale = scale | |
17 self.dtype = dtype | |
15 | 18 |
16 def skip(self, num): | 19 def skip(self, num): |
17 r""" | 20 r""" |
18 Skips `num` items in the file. | 21 Skips `num` items in the file. |
22 | |
23 If `num` is negative, skips size-num. | |
19 | 24 |
20 Tests: | 25 Tests: |
21 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') | 26 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') |
22 >>> f.size | 27 >>> f.size |
23 58646 | 28 58646 |
28 >>> f.skip(1000) | 33 >>> f.skip(1000) |
29 >>> f.file.tell() | 34 >>> f.file.tell() |
30 4020 | 35 4020 |
31 >>> f.size | 36 >>> f.size |
32 57646 | 37 57646 |
33 """ | 38 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') |
39 >>> f.size | |
40 58646 | |
41 >>> f.file.tell() | |
42 20 | |
43 >>> f.skip(-1000) | |
44 >>> f.file.tell() | |
45 230604 | |
46 >>> f.size | |
47 1000 | |
48 """ | |
49 if num < 0: | |
50 num += self.size | |
51 if num < 0: | |
52 raise ValueError('Skipping past the start of the file') | |
34 if num >= self.size: | 53 if num >= self.size: |
35 self.size = 0 | 54 self.size = 0 |
36 else: | 55 else: |
37 self.size -= num | 56 self.size -= num |
38 f_start = self.file.tell() | 57 f_start = self.file.tell() |
60 """ | 79 """ |
61 if num > self.size: | 80 if num > self.size: |
62 num = self.size | 81 num = self.size |
63 self.dim[0] = num | 82 self.dim[0] = num |
64 self.size -= num | 83 self.size -= num |
65 return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) | 84 res = numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) |
85 if self.dtype is not None: | |
86 res = res.astype(self.dtype) | |
87 if self.scale != 1: | |
88 res /= self.scale | |
89 return res | |
66 | 90 |
67 class FTSource(object): | 91 class FTSource(object): |
68 def __init__(self, file, skip=0, size=None): | 92 def __init__(self, file, skip=0, size=None, dtype=None, scale=1): |
69 r""" | 93 r""" |
70 Create a data source from a possible subset of a .ft file. | 94 Create a data source from a possible subset of a .ft file. |
71 | 95 |
72 Parameters: | 96 Parameters: |
73 `file` (string) -- the filename | 97 `file` (string) -- the filename |
74 `skip` (int, optional) -- amount of examples to skip from the start of the file | 98 `skip` (int, optional) -- amount of examples to skip from |
75 `size` (int, optional) -- truncates number of examples read (after skipping) | 99 the start of the file. If |
76 | 100 negative, skips filesize - skip. |
101 `size` (int, optional) -- truncates number of examples | |
102 read (after skipping). If | |
103 negative truncates to | |
104 filesize - size | |
105 (also after skipping). | |
106 `dtype` (dtype, optional) -- convert the data to this | |
107 dtype after reading. | |
108 `scale` (number, optional) -- scale (that is divide) the | |
109 data by this number (after | |
110 dtype conversion, if any). | |
111 | |
77 Tests: | 112 Tests: |
78 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | 113 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') |
79 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) | 114 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) |
80 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) | 115 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) |
81 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) | 116 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) |
82 """ | 117 """ |
83 self.file = file | 118 self.file = file |
84 self.skip = skip | 119 self.skip = skip |
85 self.size = size | 120 self.size = size |
121 self.dtype = dtype | |
122 self.scale = scale | |
86 | 123 |
87 def open(self): | 124 def open(self): |
88 r""" | 125 r""" |
89 Returns an FTFile that corresponds to this dataset. | 126 Returns an FTFile that corresponds to this dataset. |
90 | 127 |
98 >>> s.open().size | 135 >>> s.open().size |
99 1000 | 136 1000 |
100 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) | 137 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) |
101 >>> s.open().size | 138 >>> s.open().size |
102 1 | 139 1 |
103 """ | 140 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=-10) |
104 f = FTFile(self.file) | 141 >>> s.open().size |
142 58636 | |
143 """ | |
144 f = FTFile(self.file, scale=self.scale, dtype=self.dtype) | |
105 if self.skip != 0: | 145 if self.skip != 0: |
106 f.skip(self.skip) | 146 f.skip(self.skip) |
107 if self.size is not None and self.size < f.size: | 147 if self.size is not None and self.size < f.size: |
108 f.size = self.size | 148 if self.size < 0: |
149 f.size += self.size | |
150 else: | |
151 f.size = self.size | |
109 return f | 152 return f |
110 | 153 |
111 class FTData(object): | 154 class FTData(object): |
112 r""" | 155 r""" |
113 This is a list of FTSources. | 156 This is a list of FTSources. |
114 """ | 157 """ |
115 def __init__(self, datafiles, labelfiles, skip=0, size=None): | 158 def __init__(self, datafiles, labelfiles, skip=0, size=None, |
116 self.inputs = [FTSource(f, skip, size) for f in datafiles] | 159 inscale=1, indtype=None, outscale=1, outdtype=None): |
117 self.outputs = [FTSource(f, skip, size) for f in labelfiles] | 160 self.inputs = [FTSource(f, skip, size, scale=inscale, dtype=indtype) |
161 for f in datafiles] | |
162 self.outputs = [FTSource(f, skip, size, scale=outscale, dtype=outdtype) | |
163 for f in labelfiles] | |
118 | 164 |
119 def open_inputs(self): | 165 def open_inputs(self): |
120 return [f.open() for f in self.inputs] | 166 return [f.open() for f in self.inputs] |
121 | 167 |
122 def open_outputs(self): | 168 def open_outputs(self): |
123 return [f.open() for f in self.outputs] | 169 return [f.open() for f in self.outputs] |
124 | 170 |
125 | 171 |
126 class FTDataSet(DataSet): | 172 class FTDataSet(DataSet): |
127 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None): | 173 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None, indtype=None, outdtype=None, inscale=1, outscale=1): |
128 r""" | 174 r""" |
129 Defines a DataSet from a bunch of files. | 175 Defines a DataSet from a bunch of files. |
130 | 176 |
131 Parameters: | 177 Parameters: |
132 `train_data` -- list of train data files | 178 `train_data` -- list of train data files |
134 `test_data`, `test_labels` -- same thing as train, but for | 180 `test_data`, `test_labels` -- same thing as train, but for |
135 test. The number of files | 181 test. The number of files |
136 can differ from train. | 182 can differ from train. |
137 `valid_data`, `valid_labels` -- same thing again for validation. | 183 `valid_data`, `valid_labels` -- same thing again for validation. |
138 (optional) | 184 (optional) |
185 `indtype`, `outdtype`, -- see FTSource.__init__() | |
186 `inscale`, `outscale` (optional) | |
187 | |
139 | 188 |
140 If `valid_data` and `valid_labels` are not supplied then a sample | 189 If `valid_data` and `valid_labels` are not supplied then a sample |
141 approximately equal in size to the test set is taken from the train | 190 approximately equal in size to the test set is taken from the train |
142 set. | 191 set. |
143 """ | 192 """ |
144 if valid_data is None: | 193 if valid_data is None: |
145 total_valid_size = sum(FTFile(td).size for td in test_data) | 194 total_valid_size = sum(FTFile(td).size for td in test_data) |
146 valid_size = total_valid_size/len(train_data) | 195 valid_size = total_valid_size/len(train_data) |
147 self._train = FTData(train_data, train_lbl, skip=valid_size) | 196 self._train = FTData(train_data, train_lbl, size=-valid_size) |
148 self._valid = FTData(train_data, train_lbl, size=valid_size) | 197 self._valid = FTData(train_data, train_lbl, skip=-valid_size) |
149 else: | 198 else: |
150 self._train = FTData(train_data, train_lbl) | 199 self._train = FTData(train_data, train_lbl) |
151 self._valid = FTData(valid_data, valid_lbl) | 200 self._valid = FTData(valid_data, valid_lbl) |
152 self._test = FTData(test_data, test_lbl) | 201 self._test = FTData(test_data, test_lbl) |
153 | 202 |
154 def _return_it(self, batchsize, bufsize, ftdata): | 203 def _return_it(self, batchsize, bufsize, ftdata): |
155 return zip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), | 204 return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), |
156 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) | 205 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) |