Mercurial > ift6266
changeset 302:1adfafdc3d57
Fix concatenation of 1-dim datasets (such as int target vectors).
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Tue, 30 Mar 2010 14:40:54 -0400 |
parents | be45e7db7cd4 |
children | ef28cbb5f464 a21a174c1c18 |
files | datasets/dsetiter.py |
diffstat | 1 files changed, 26 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/datasets/dsetiter.py Mon Mar 29 18:14:30 2010 -0400 +++ b/datasets/dsetiter.py Tue Mar 30 14:40:54 2010 -0400 @@ -1,14 +1,15 @@ import numpy class DummyFile(object): - def __init__(self, size): + def __init__(self, size, shape=()): self.size = size + self.shape = shape def read(self, num): if num > self.size: num = self.size self.size -= num - return numpy.zeros((num, 3, 2)) + return numpy.zeros((num,)+self.shape) class DataIterator(object): @@ -84,6 +85,28 @@ Will fill across files in case the current one runs out. Test: + >>> d = DataIterator([DummyFile(20, (3,2))], 10, 10) + >>> d._fill_buf() + >>> d.curpos + 0 + >>> len(d.buffer) + 10 + >>> d = DataIterator([DummyFile(11, (3,2)), DummyFile(9, (3,2))], 10, 10) + >>> d._fill_buf() + >>> len(d.buffer) + 10 + >>> d._fill_buf() + Traceback (most recent call last): + ... + StopIteration + >>> d = DataIterator([DummyFile(10, (3,2)), DummyFile(9, (3,2))], 10, 10) + >>> d._fill_buf() + >>> len(d.buffer) + 9 + >>> d._fill_buf() + Traceback (most recent call last): + ... + StopIteration >>> d = DataIterator([DummyFile(20)], 10, 10) >>> d._fill_buf() >>> d.curpos @@ -121,7 +144,7 @@ raise break tmpbuf = self.curfile.read(self.bufsize - len(buf)) - buf = numpy.row_stack((buf, tmpbuf)) + buf = numpy.concatenate([buf, tmpbuf], axis=0) self.cursize = len(buf) self.buffer = buf