Mercurial > ift6266
comparison datasets/dsetiter.py @ 178:938bd350dbf0
Make the datasets iterators return theano shared slices with the appropriate types.
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Sat, 27 Feb 2010 15:09:02 -0500 |
parents | 4b28d7382dbf |
children | defd388aba0c |
comparison
equal
deleted
inserted
replaced
177:be714ac9bcbd | 178:938bd350dbf0 |
---|---|
1 import numpy | 1 import numpy, theano |
2 | 2 |
3 class DummyFile(object): | 3 class DummyFile(object): |
4 def __init__(self, size): | 4 def __init__(self, size): |
5 self.size = size | 5 self.size = size |
6 | 6 |
86 Test: | 86 Test: |
87 >>> d = DataIterator([DummyFile(20)], 10, 10) | 87 >>> d = DataIterator([DummyFile(20)], 10, 10) |
88 >>> d._fill_buf() | 88 >>> d._fill_buf() |
89 >>> d.curpos | 89 >>> d.curpos |
90 0 | 90 0 |
91 >>> len(d.buffer) | 91 >>> len(d.buffer.value) |
92 10 | 92 10 |
93 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) | 93 >>> d = DataIterator([DummyFile(11), DummyFile(9)], 10, 10) |
94 >>> d._fill_buf() | 94 >>> d._fill_buf() |
95 >>> len(d.buffer) | 95 >>> len(d.buffer.value) |
96 10 | 96 10 |
97 >>> d._fill_buf() | 97 >>> d._fill_buf() |
98 Traceback (most recent call last): | 98 Traceback (most recent call last): |
99 ... | 99 ... |
100 StopIteration | 100 StopIteration |
101 >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) | 101 >>> d = DataIterator([DummyFile(10), DummyFile(9)], 10, 10) |
102 >>> d._fill_buf() | 102 >>> d._fill_buf() |
103 >>> len(d.buffer) | 103 >>> len(d.buffer.value) |
104 9 | 104 9 |
105 >>> d._fill_buf() | 105 >>> d._fill_buf() |
106 Traceback (most recent call last): | 106 Traceback (most recent call last): |
107 ... | 107 ... |
108 StopIteration | 108 StopIteration |
109 """ | 109 """ |
110 self.buffer = None | |
110 if self.empty: | 111 if self.empty: |
111 raise StopIteration | 112 raise StopIteration |
112 self.buffer = self.curfile.read(self.bufsize) | 113 buf = self.curfile.read(self.bufsize) |
113 | 114 |
114 while len(self.buffer) < self.bufsize: | 115 while len(buf) < self.bufsize: |
115 try: | 116 try: |
116 self.curfile = self.files.next() | 117 self.curfile = self.files.next() |
117 except StopIteration: | 118 except StopIteration: |
118 self.empty = True | 119 self.empty = True |
119 if len(self.buffer) == 0: | 120 if len(buf) == 0: |
120 raise StopIteration | 121 raise |
121 self.curpos = 0 | 122 break |
122 return | 123 tmpbuf = self.curfile.read(self.bufsize - len(buf)) |
123 tmpbuf = self.curfile.read(self.bufsize - len(self.buffer)) | 124 buf = numpy.row_stack((buf, tmpbuf)) |
124 self.buffer = numpy.row_stack((self.buffer, tmpbuf)) | 125 |
126 self.buffer = theano.shared(numpy.asarray(buf, dtype=theano.config.floatX)) | |
125 self.curpos = 0 | 127 self.curpos = 0 |
126 | 128 |
127 def __next__(self): | 129 def __next__(self): |
128 r""" | 130 r""" |
129 Returns the next portion of the dataset. | 131 Returns the next portion of the dataset. |
130 | 132 |
131 Test: | 133 Test: |
132 >>> d = DataIterator([DummyFile(20)], 10, 20) | 134 >>> d = DataIterator([DummyFile(20)], 10, 20) |
133 >>> len(d.next()) | 135 >>> d.next() |
134 10 | 136 Subtensor{0:10:}.0 |
135 >>> len(d.next()) | 137 >>> d.next() |
136 10 | 138 Subtensor{10:20:}.0 |
137 >>> d.next() | 139 >>> d.next() |
138 Traceback (most recent call last): | 140 Traceback (most recent call last): |
139 ... | 141 ... |
140 StopIteration | 142 StopIteration |
141 >>> d.next() | 143 >>> d.next() |