comparison datasets/dsetiter.py @ 163:4b28d7382dbf

Add inital implementation of datasets. For the moment only nist_digits is defined.
author Arnaud Bergeron <abergeron@gmail.com>
date Thu, 25 Feb 2010 18:40:01 -0500
parents
children 938bd350dbf0
comparison
equal deleted inserted replaced
162:050c7ff6b449 163:4b28d7382dbf
1 import numpy
2
3 class DummyFile(object):
4 def __init__(self, size):
5 self.size = size
6
7 def read(self, num):
8 if num > self.size:
9 num = self.size
10 self.size -= num
11 return numpy.zeros((num, 3, 2))
12
13 class DataIterator(object):
14
15 def __init__(self, files, batchsize, bufsize=None):
16 r"""
17 Makes an iterator which will read examples from `files`
18 and return them in `batchsize` lots.
19
20 Parameters:
21 files -- list of numpy readers
22 batchsize -- (int) the size of returned batches
23 bufsize -- (int, default=None) internal read buffer size.
24
25 Tests:
26 >>> d = DataIterator([DummyFile(930)], 10, 100)
27 >>> d.batchsize
28 10
29 >>> d.bufsize
30 100
31 >>> d = DataIterator([DummyFile(1)], 10)
32 >>> d.batchsize
33 10
34 >>> d.bufsize
35 10000
36 >>> d = DataIterator([DummyFile(1)], 99)
37 >>> d.batchsize
38 99
39 >>> d.bufsize
40 9999
41 >>> d = DataIterator([DummyFile(1)], 10, 121)
42 >>> d.batchsize
43 10
44 >>> d.bufsize
45 120
46 >>> d = DataIterator([DummyFile(1)], 10, 1)
47 >>> d.batchsize
48 10
49 >>> d.bufsize
50 10
51 >>> d = DataIterator([DummyFile(1)], 2000)
52 >>> d.batchsize
53 2000
54 >>> d.bufsize
55 20000
56 >>> d = DataIterator([DummyFile(1)], 2000, 31254)
57 >>> d.batchsize
58 2000
59 >>> d.bufsize
60 30000
61 >>> d = DataIterator([DummyFile(1)], 2000, 10)
62 >>> d.batchsize
63 2000
64 >>> d.bufsize
65 2000
66 """
67 self.batchsize = batchsize
68 if bufsize is None:
69 self.bufsize = max(10*batchsize, 10000)
70 else:
71 self.bufsize = bufsize
72 self.bufsize -= self.bufsize % self.batchsize
73 if self.bufsize < self.batchsize:
74 self.bufsize = self.batchsize
75 self.files = iter(files)
76 self.curfile = self.files.next()
77 self.empty = False
78 self._fill_buf()
79
80 def _fill_buf(self):
81 r"""
82 Fill the internal buffer.
83
84 Will fill across files in case the current one runs out.
85
86 Test:
87 >>> d = DataIterator([DummyFile(20)], 10, 10)
88 >>> d._fill_buf()
89 >>> d.curpos
90 0
91 >>> len(d.buffer)
92 10
93 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10)
94 >>> d._fill_buf()
95 >>> len(d.buffer)
96 10
97 >>> d._fill_buf()
98 Traceback (most recent call last):
99 ...
100 StopIteration
101 >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10)
102 >>> d._fill_buf()
103 >>> len(d.buffer)
104 9
105 >>> d._fill_buf()
106 Traceback (most recent call last):
107 ...
108 StopIteration
109 """
110 if self.empty:
111 raise StopIteration
112 self.buffer = self.curfile.read(self.bufsize)
113
114 while len(self.buffer) < self.bufsize:
115 try:
116 self.curfile = self.files.next()
117 except StopIteration:
118 self.empty = True
119 if len(self.buffer) == 0:
120 raise StopIteration
121 self.curpos = 0
122 return
123 tmpbuf = self.curfile.read(self.bufsize - len(self.buffer))
124 self.buffer = numpy.row_stack((self.buffer, tmpbuf))
125 self.curpos = 0
126
127 def __next__(self):
128 r"""
129 Returns the next portion of the dataset.
130
131 Test:
132 >>> d = DataIterator([DummyFile(20)], 10, 20)
133 >>> len(d.next())
134 10
135 >>> len(d.next())
136 10
137 >>> d.next()
138 Traceback (most recent call last):
139 ...
140 StopIteration
141 >>> d.next()
142 Traceback (most recent call last):
143 ...
144 StopIteration
145
146 """
147 if self.curpos >= self.bufsize:
148 self._fill_buf()
149 res = self.buffer[self.curpos:self.curpos+self.batchsize]
150 self.curpos += self.batchsize
151 return res
152
153 next = __next__
154
155 def __iter__(self):
156 return self