Mercurial > ift6266
view datasets/dsetiter.py @ 178:938bd350dbf0
Make the datasets iterators return theano shared slices with the appropriate types.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Sat, 27 Feb 2010 15:09:02 -0500 |
parents | 4b28d7382dbf |
children | defd388aba0c |
line wrap: on
line source
import numpy, theano class DummyFile(object): def __init__(self, size): self.size = size def read(self, num): if num > self.size: num = self.size self.size -= num return numpy.zeros((num, 3, 2)) class DataIterator(object): def __init__(self, files, batchsize, bufsize=None): r""" Makes an iterator which will read examples from `files` and return them in `batchsize` lots. Parameters: files -- list of numpy readers batchsize -- (int) the size of returned batches bufsize -- (int, default=None) internal read buffer size. Tests: >>> d = DataIterator([DummyFile(930)], 10, 100) >>> d.batchsize 10 >>> d.bufsize 100 >>> d = DataIterator([DummyFile(1)], 10) >>> d.batchsize 10 >>> d.bufsize 10000 >>> d = DataIterator([DummyFile(1)], 99) >>> d.batchsize 99 >>> d.bufsize 9999 >>> d = DataIterator([DummyFile(1)], 10, 121) >>> d.batchsize 10 >>> d.bufsize 120 >>> d = DataIterator([DummyFile(1)], 10, 1) >>> d.batchsize 10 >>> d.bufsize 10 >>> d = DataIterator([DummyFile(1)], 2000) >>> d.batchsize 2000 >>> d.bufsize 20000 >>> d = DataIterator([DummyFile(1)], 2000, 31254) >>> d.batchsize 2000 >>> d.bufsize 30000 >>> d = DataIterator([DummyFile(1)], 2000, 10) >>> d.batchsize 2000 >>> d.bufsize 2000 """ self.batchsize = batchsize if bufsize is None: self.bufsize = max(10*batchsize, 10000) else: self.bufsize = bufsize self.bufsize -= self.bufsize % self.batchsize if self.bufsize < self.batchsize: self.bufsize = self.batchsize self.files = iter(files) self.curfile = self.files.next() self.empty = False self._fill_buf() def _fill_buf(self): r""" Fill the internal buffer. Will fill across files in case the current one runs out. Test: >>> d = DataIterator([DummyFile(20)], 10, 10) >>> d._fill_buf() >>> d.curpos 0 >>> len(d.buffer.value) 10 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) >>> d._fill_buf() >>> len(d.buffer.value) 10 >>> d._fill_buf() Traceback (most recent call last): ... StopIteration >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) >>> d._fill_buf() >>> len(d.buffer.value) 9 >>> d._fill_buf() Traceback (most recent call last): ... StopIteration """ self.buffer = None if self.empty: raise StopIteration buf = self.curfile.read(self.bufsize) while len(buf) < self.bufsize: try: self.curfile = self.files.next() except StopIteration: self.empty = True if len(buf) == 0: raise break tmpbuf = self.curfile.read(self.bufsize - len(buf)) buf = numpy.row_stack((buf, tmpbuf)) self.buffer = theano.shared(numpy.asarray(buf, dtype=theano.config.floatX)) self.curpos = 0 def __next__(self): r""" Returns the next portion of the dataset. Test: >>> d = DataIterator([DummyFile(20)], 10, 20) >>> d.next() Subtensor{0:10:}.0 >>> d.next() Subtensor{10:20:}.0 >>> d.next() Traceback (most recent call last): ... StopIteration >>> d.next() Traceback (most recent call last): ... StopIteration """ if self.curpos >= self.bufsize: self._fill_buf() res = self.buffer[self.curpos:self.curpos+self.batchsize] self.curpos += self.batchsize return res next = __next__ def __iter__(self): return self