Mercurial > ift6266
diff datasets/dsetiter.py @ 163:4b28d7382dbf
Add inital implementation of datasets.
For the moment only nist_digits is defined.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Thu, 25 Feb 2010 18:40:01 -0500 |
parents | |
children | 938bd350dbf0 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/datasets/dsetiter.py Thu Feb 25 18:40:01 2010 -0500 @@ -0,0 +1,156 @@ +import numpy + +class DummyFile(object): + def __init__(self, size): + self.size = size + + def read(self, num): + if num > self.size: + num = self.size + self.size -= num + return numpy.zeros((num, 3, 2)) + +class DataIterator(object): + + def __init__(self, files, batchsize, bufsize=None): + r""" + Makes an iterator which will read examples from `files` + and return them in `batchsize` lots. + + Parameters: + files -- list of numpy readers + batchsize -- (int) the size of returned batches + bufsize -- (int, default=None) internal read buffer size. + + Tests: + >>> d = DataIterator([DummyFile(930)], 10, 100) + >>> d.batchsize + 10 + >>> d.bufsize + 100 + >>> d = DataIterator([DummyFile(1)], 10) + >>> d.batchsize + 10 + >>> d.bufsize + 10000 + >>> d = DataIterator([DummyFile(1)], 99) + >>> d.batchsize + 99 + >>> d.bufsize + 9999 + >>> d = DataIterator([DummyFile(1)], 10, 121) + >>> d.batchsize + 10 + >>> d.bufsize + 120 + >>> d = DataIterator([DummyFile(1)], 10, 1) + >>> d.batchsize + 10 + >>> d.bufsize + 10 + >>> d = DataIterator([DummyFile(1)], 2000) + >>> d.batchsize + 2000 + >>> d.bufsize + 20000 + >>> d = DataIterator([DummyFile(1)], 2000, 31254) + >>> d.batchsize + 2000 + >>> d.bufsize + 30000 + >>> d = DataIterator([DummyFile(1)], 2000, 10) + >>> d.batchsize + 2000 + >>> d.bufsize + 2000 + """ + self.batchsize = batchsize + if bufsize is None: + self.bufsize = max(10*batchsize, 10000) + else: + self.bufsize = bufsize + self.bufsize -= self.bufsize % self.batchsize + if self.bufsize < self.batchsize: + self.bufsize = self.batchsize + self.files = iter(files) + self.curfile = self.files.next() + self.empty = False + self._fill_buf() + + def _fill_buf(self): + r""" + Fill the internal buffer. + + Will fill across files in case the current one runs out. + + Test: + >>> d = DataIterator([DummyFile(20)], 10, 10) + >>> d._fill_buf() + >>> d.curpos + 0 + >>> len(d.buffer) + 10 + >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) + >>> d._fill_buf() + >>> len(d.buffer) + 10 + >>> d._fill_buf() + Traceback (most recent call last): + ... + StopIteration + >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) + >>> d._fill_buf() + >>> len(d.buffer) + 9 + >>> d._fill_buf() + Traceback (most recent call last): + ... + StopIteration + """ + if self.empty: + raise StopIteration + self.buffer = self.curfile.read(self.bufsize) + + while len(self.buffer) < self.bufsize: + try: + self.curfile = self.files.next() + except StopIteration: + self.empty = True + if len(self.buffer) == 0: + raise StopIteration + self.curpos = 0 + return + tmpbuf = self.curfile.read(self.bufsize - len(self.buffer)) + self.buffer = numpy.row_stack((self.buffer, tmpbuf)) + self.curpos = 0 + + def __next__(self): + r""" + Returns the next portion of the dataset. + + Test: + >>> d = DataIterator([DummyFile(20)], 10, 20) + >>> len(d.next()) + 10 + >>> len(d.next()) + 10 + >>> d.next() + Traceback (most recent call last): + ... + StopIteration + >>> d.next() + Traceback (most recent call last): + ... + StopIteration + + """ + if self.curpos >= self.bufsize: + self._fill_buf() + res = self.buffer[self.curpos:self.curpos+self.batchsize] + self.curpos += self.batchsize + return res + + next = __next__ + + def __iter__(self): + return self