comparison datasets/dsetiter.py @ 302:1adfafdc3d57

Fix concatenation of 1-dim datasets (such as int target vectors).
author Arnaud Bergeron <abergeron@gmail.com>
date Tue, 30 Mar 2010 14:40:54 -0400
parents 0d0677773533
children
comparison
equal deleted inserted replaced
301:be45e7db7cd4 302:1adfafdc3d57
1 import numpy 1 import numpy
2 2
3 class DummyFile(object): 3 class DummyFile(object):
4 def __init__(self, size): 4 def __init__(self, size, shape=()):
5 self.size = size 5 self.size = size
6 self.shape = shape
6 7
7 def read(self, num): 8 def read(self, num):
8 if num > self.size: 9 if num > self.size:
9 num = self.size 10 num = self.size
10 self.size -= num 11 self.size -= num
11 return numpy.zeros((num, 3, 2)) 12 return numpy.zeros((num,)+self.shape)
12 13
13 class DataIterator(object): 14 class DataIterator(object):
14 15
15 def __init__(self, files, batchsize, bufsize=None): 16 def __init__(self, files, batchsize, bufsize=None):
16 r""" 17 r"""
82 Fill the internal buffer. 83 Fill the internal buffer.
83 84
84 Will fill across files in case the current one runs out. 85 Will fill across files in case the current one runs out.
85 86
86 Test: 87 Test:
88 >>> d = DataIterator([DummyFile(20, (3,2))], 10, 10)
89 >>> d._fill_buf()
90 >>> d.curpos
91 0
92 >>> len(d.buffer)
93 10
94 >>> d = DataIterator([DummyFile(11, (3,2)), DummyFile(9, (3,2))], 10, 10)
95 >>> d._fill_buf()
96 >>> len(d.buffer)
97 10
98 >>> d._fill_buf()
99 Traceback (most recent call last):
100 ...
101 StopIteration
102 >>> d = DataIterator([DummyFile(10, (3,2)), DummyFile(9, (3,2))], 10, 10)
103 >>> d._fill_buf()
104 >>> len(d.buffer)
105 9
106 >>> d._fill_buf()
107 Traceback (most recent call last):
108 ...
109 StopIteration
87 >>> d = DataIterator([DummyFile(20)], 10, 10) 110 >>> d = DataIterator([DummyFile(20)], 10, 10)
88 >>> d._fill_buf() 111 >>> d._fill_buf()
89 >>> d.curpos 112 >>> d.curpos
90 0 113 0
91 >>> len(d.buffer) 114 >>> len(d.buffer)
119 self.empty = True 142 self.empty = True
120 if len(buf) == 0: 143 if len(buf) == 0:
121 raise 144 raise
122 break 145 break
123 tmpbuf = self.curfile.read(self.bufsize - len(buf)) 146 tmpbuf = self.curfile.read(self.bufsize - len(buf))
124 buf = numpy.row_stack((buf, tmpbuf)) 147 buf = numpy.concatenate([buf, tmpbuf], axis=0)
125 148
126 self.cursize = len(buf) 149 self.cursize = len(buf)
127 self.buffer = buf 150 self.buffer = buf
128 self.curpos = 0 151 self.curpos = 0
129 152