diff datasets/dsetiter.py @ 178:938bd350dbf0

Make the datasets iterators return theano shared slices with the appropriate types.
author Arnaud Bergeron <abergeron@gmail.com>
date Sat, 27 Feb 2010 15:09:02 -0500
parents 4b28d7382dbf
children defd388aba0c
line wrap: on
line diff
--- a/datasets/dsetiter.py	Sat Feb 27 14:15:11 2010 -0500
+++ b/datasets/dsetiter.py	Sat Feb 27 15:09:02 2010 -0500
@@ -1,4 +1,4 @@
-import numpy
+import numpy, theano
 
 class DummyFile(object):
     def __init__(self, size):
@@ -88,11 +88,11 @@
             >>> d._fill_buf()
             >>> d.curpos
             0
-            >>> len(d.buffer)
+            >>> len(d.buffer.value)
             10
             >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10)
             >>> d._fill_buf()
-            >>> len(d.buffer)
+            >>> len(d.buffer.value)
             10
             >>> d._fill_buf()
             Traceback (most recent call last):
@@ -100,28 +100,30 @@
             StopIteration
             >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10)
             >>> d._fill_buf()
-            >>> len(d.buffer)
+            >>> len(d.buffer.value)
             9
             >>> d._fill_buf()
             Traceback (most recent call last):
               ...
             StopIteration
         """
+        self.buffer = None
         if self.empty:
             raise StopIteration
-        self.buffer = self.curfile.read(self.bufsize)
+        buf = self.curfile.read(self.bufsize)
         
-        while len(self.buffer) < self.bufsize:
+        while len(buf) < self.bufsize:
             try:
                 self.curfile = self.files.next()
             except StopIteration:
                 self.empty = True
-                if len(self.buffer) == 0:
-                    raise StopIteration
-                self.curpos = 0
-                return
-            tmpbuf = self.curfile.read(self.bufsize - len(self.buffer))
-            self.buffer = numpy.row_stack((self.buffer, tmpbuf))
+                if len(buf) == 0:
+                    raise
+                break
+            tmpbuf = self.curfile.read(self.bufsize - len(buf))
+            buf = numpy.row_stack((buf, tmpbuf))
+
+        self.buffer = theano.shared(numpy.asarray(buf, dtype=theano.config.floatX))
         self.curpos = 0
 
     def __next__(self):
@@ -130,10 +132,10 @@
 
         Test:
             >>> d = DataIterator([DummyFile(20)], 10, 20)
-            >>> len(d.next())
-            10
-            >>> len(d.next())
-            10
+            >>> d.next()
+            Subtensor{0:10:}.0
+            >>> d.next()
+            Subtensor{10:20:}.0
             >>> d.next()
             Traceback (most recent call last):
               ...