# HG changeset patch # User Arnaud Bergeron # Date 1269974454 14400 # Node ID 1adfafdc3d5777bb02fc28cb13bdb86bff1ca2d9 # Parent be45e7db7cd4737063b5d4386e494767e7f411bd Fix concatenation of 1-dim datasets (such as int target vectors). diff -r be45e7db7cd4 -r 1adfafdc3d57 datasets/dsetiter.py --- a/datasets/dsetiter.py Mon Mar 29 18:14:30 2010 -0400 +++ b/datasets/dsetiter.py Tue Mar 30 14:40:54 2010 -0400 @@ -1,14 +1,15 @@ import numpy class DummyFile(object): - def __init__(self, size): + def __init__(self, size, shape=()): self.size = size + self.shape = shape def read(self, num): if num > self.size: num = self.size self.size -= num - return numpy.zeros((num, 3, 2)) + return numpy.zeros((num,)+self.shape) class DataIterator(object): @@ -84,6 +85,28 @@ Will fill across files in case the current one runs out. Test: + >>> d = DataIterator([DummyFile(20, (3,2))], 10, 10) + >>> d._fill_buf() + >>> d.curpos + 0 + >>> len(d.buffer) + 10 + >>> d = DataIterator([DummyFile(11, (3,2)), DummyFile(9, (3,2))], 10, 10) + >>> d._fill_buf() + >>> len(d.buffer) + 10 + >>> d._fill_buf() + Traceback (most recent call last): + ... + StopIteration + >>> d = DataIterator([DummyFile(10, (3,2)), DummyFile(9, (3,2))], 10, 10) + >>> d._fill_buf() + >>> len(d.buffer) + 9 + >>> d._fill_buf() + Traceback (most recent call last): + ... + StopIteration >>> d = DataIterator([DummyFile(20)], 10, 10) >>> d._fill_buf() >>> d.curpos @@ -121,7 +144,7 @@ raise break tmpbuf = self.curfile.read(self.bufsize - len(buf)) - buf = numpy.row_stack((buf, tmpbuf)) + buf = numpy.concatenate([buf, tmpbuf], axis=0) self.cursize = len(buf) self.buffer = buf