comparison datasets/ftfile.py @ 163:4b28d7382dbf

Add inital implementation of datasets. For the moment only nist_digits is defined.
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 25 Feb 2010 18:40:01 -0500
parents
children 954185d6002a
comparison
equal deleted inserted replaced
162:050c7ff6b449 163:4b28d7382dbf
1 from pylearn.io.filetensor import _read_header, _prod
2 import numpy
3 from dataset import DataSet
4 from dsetiter import DataIterator
5
6 class FTFile(object):
7 def __init__(self, fname):
8 r"""
9 Tests:
10 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
11 """
12 self.file = open(fname, 'rb')
13 self.magic_t, self.elsize, _, self.dim, _ = _read_header(self.file, False)
14 self.size = self.dim[0]
15
16 def skip(self, num):
17 r"""
18 Skips `num` items in the file.
19
20 Tests:
21 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
22 >>> f.size
23 58646
24 >>> f.elsize
25 4
26 >>> f.file.tell()
27 20
28 >>> f.skip(1000)
29 >>> f.file.tell()
30 4020
31 >>> f.size
32 57646
33 """
34 if num >= self.size:
35 self.size = 0
36 else:
37 self.size -= num
38 f_start = self.file.tell()
39 self.file.seek(f_start + (self.elsize * _prod(self.dim[1:]) * num))
40
41 def read(self, num):
42 r"""
43 Reads `num` elements from the file and return the result as a
44 numpy matrix. Last read is truncated.
45
46 Tests:
47 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_labels.ft')
48 >>> f.read(1)
49 array([6], dtype=int32)
50 >>> f.read(10)
51 array([7, 4, 7, 5, 6, 4, 8, 0, 9, 6], dtype=int32)
52 >>> f.skip(58630)
53 >>> f.read(10)
54 array([9, 2, 4, 2, 8], dtype=int32)
55 >>> f.read(10)
56 array([], dtype=int32)
57 >>> f = FTFile('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
58 >>> f.read(1)
59 array([[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)
60 """
61 if num > self.size:
62 num = self.size
63 self.dim[0] = num
64 self.size -= num
65 return numpy.fromfile(self.file, dtype=self.magic_t, count=_prod(self.dim)).reshape(self.dim)
66
67 class FTSource(object):
68 def __init__(self, file, skip=0, size=None):
69 r"""
70 Create a data source from a possible subset of a .ft file.
71
72 Parameters:
73 `file` (string) -- the filename
74 `skip` (int, optional) -- amount of examples to skip from the start of the file
75 `size` (int, optional) -- truncates number of examples read (after skipping)
76
77 Tests:
78 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
79 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1000)
80 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=10)
81 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=100, size=120)
82 """
83 self.file = file
84 self.skip = skip
85 self.size = size
86
87 def open(self):
88 r"""
89 Returns an FTFile that corresponds to this dataset.
90
91 Tests:
92 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft')
93 >>> f = s.open()
94 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', size=1)
95 >>> len(s.open().read(2))
96 1
97 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646)
98 >>> s.open().size
99 1000
100 >>> s = FTSource('/data/lisa/data/nist/by_class/digits/digits_test_data.ft', skip=57646, size=1)
101 >>> s.open().size
102 1
103 """
104 f = FTFile(self.file)
105 if self.skip != 0:
106 f.skip(self.skip)
107 if self.size is not None and self.size < f.size:
108 f.size = self.size
109 return f
110
111 class FTData(object):
112 r"""
113 This is a list of FTSources.
114 """
115 def __init__(self, datafiles, labelfiles, skip=0, size=None):
116 self.inputs = [FTSource(f, skip, size) for f in datafiles]
117 self.outputs = [FTSource(f, skip, size) for f in labelfiles]
118
119 def open_inputs(self):
120 return [f.open() for f in self.inputs]
121
122 def open_outputs(self):
123 return [f.open() for f in self.outputs]
124
125
126 class FTDataSet(DataSet):
127 def __init__(self, train_data, train_lbl, test_data, test_lbl, valid_data=None, valid_lbl=None):
128 r"""
129 Defines a DataSet from a bunch of files.
130
131 Parameters:
132 `train_data` -- list of train data files
133 `train_label` -- list of train label files (same length as `train_data`)
134 `test_data`, `test_labels` -- same thing as train, but for
135 test. The number of files
136 can differ from train.
137 `valid_data`, `valid_labels` -- same thing again for validation.
138 (optional)
139
140 If `valid_data` and `valid_labels` are not supplied then a sample
141 approximately equal in size to the test set is taken from the train
142 set.
143 """
144 if valid_data is None:
145 total_valid_size = sum(FTFile(td).size for td in test_data)
146 valid_size = total_valid_size/len(train_data)
147 self._train = FTData(train_data, train_lbl, skip=valid_size)
148 self._valid = FTData(train_data, train_lbl, size=valid_size)
149 else:
150 self._train = FTData(train_data, train_lbl)
151 self._valid = FTData(valid_data, valid_lbl)
152 self._test = FTData(test_data, test_lbl)
153
154 def _return_it(self, batchsize, bufsize, ftdata):
155 return zip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
156 DataIterator(ftdata.open_outputs(), batchsize, bufsize))