changeset 178:938bd350dbf0

Make the datasets iterators return theano shared slices with the appropriate types.
author Arnaud Bergeron <abergeron@gmail.com>
date Sat, 27 Feb 2010 15:09:02 -0500
parents be714ac9bcbd
children defd388aba0c
files datasets/__init__.py datasets/dsetiter.py datasets/ftfile.py
diffstat 3 files changed, 22 insertions(+), 20 deletions(-) [+]
line wrap: on
line diff
--- a/datasets/__init__.py	Sat Feb 27 14:15:11 2010 -0500
+++ b/datasets/__init__.py	Sat Feb 27 15:09:02 2010 -0500
@@ -1,2 +1,1 @@
 from defs import *
-
--- a/datasets/dsetiter.py	Sat Feb 27 14:15:11 2010 -0500
+++ b/datasets/dsetiter.py	Sat Feb 27 15:09:02 2010 -0500
@@ -1,4 +1,4 @@
-import numpy
+import numpy, theano
 
 class DummyFile(object):
     def __init__(self, size):
@@ -88,11 +88,11 @@
             >>> d._fill_buf()
             >>> d.curpos
             0
-            >>> len(d.buffer)
+            >>> len(d.buffer.value)
             10
             >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10)
             >>> d._fill_buf()
-            >>> len(d.buffer)
+            >>> len(d.buffer.value)
             10
             >>> d._fill_buf()
             Traceback (most recent call last):
@@ -100,28 +100,30 @@
             StopIteration
             >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10)
             >>> d._fill_buf()
-            >>> len(d.buffer)
+            >>> len(d.buffer.value)
             9
             >>> d._fill_buf()
             Traceback (most recent call last):
               ...
             StopIteration
         """
+        self.buffer = None
         if self.empty:
             raise StopIteration
-        self.buffer = self.curfile.read(self.bufsize)
+        buf = self.curfile.read(self.bufsize)
         
-        while len(self.buffer) < self.bufsize:
+        while len(buf) < self.bufsize:
             try:
                 self.curfile = self.files.next()
             except StopIteration:
                 self.empty = True
-                if len(self.buffer) == 0:
-                    raise StopIteration
-                self.curpos = 0
-                return
-            tmpbuf = self.curfile.read(self.bufsize - len(self.buffer))
-            self.buffer = numpy.row_stack((self.buffer, tmpbuf))
+                if len(buf) == 0:
+                    raise
+                break
+            tmpbuf = self.curfile.read(self.bufsize - len(buf))
+            buf = numpy.row_stack((buf, tmpbuf))
+
+        self.buffer = theano.shared(numpy.asarray(buf, dtype=theano.config.floatX))
         self.curpos = 0
 
     def __next__(self):
@@ -130,10 +132,10 @@
 
         Test:
             >>> d = DataIterator([DummyFile(20)], 10, 20)
-            >>> len(d.next())
-            10
-            >>> len(d.next())
-            10
+            >>> d.next()
+            Subtensor{0:10:}.0
+            >>> d.next()
+            Subtensor{10:20:}.0
             >>> d.next()
             Traceback (most recent call last):
               ...
--- a/datasets/ftfile.py	Sat Feb 27 14:15:11 2010 -0500
+++ b/datasets/ftfile.py	Sat Feb 27 15:09:02 2010 -0500
@@ -1,8 +1,8 @@
 from pylearn.io.filetensor import _read_header, _prod
-import numpy
+import numpy, theano
 from dataset import DataSet
 from dsetiter import DataIterator
-from itertools import izip
+from itertools import izip, imap
 
 class FTFile(object):
     def __init__(self, fname):
@@ -182,4 +182,5 @@
 
     def _return_it(self, batchsize, bufsize, ftdata):
         return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize),
-                   DataIterator(ftdata.open_outputs(), batchsize, bufsize))
+                   imap(lambda b: theano.tensor.cast(b, 'int32'), 
+                        DataIterator(ftdata.open_outputs(), batchsize, bufsize)))