comparison _test_dataset.py @ 341:9c08e3af975e

corrected test for dataset.minibatches()
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 17 Jun 2008 13:33:17 -0400
parents b48cf8dce2bf
children 2259f6fa4959
comparison
equal deleted inserted replaced
340:d96be0eba3cc 341:9c08e3af975e
132 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) 132 assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size)
133 133
134 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): 134 # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
135 i=0 135 i=0
136 mi=0 136 mi=0
137 m=ds.minibatches(['x','z'], minibatch_size=3) 137 size=3
138 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 138 m=ds.minibatches(['x','z'], minibatch_size=size)
139 assert hasattr(m,'__iter__')
139 for minibatch in m: 140 for minibatch in m:
140 assert isinstance(minibatch,DataSetFields) 141 assert isinstance(minibatch,LookupList)
141 assert len(minibatch)==2 142 assert len(minibatch)==2
142 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) 143 test_minibatch_size(minibatch,size,len(ds),2,mi)
143 if type(ds)==ArrayDataSet: 144 if type(ds)==ArrayDataSet:
144 assert (minibatch[0][:,::2]==minibatch[1]).all() 145 assert (minibatch[0][:,::2]==minibatch[1]).all()
145 else: 146 else:
146 for j in xrange(len(minibatch[0])): 147 for j in xrange(len(minibatch[0])):
147 (minibatch[0][j][::2]==minibatch[1][j]).all() 148 (minibatch[0][j][::2]==minibatch[1][j]).all()
148 mi+=1 149 mi+=1
149 i+=len(minibatch[0]) 150 i+=len(minibatch[0])
150 assert i==len(ds) 151 assert i==(len(ds)/size)*size
151 assert mi==4 152 assert mi==(len(ds)/size)
152 del minibatch,i,m,mi 153 del minibatch,i,m,mi,size
153 154
154 i=0 155 i=0
155 mi=0 156 mi=0
156 m=ds.minibatches(['x','y'], minibatch_size=3) 157 size=3
157 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 158 m=ds.minibatches(['x','y'], minibatch_size=size)
159 assert hasattr(m,'__iter__')
158 for minibatch in m: 160 for minibatch in m:
161 assert isinstance(minibatch,LookupList)
159 assert len(minibatch)==2 162 assert len(minibatch)==2
160 test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) 163 test_minibatch_size(minibatch,size,len(ds),2,mi)
161 mi+=1 164 mi+=1
162 for id in range(len(minibatch[0])): 165 for id in range(len(minibatch[0])):
163 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() 166 assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all()
164 i+=1 167 i+=1
165 assert i==len(ds) 168 assert i==(len(ds)/size)*size
166 assert mi==4 169 assert mi==(len(ds)/size)
167 del minibatch,i,id,m,mi 170 del minibatch,i,id,m,mi,size
168 171
169 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): 172 # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
170 i=0 173 i=0
171 mi=0 174 mi=0
172 m=ds.minibatches(['x','z'], minibatch_size=3) 175 size=3
173 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 176 m=ds.minibatches(['x','z'], minibatch_size=size)
177 assert hasattr(m,'__iter__')
174 for x,z in m: 178 for x,z in m:
175 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) 179 test_minibatch_field_size(x,size,len(ds),mi)
176 test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) 180 test_minibatch_field_size(z,size,len(ds),mi)
177 for id in range(len(x)): 181 for id in range(len(x)):
178 assert (x[id][::2]==z[id]).all() 182 assert (x[id][::2]==z[id]).all()
179 i+=1 183 i+=1
180 mi+=1 184 mi+=1
181 assert i==len(ds) 185 assert i==(len(ds)/size)*size
182 assert mi==4 186 assert mi==(len(ds)/size)
183 del x,z,i,m,mi 187 del x,z,i,m,mi,size
188
184 i=0 189 i=0
185 mi=0 190 mi=0
191 size=3
186 m=ds.minibatches(['x','y'], minibatch_size=3) 192 m=ds.minibatches(['x','y'], minibatch_size=3)
193 assert hasattr(m,'__iter__')
187 for x,y in m: 194 for x,y in m:
188 test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) 195 assert len(x)==size
189 test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) 196 assert len(y)==size
197 test_minibatch_field_size(x,size,len(ds),mi)
198 test_minibatch_field_size(y,size,len(ds),mi)
190 mi+=1 199 mi+=1
191 for id in range(len(x)): 200 for id in range(len(x)):
192 assert (numpy.append(x[id],y[id])==array[i]).all() 201 assert (numpy.append(x[id],y[id])==array[i]).all()
193 i+=1 202 i+=1
194 assert i==len(ds) 203 assert i==(len(ds)/size)*size
195 assert mi==4 204 assert mi==(len(ds)/size)
196 del x,y,i,id,m,mi 205 del x,y,i,id,m,mi,size
197 206
198 #not in doc 207 #not in doc
199 i=0 208 i=0
200 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) 209 size=3
201 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 210 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4)
211 assert hasattr(m,'__iter__')
202 for x,y in m: 212 for x,y in m:
203 assert len(x)==m.minibatch_size 213 assert len(x)==size
204 assert len(y)==m.minibatch_size 214 assert len(y)==size
205 for id in range(m.minibatch_size): 215 for id in range(size):
206 assert (numpy.append(x[id],y[id])==array[i+4]).all() 216 assert (numpy.append(x[id],y[id])==array[i+4]).all()
207 i+=1 217 i+=1
208 assert i==m.n_batches*m.minibatch_size 218 assert i==size
209 del x,y,i,id,m 219 del x,y,i,id,m,size
210 220
211 i=0 221 i=0
212 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) 222 size=3
213 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 223 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4)
224 assert hasattr(m,'__iter__')
214 for x,y in m: 225 for x,y in m:
215 assert len(x)==m.minibatch_size 226 assert len(x)==size
216 assert len(y)==m.minibatch_size 227 assert len(y)==size
217 for id in range(m.minibatch_size): 228 for id in range(size):
218 assert (numpy.append(x[id],y[id])==array[i+4]).all() 229 assert (numpy.append(x[id],y[id])==array[i+4]).all()
219 i+=1 230 i+=1
220 assert i==m.n_batches*m.minibatch_size 231 assert i==2*size
221 del x,y,i,id,m 232 del x,y,i,id,m,size
222 233
223 i=0 234 i=0
224 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) 235 size=3
225 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 236 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4)
237 assert hasattr(m,'__iter__')
226 for x,y in m: 238 for x,y in m:
227 assert len(x)==m.minibatch_size 239 assert len(x)==size
228 assert len(y)==m.minibatch_size 240 assert len(y)==size
229 for id in range(m.minibatch_size): 241 for id in range(size):
230 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() 242 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
231 i+=1 243 i+=1
232 assert i==m.n_batches*m.minibatch_size 244 assert i==2*size # should not wrap
233 del x,y,i,id 245 del x,y,i,id,size
234 246
235 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) 247 assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
236 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) 248 assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0)
237 249
238 def test_ds_iterator(array,iterator1,iterator2,iterator3): 250 def test_ds_iterator(array,iterator1,iterator2,iterator3):
239 l=len(iterator1) 251 l=len(iterator1)
240 i=0 252 i=0