comparison test_dataset.py @ 239:77b362a23f8e

more general test
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Thu, 29 May 2008 10:41:35 -0400
parents 6aff510792dd
children 97f35d586727
comparison
equal deleted inserted replaced
238:ae1d85aca858 239:77b362a23f8e
192 #not in doc 192 #not in doc
193 i=0 193 i=0
194 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) 194 m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
195 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 195 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
196 for x,y in m: 196 for x,y in m:
197 assert len(x)==3 197 assert len(x)==m.minibatch_size
198 assert len(y)==3 198 assert len(y)==m.minibatch_size
199 for id in range(3): 199 for id in range(m.minibatch_size):
200 assert (numpy.append(x[id],y[id])==array[i+4]).all() 200 assert (numpy.append(x[id],y[id])==array[i+4]).all()
201 i+=1 201 i+=1
202 assert i==3 202 assert i==m.n_batches*m.minibatch_size
203 del x,y,i,id,m 203 del x,y,i,id,m
204 204
205 i=0 205 i=0
206 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) 206 m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
207 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 207 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
208 for x,y in m: 208 for x,y in m:
209 assert len(x)==3 209 assert len(x)==m.minibatch_size
210 assert len(y)==3 210 assert len(y)==m.minibatch_size
211 for id in range(3): 211 for id in range(m.minibatch_size):
212 assert (numpy.append(x[id],y[id])==array[i+4]).all() 212 assert (numpy.append(x[id],y[id])==array[i+4]).all()
213 i+=1 213 i+=1
214 assert i==6 214 assert i==m.n_batches*m.minibatch_size
215 del x,y,i,id,m 215 del x,y,i,id,m
216 216
217 i=0 217 i=0
218 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) 218 m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
219 assert isinstance(m,DataSet.MinibatchWrapAroundIterator) 219 assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
220 for x,y in m: 220 for x,y in m:
221 assert len(x)==3 221 assert len(x)==m.minibatch_size
222 assert len(y)==3 222 assert len(y)==m.minibatch_size
223 for id in range(3): 223 for id in range(m.minibatch_size):
224 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() 224 assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
225 i+=1 225 i+=1
226 assert i==m.n_batches*m.minibatch_size 226 assert i==m.n_batches*m.minibatch_size
227 del x,y,i,id 227 del x,y,i,id
228 228