Mercurial > ift6266
comparison datasets/dsetiter.py @ 302:1adfafdc3d57
Fix concatenation of 1-dim datasets (such as int target vectors).
author | Arnaud Bergeron <abergeron@gmail.com> |
---|---|
date | Tue, 30 Mar 2010 14:40:54 -0400 |
parents | 0d0677773533 |
children |
comparison
equal
deleted
inserted
replaced
301:be45e7db7cd4 | 302:1adfafdc3d57 |
---|---|
1 import numpy | 1 import numpy |
2 | 2 |
3 class DummyFile(object): | 3 class DummyFile(object): |
4 def __init__(self, size): | 4 def __init__(self, size, shape=()): |
5 self.size = size | 5 self.size = size |
6 self.shape = shape | |
6 | 7 |
7 def read(self, num): | 8 def read(self, num): |
8 if num > self.size: | 9 if num > self.size: |
9 num = self.size | 10 num = self.size |
10 self.size -= num | 11 self.size -= num |
11 return numpy.zeros((num, 3, 2)) | 12 return numpy.zeros((num,)+self.shape) |
12 | 13 |
13 class DataIterator(object): | 14 class DataIterator(object): |
14 | 15 |
15 def __init__(self, files, batchsize, bufsize=None): | 16 def __init__(self, files, batchsize, bufsize=None): |
16 r""" | 17 r""" |
82 Fill the internal buffer. | 83 Fill the internal buffer. |
83 | 84 |
84 Will fill across files in case the current one runs out. | 85 Will fill across files in case the current one runs out. |
85 | 86 |
86 Test: | 87 Test: |
88 >>> d = DataIterator([DummyFile(20, (3,2))], 10, 10) | |
89 >>> d._fill_buf() | |
90 >>> d.curpos | |
91 0 | |
92 >>> len(d.buffer) | |
93 10 | |
94 >>> d = DataIterator([DummyFile(11, (3,2)), DummyFile(9, (3,2))], 10, 10) | |
95 >>> d._fill_buf() | |
96 >>> len(d.buffer) | |
97 10 | |
98 >>> d._fill_buf() | |
99 Traceback (most recent call last): | |
100 ... | |
101 StopIteration | |
102 >>> d = DataIterator([DummyFile(10, (3,2)), DummyFile(9, (3,2))], 10, 10) | |
103 >>> d._fill_buf() | |
104 >>> len(d.buffer) | |
105 9 | |
106 >>> d._fill_buf() | |
107 Traceback (most recent call last): | |
108 ... | |
109 StopIteration | |
87 >>> d = DataIterator([DummyFile(20)], 10, 10) | 110 >>> d = DataIterator([DummyFile(20)], 10, 10) |
88 >>> d._fill_buf() | 111 >>> d._fill_buf() |
89 >>> d.curpos | 112 >>> d.curpos |
90 0 | 113 0 |
91 >>> len(d.buffer) | 114 >>> len(d.buffer) |
119 self.empty = True | 142 self.empty = True |
120 if len(buf) == 0: | 143 if len(buf) == 0: |
121 raise | 144 raise |
122 break | 145 break |
123 tmpbuf = self.curfile.read(self.bufsize - len(buf)) | 146 tmpbuf = self.curfile.read(self.bufsize - len(buf)) |
124 buf = numpy.row_stack((buf, tmpbuf)) | 147 buf = numpy.concatenate([buf, tmpbuf], axis=0) |
125 | 148 |
126 self.cursize = len(buf) | 149 self.cursize = len(buf) |
127 self.buffer = buf | 150 self.buffer = buf |
128 self.curpos = 0 | 151 self.curpos = 0 |
129 | 152 |