# HG changeset patch # User Frederic Bastien # Date 1213723997 14400 # Node ID 9c08e3af975e76c445947be8b989f5d78858d87a # Parent d96be0eba3ccd7ae83a2f5cf7cbc6d8ccfd94434 corrected test for dataset.minibatches() diff -r d96be0eba3cc -r 9c08e3af975e _test_dataset.py --- a/_test_dataset.py Tue Jun 17 11:41:01 2008 -0400 +++ b/_test_dataset.py Tue Jun 17 13:33:17 2008 -0400 @@ -134,12 +134,13 @@ # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): i=0 mi=0 - m=ds.minibatches(['x','z'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','z'], minibatch_size=size) + assert hasattr(m,'__iter__') for minibatch in m: - assert isinstance(minibatch,DataSetFields) + assert isinstance(minibatch,LookupList) assert len(minibatch)==2 - test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) + test_minibatch_size(minibatch,size,len(ds),2,mi) if type(ds)==ArrayDataSet: assert (minibatch[0][:,::2]==minibatch[1]).all() else: @@ -147,92 +148,103 @@ (minibatch[0][j][::2]==minibatch[1][j]).all() mi+=1 i+=len(minibatch[0]) - assert i==len(ds) - assert mi==4 - del minibatch,i,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del minibatch,i,m,mi,size i=0 mi=0 - m=ds.minibatches(['x','y'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'], minibatch_size=size) + assert hasattr(m,'__iter__') for minibatch in m: + assert isinstance(minibatch,LookupList) assert len(minibatch)==2 - test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) + test_minibatch_size(minibatch,size,len(ds),2,mi) mi+=1 for id in range(len(minibatch[0])): assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all() i+=1 - assert i==len(ds) - assert mi==4 - del minibatch,i,id,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del minibatch,i,id,m,mi,size # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): i=0 mi=0 - m=ds.minibatches(['x','z'], minibatch_size=3) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','z'], minibatch_size=size) + assert hasattr(m,'__iter__') for x,z in m: - test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) - test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) + test_minibatch_field_size(x,size,len(ds),mi) + test_minibatch_field_size(z,size,len(ds),mi) for id in range(len(x)): assert (x[id][::2]==z[id]).all() i+=1 mi+=1 - assert i==len(ds) - assert mi==4 - del x,z,i,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del x,z,i,m,mi,size + i=0 mi=0 + size=3 m=ds.minibatches(['x','y'], minibatch_size=3) + assert hasattr(m,'__iter__') for x,y in m: - test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) - test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) + assert len(x)==size + assert len(y)==size + test_minibatch_field_size(x,size,len(ds),mi) + test_minibatch_field_size(y,size,len(ds),mi) mi+=1 for id in range(len(x)): assert (numpy.append(x[id],y[id])==array[i]).all() i+=1 - assert i==len(ds) - assert mi==4 - del x,y,i,id,m,mi + assert i==(len(ds)/size)*size + assert mi==(len(ds)/size) + del x,y,i,id,m,mi,size #not in doc i=0 - m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[i+4]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id,m + assert i==size + del x,y,i,id,m,size i=0 - m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[i+4]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id,m + assert i==2*size + del x,y,i,id,m,size i=0 - m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4) - assert isinstance(m,DataSet.MinibatchWrapAroundIterator) + size=3 + m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4) + assert hasattr(m,'__iter__') for x,y in m: - assert len(x)==m.minibatch_size - assert len(y)==m.minibatch_size - for id in range(m.minibatch_size): + assert len(x)==size + assert len(y)==size + for id in range(size): assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all() i+=1 - assert i==m.n_batches*m.minibatch_size - del x,y,i,id + assert i==2*size # should not wrap + del x,y,i,id,size - assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) + assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0) assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0) def test_ds_iterator(array,iterator1,iterator2,iterator3):