Mercurial > ift6266
comparison datasets/dsetiter.py @ 163:4b28d7382dbf
Add inital implementation of datasets.
For the moment only nist_digits is defined.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 25 Feb 2010 18:40:01 -0500 |
parents | |
children | 938bd350dbf0 |
comparison
equal
deleted
inserted
replaced
162:050c7ff6b449 | 163:4b28d7382dbf |
---|---|
1 import numpy | |
2 | |
3 class DummyFile(object): | |
4 def __init__(self, size): | |
5 self.size = size | |
6 | |
7 def read(self, num): | |
8 if num > self.size: | |
9 num = self.size | |
10 self.size -= num | |
11 return numpy.zeros((num, 3, 2)) | |
12 | |
13 class DataIterator(object): | |
14 | |
15 def __init__(self, files, batchsize, bufsize=None): | |
16 r""" | |
17 Makes an iterator which will read examples from `files` | |
18 and return them in `batchsize` lots. | |
19 | |
20 Parameters: | |
21 files -- list of numpy readers | |
22 batchsize -- (int) the size of returned batches | |
23 bufsize -- (int, default=None) internal read buffer size. | |
24 | |
25 Tests: | |
26 >>> d = DataIterator([DummyFile(930)], 10, 100) | |
27 >>> d.batchsize | |
28 10 | |
29 >>> d.bufsize | |
30 100 | |
31 >>> d = DataIterator([DummyFile(1)], 10) | |
32 >>> d.batchsize | |
33 10 | |
34 >>> d.bufsize | |
35 10000 | |
36 >>> d = DataIterator([DummyFile(1)], 99) | |
37 >>> d.batchsize | |
38 99 | |
39 >>> d.bufsize | |
40 9999 | |
41 >>> d = DataIterator([DummyFile(1)], 10, 121) | |
42 >>> d.batchsize | |
43 10 | |
44 >>> d.bufsize | |
45 120 | |
46 >>> d = DataIterator([DummyFile(1)], 10, 1) | |
47 >>> d.batchsize | |
48 10 | |
49 >>> d.bufsize | |
50 10 | |
51 >>> d = DataIterator([DummyFile(1)], 2000) | |
52 >>> d.batchsize | |
53 2000 | |
54 >>> d.bufsize | |
55 20000 | |
56 >>> d = DataIterator([DummyFile(1)], 2000, 31254) | |
57 >>> d.batchsize | |
58 2000 | |
59 >>> d.bufsize | |
60 30000 | |
61 >>> d = DataIterator([DummyFile(1)], 2000, 10) | |
62 >>> d.batchsize | |
63 2000 | |
64 >>> d.bufsize | |
65 2000 | |
66 """ | |
67 self.batchsize = batchsize | |
68 if bufsize is None: | |
69 self.bufsize = max(10*batchsize, 10000) | |
70 else: | |
71 self.bufsize = bufsize | |
72 self.bufsize -= self.bufsize % self.batchsize | |
73 if self.bufsize < self.batchsize: | |
74 self.bufsize = self.batchsize | |
75 self.files = iter(files) | |
76 self.curfile = self.files.next() | |
77 self.empty = False | |
78 self._fill_buf() | |
79 | |
80 def _fill_buf(self): | |
81 r""" | |
82 Fill the internal buffer. | |
83 | |
84 Will fill across files in case the current one runs out. | |
85 | |
86 Test: | |
87 >>> d = DataIterator([DummyFile(20)], 10, 10) | |
88 >>> d._fill_buf() | |
89 >>> d.curpos | |
90 0 | |
91 >>> len(d.buffer) | |
92 10 | |
93 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) | |
94 >>> d._fill_buf() | |
95 >>> len(d.buffer) | |
96 10 | |
97 >>> d._fill_buf() | |
98 Traceback (most recent call last): | |
99 ... | |
100 StopIteration | |
101 >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) | |
102 >>> d._fill_buf() | |
103 >>> len(d.buffer) | |
104 9 | |
105 >>> d._fill_buf() | |
106 Traceback (most recent call last): | |
107 ... | |
108 StopIteration | |
109 """ | |
110 if self.empty: | |
111 raise StopIteration | |
112 self.buffer = self.curfile.read(self.bufsize) | |
113 | |
114 while len(self.buffer) < self.bufsize: | |
115 try: | |
116 self.curfile = self.files.next() | |
117 except StopIteration: | |
118 self.empty = True | |
119 if len(self.buffer) == 0: | |
120 raise StopIteration | |
121 self.curpos = 0 | |
122 return | |
123 tmpbuf = self.curfile.read(self.bufsize - len(self.buffer)) | |
124 self.buffer = numpy.row_stack((self.buffer, tmpbuf)) | |
125 self.curpos = 0 | |
126 | |
127 def __next__(self): | |
128 r""" | |
129 Returns the next portion of the dataset. | |
130 | |
131 Test: | |
132 >>> d = DataIterator([DummyFile(20)], 10, 20) | |
133 >>> len(d.next()) | |
134 10 | |
135 >>> len(d.next()) | |
136 10 | |
137 >>> d.next() | |
138 Traceback (most recent call last): | |
139 ... | |
140 StopIteration | |
141 >>> d.next() | |
142 Traceback (most recent call last): | |
143 ... | |
144 StopIteration | |
145 | |
146 """ | |
147 if self.curpos >= self.bufsize: | |
148 self._fill_buf() | |
149 res = self.buffer[self.curpos:self.curpos+self.batchsize] | |
150 self.curpos += self.batchsize | |
151 return res | |
152 | |
153 next = __next__ | |
154 | |
155 def __iter__(self): | |
156 return self |