Mercurial > ift6266
comparison datasets/ftfile.py @ 163:4b28d7382dbf
Add inital implementation of datasets.
For the moment only nist_digits is defined.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 25 Feb 2010 18:40:01 -0500 |
parents | |
children | 954185d6002a |
comparison
equal
deleted
inserted
replaced
162:050c7ff6b449 | 163:4b28d7382dbf |
---|---|
1 from pylearn.io.filetensor import _read_header, _prod | |
2 import numpy | |
3 from dataset import DataSet | |
4 from dsetiter import DataIterator | |
5 | |
6 class FTFile(object): | |
7 def __init__(self, fname): | |
8 r""" | |
9 Tests: | |
10 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') | |
11 """ | |
12 self.file = open(fname, 'rb') | |
13 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False) | |
14 self.size = self.dim[0] | |
15 | |
16 def skip(self, num): | |
17 r""" | |
18 Skips `num` items in the file. | |
19 | |
20 Tests: | |
21 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') | |
22 >>> f.size | |
23 58646 | |
24 >>> f.elsize | |
25 4 | |
26 >>> f.file.tell() | |
27 20 | |
28 >>> f.skip(1000) | |
29 >>> f.file.tell() | |
30 4020 | |
31 >>> f.size | |
32 57646 | |
33 """ | |
34 if num >= self.size: | |
35 self.size = 0 | |
36 else: | |
37 self.size -= num | |
38 f_start = self.file.tell() | |
39 self.file.seek(f_start + (self.elsize * _prod(self.dim[1:]) * num)) | |
40 | |
41 def read(self, num): | |
42 r""" | |
43 Reads `num` elements from the file and return the result as a | |
44 numpy matrix. Last read is truncated. | |
45 | |
46 Tests: | |
47 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft') | |
48 >>> f.read(1) | |
49 array([6], dtype=int32) | |
50 >>> f.read(10) | |
51 array([7, 4, 7, 5, 6, 4, 8, 0, 9, 6], dtype=int32) | |
52 >>> f.skip(58630) | |
53 >>> f.read(10) | |
54 array([9, 2, 4, 2, 8], dtype=int32) | |
55 >>> f.read(10) | |
56 array([], dtype=int32) | |
57 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | |
58 >>> f.read(1) | |
59 array([[0, 0, 0, ..., 0, 0, 0]], dtype=uint8) | |
60 """ | |
61 if num > self.size: | |
62 num = self.size | |
63 self.dim[0] = num | |
64 self.size -= num | |
65 return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim) | |
66 | |
67 class FTSource(object): | |
68 def __init__(self, file, skip=0, size=None): | |
69 r""" | |
70 Create a data source from a possible subset of a .ft file. | |
71 | |
72 Parameters: | |
73 `file` (string) -- the filename | |
74 `skip` (int, optional) -- amount of examples to skip from the start of the file | |
75 `size` (int, optional) -- truncates number of examples read (after skipping) | |
76 | |
77 Tests: | |
78 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | |
79 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000) | |
80 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10) | |
81 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120) | |
82 """ | |
83 self.file = file | |
84 self.skip = skip | |
85 self.size = size | |
86 | |
87 def open(self): | |
88 r""" | |
89 Returns an FTFile that corresponds to this dataset. | |
90 | |
91 Tests: | |
92 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft') | |
93 >>> f = s.open() | |
94 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1) | |
95 >>> len(s.open().read(2)) | |
96 1 | |
97 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646) | |
98 >>> s.open().size | |
99 1000 | |
100 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1) | |
101 >>> s.open().size | |
102 1 | |
103 """ | |
104 f = FTFile(self.file) | |
105 if self.skip != 0: | |
106 f.skip(self.skip) | |
107 if self.size is not None and self.size < f.size: | |
108 f.size = self.size | |
109 return f | |
110 | |
111 class FTData(object): | |
112 r""" | |
113 This is a list of FTSources. | |
114 """ | |
115 def __init__(self, datafiles, labelfiles, skip=0, size=None): | |
116 self.inputs = [FTSource(f, skip, size) for f in datafiles] | |
117 self.outputs = [FTSource(f, skip, size) for f in labelfiles] | |
118 | |
119 def open_inputs(self): | |
120 return [f.open() for f in self.inputs] | |
121 | |
122 def open_outputs(self): | |
123 return [f.open() for f in self.outputs] | |
124 | |
125 | |
126 class FTDataSet(DataSet): | |
127 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None): | |
128 r""" | |
129 Defines a DataSet from a bunch of files. | |
130 | |
131 Parameters: | |
132 `train_data` -- list of train data files | |
133 `train_label` -- list of train label files (same length as `train_data`) | |
134 `test_data`, `test_labels` -- same thing as train, but for | |
135 test. The number of files | |
136 can differ from train. | |
137 `valid_data`, `valid_labels` -- same thing again for validation. | |
138 (optional) | |
139 | |
140 If `valid_data` and `valid_labels` are not supplied then a sample | |
141 approximately equal in size to the test set is taken from the train | |
142 set. | |
143 """ | |
144 if valid_data is None: | |
145 total_valid_size = sum(FTFile(td).size for td in test_data) | |
146 valid_size = total_valid_size/len(train_data) | |
147 self._train = FTData(train_data, train_lbl, skip=valid_size) | |
148 self._valid = FTData(train_data, train_lbl, size=valid_size) | |
149 else: | |
150 self._train = FTData(train_data, train_lbl) | |
151 self._valid = FTData(valid_data, valid_lbl) | |
152 self._test = FTData(test_data, test_lbl) | |
153 | |
154 def _return_it(self, batchsize, bufsize, ftdata): | |
155 return zip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), | |
156 DataIterator(ftdata.open_outputs(), batchsize, bufsize)) |