# HG changeset patch # User Arnaud Bergeron # Date 1267301342 18000 # Node ID 938bd350dbf0b7af957f7adb3582796c190ef6c8 # Parent be714ac9bcbd12b01bf5ef70f5811f52e64713f0 Make the datasets iterators return theano shared slices with the appropriate types. diff -r be714ac9bcbd -r 938bd350dbf0 datasets/__init__.py --- a/datasets/__init__.py Sat Feb 27 14:15:11 2010 -0500 +++ b/datasets/__init__.py Sat Feb 27 15:09:02 2010 -0500 @@ -1,2 +1,1 @@ from defs import * - diff -r be714ac9bcbd -r 938bd350dbf0 datasets/dsetiter.py --- a/datasets/dsetiter.py Sat Feb 27 14:15:11 2010 -0500 +++ b/datasets/dsetiter.py Sat Feb 27 15:09:02 2010 -0500 @@ -1,4 +1,4 @@ -import numpy +import numpy, theano class DummyFile(object): def __init__(self, size): @@ -88,11 +88,11 @@ >>> d._fill_buf() >>> d.curpos 0 - >>> len(d.buffer) + >>> len(d.buffer.value) 10 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) >>> d._fill_buf() - >>> len(d.buffer) + >>> len(d.buffer.value) 10 >>> d._fill_buf() Traceback (most recent call last): @@ -100,28 +100,30 @@ StopIteration >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) >>> d._fill_buf() - >>> len(d.buffer) + >>> len(d.buffer.value) 9 >>> d._fill_buf() Traceback (most recent call last): ... StopIteration """ + self.buffer = None if self.empty: raise StopIteration - self.buffer = self.curfile.read(self.bufsize) + buf = self.curfile.read(self.bufsize) - while len(self.buffer) < self.bufsize: + while len(buf) < self.bufsize: try: self.curfile = self.files.next() except StopIteration: self.empty = True - if len(self.buffer) == 0: - raise StopIteration - self.curpos = 0 - return - tmpbuf = self.curfile.read(self.bufsize - len(self.buffer)) - self.buffer = numpy.row_stack((self.buffer, tmpbuf)) + if len(buf) == 0: + raise + break + tmpbuf = self.curfile.read(self.bufsize - len(buf)) + buf = numpy.row_stack((buf, tmpbuf)) + + self.buffer = theano.shared(numpy.asarray(buf, dtype=theano.config.floatX)) self.curpos = 0 def __next__(self): @@ -130,10 +132,10 @@ Test: >>> d = DataIterator([DummyFile(20)], 10, 20) - >>> len(d.next()) - 10 - >>> len(d.next()) - 10 + >>> d.next() + Subtensor{0:10:}.0 + >>> d.next() + Subtensor{10:20:}.0 >>> d.next() Traceback (most recent call last): ... diff -r be714ac9bcbd -r 938bd350dbf0 datasets/ftfile.py --- a/datasets/ftfile.py Sat Feb 27 14:15:11 2010 -0500 +++ b/datasets/ftfile.py Sat Feb 27 15:09:02 2010 -0500 @@ -1,8 +1,8 @@ from pylearn.io.filetensor import _read_header, _prod -import numpy +import numpy, theano from dataset import DataSet from dsetiter import DataIterator -from itertools import izip +from itertools import izip, imap class FTFile(object): def __init__(self, fname): @@ -182,4 +182,5 @@ def _return_it(self, batchsize, bufsize, ftdata): return izip(DataIterator(ftdata.open_inputs(), batchsize, bufsize), - DataIterator(ftdata.open_outputs(), batchsize, bufsize)) + imap(lambda b: theano.tensor.cast(b, 'int32'), + DataIterator(ftdata.open_outputs(), batchsize, bufsize)))