changeset 302:1adfafdc3d57

Fix concatenation of 1-dim datasets (such as int target vectors).
author Arnaud Bergeron <abergeron@gmail.com>
date Tue, 30 Mar 2010 14:40:54 -0400
parents be45e7db7cd4
children ef28cbb5f464 a21a174c1c18
files datasets/dsetiter.py
diffstat 1 files changed, 26 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/datasets/dsetiter.py	Mon Mar 29 18:14:30 2010 -0400
+++ b/datasets/dsetiter.py	Tue Mar 30 14:40:54 2010 -0400
@@ -1,14 +1,15 @@
 import numpy
 
 class DummyFile(object):
-    def __init__(self, size):
+    def __init__(self, size, shape=()):
         self.size = size
+        self.shape = shape
 
     def read(self, num):
         if num > self.size:
             num = self.size
         self.size -= num
-        return numpy.zeros((num, 3, 2))
+        return numpy.zeros((num,)+self.shape)
 
 class DataIterator(object):
     
@@ -84,6 +85,28 @@
         Will fill across files in case the current one runs out.
 
         Test:
+            >>> d = DataIterator([DummyFile(20, (3,2))], 10, 10)
+            >>> d._fill_buf()
+            >>> d.curpos
+            0
+            >>> len(d.buffer)
+            10
+            >>> d = DataIterator([DummyFile(11, (3,2)), DummyFile(9, (3,2))], 10, 10)
+            >>> d._fill_buf()
+            >>> len(d.buffer)
+            10
+            >>> d._fill_buf()
+            Traceback (most recent call last):
+              ...
+            StopIteration
+            >>> d = DataIterator([DummyFile(10, (3,2)), DummyFile(9, (3,2))], 10, 10)
+            >>> d._fill_buf()
+            >>> len(d.buffer)
+            9
+            >>> d._fill_buf()
+            Traceback (most recent call last):
+              ...
+            StopIteration
             >>> d = DataIterator([DummyFile(20)], 10, 10)
             >>> d._fill_buf()
             >>> d.curpos
@@ -121,7 +144,7 @@
                     raise
                 break
             tmpbuf = self.curfile.read(self.bufsize - len(buf))
-            buf = numpy.row_stack((buf, tmpbuf))
+            buf = numpy.concatenate([buf, tmpbuf], axis=0)
 
         self.cursize = len(buf)
         self.buffer = buf