comparison datasets/dsetiter.py @ 179:defd388aba0c

Do not yield theano shared variables. They can only be used by theano.function().
author Arnaud Bergeron <abergeron@gmail.com>
date Sat, 27 Feb 2010 16:07:09 -0500
parents 938bd350dbf0
children 76bc047df5ee
comparison
equal deleted inserted replaced
178:938bd350dbf0 179:defd388aba0c
86 Test: 86 Test:
87 >>> d = DataIterator([DummyFile(20)], 10, 10) 87 >>> d = DataIterator([DummyFile(20)], 10, 10)
88 >>> d._fill_buf() 88 >>> d._fill_buf()
89 >>> d.curpos 89 >>> d.curpos
90 0 90 0
91 >>> len(d.buffer.value) 91 >>> len(d.buffer)
92 10 92 10
93 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) 93 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10)
94 >>> d._fill_buf() 94 >>> d._fill_buf()
95 >>> len(d.buffer.value) 95 >>> len(d.buffer)
96 10 96 10
97 >>> d._fill_buf() 97 >>> d._fill_buf()
98 Traceback (most recent call last): 98 Traceback (most recent call last):
99 ... 99 ...
100 StopIteration 100 StopIteration
101 >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) 101 >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10)
102 >>> d._fill_buf() 102 >>> d._fill_buf()
103 >>> len(d.buffer.value) 103 >>> len(d.buffer)
104 9 104 9
105 >>> d._fill_buf() 105 >>> d._fill_buf()
106 Traceback (most recent call last): 106 Traceback (most recent call last):
107 ... 107 ...
108 StopIteration 108 StopIteration
121 raise 121 raise
122 break 122 break
123 tmpbuf = self.curfile.read(self.bufsize - len(buf)) 123 tmpbuf = self.curfile.read(self.bufsize - len(buf))
124 buf = numpy.row_stack((buf, tmpbuf)) 124 buf = numpy.row_stack((buf, tmpbuf))
125 125
126 self.buffer = theano.shared(numpy.asarray(buf, dtype=theano.config.floatX)) 126 self.buffer = numpy.asarray(buf, dtype=theano.config.floatX)
127 self.curpos = 0 127 self.curpos = 0
128 128
129 def __next__(self): 129 def __next__(self):
130 r""" 130 r"""
131 Returns the next portion of the dataset. 131 Returns the next portion of the dataset.
132 132
133 Test: 133 Test:
134 >>> d = DataIterator([DummyFile(20)], 10, 20) 134 >>> d = DataIterator([DummyFile(20)], 10, 20)
135 >>> d.next() 135 >>> len(d.next())
136 Subtensor{0:10:}.0 136 10
137 >>> d.next() 137 >>> len(d.next())
138 Subtensor{10:20:}.0 138 10
139 >>> d.next() 139 >>> d.next()
140 Traceback (most recent call last): 140 Traceback (most recent call last):
141 ... 141 ...
142 StopIteration 142 StopIteration
143 >>> d.next() 143 >>> d.next()