comparison datasets/dsetiter.py @ 178:938bd350dbf0

Make the datasets iterators return theano shared slices with the appropriate types.
author Arnaud Bergeron <abergeron@gmail.com>
date Sat, 27 Feb 2010 15:09:02 -0500
parents 4b28d7382dbf
children defd388aba0c
comparison
equal deleted inserted replaced
177:be714ac9bcbd 178:938bd350dbf0
1 import numpy 1 import numpy, theano
2 2
3 class DummyFile(object): 3 class DummyFile(object):
4 def __init__(self, size): 4 def __init__(self, size):
5 self.size = size 5 self.size = size
6 6
86 Test: 86 Test:
87 >>> d = DataIterator([DummyFile(20)], 10, 10) 87 >>> d = DataIterator([DummyFile(20)], 10, 10)
88 >>> d._fill_buf() 88 >>> d._fill_buf()
89 >>> d.curpos 89 >>> d.curpos
90 0 90 0
91 >>> len(d.buffer) 91 >>> len(d.buffer.value)
92 10 92 10
93 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) 93 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10)
94 >>> d._fill_buf() 94 >>> d._fill_buf()
95 >>> len(d.buffer) 95 >>> len(d.buffer.value)
96 10 96 10
97 >>> d._fill_buf() 97 >>> d._fill_buf()
98 Traceback (most recent call last): 98 Traceback (most recent call last):
99 ... 99 ...
100 StopIteration 100 StopIteration
101 >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) 101 >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10)
102 >>> d._fill_buf() 102 >>> d._fill_buf()
103 >>> len(d.buffer) 103 >>> len(d.buffer.value)
104 9 104 9
105 >>> d._fill_buf() 105 >>> d._fill_buf()
106 Traceback (most recent call last): 106 Traceback (most recent call last):
107 ... 107 ...
108 StopIteration 108 StopIteration
109 """ 109 """
110 self.buffer = None
110 if self.empty: 111 if self.empty:
111 raise StopIteration 112 raise StopIteration
112 self.buffer = self.curfile.read(self.bufsize) 113 buf = self.curfile.read(self.bufsize)
113 114
114 while len(self.buffer) < self.bufsize: 115 while len(buf) < self.bufsize:
115 try: 116 try:
116 self.curfile = self.files.next() 117 self.curfile = self.files.next()
117 except StopIteration: 118 except StopIteration:
118 self.empty = True 119 self.empty = True
119 if len(self.buffer) == 0: 120 if len(buf) == 0:
120 raise StopIteration 121 raise
121 self.curpos = 0 122 break
122 return 123 tmpbuf = self.curfile.read(self.bufsize - len(buf))
123 tmpbuf = self.curfile.read(self.bufsize - len(self.buffer)) 124 buf = numpy.row_stack((buf, tmpbuf))
124 self.buffer = numpy.row_stack((self.buffer, tmpbuf)) 125
126 self.buffer = theano.shared(numpy.asarray(buf, dtype=theano.config.floatX))
125 self.curpos = 0 127 self.curpos = 0
126 128
127 def __next__(self): 129 def __next__(self):
128 r""" 130 r"""
129 Returns the next portion of the dataset. 131 Returns the next portion of the dataset.
130 132
131 Test: 133 Test:
132 >>> d = DataIterator([DummyFile(20)], 10, 20) 134 >>> d = DataIterator([DummyFile(20)], 10, 20)
133 >>> len(d.next()) 135 >>> d.next()
134 10 136 Subtensor{0:10:}.0
135 >>> len(d.next()) 137 >>> d.next()
136 10 138 Subtensor{10:20:}.0
137 >>> d.next() 139 >>> d.next()
138 Traceback (most recent call last): 140 Traceback (most recent call last):
139 ... 141 ...
140 StopIteration 142 StopIteration
141 >>> d.next() 143 >>> d.next()