Mercurial > pylearn
diff test_dataset.py @ 102:4537ac630348
modifed test to accomodate the last change in dataset.py.
i.e. minibatch without a fixed number of batch return an incomplete minibatch at the end to stop at the end of the dataset.
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 16:03:17 -0400 |
parents | 574f4db76022 |
children | a90d85fef3d4 |
line wrap: on
line diff
--- a/test_dataset.py Tue May 06 16:01:53 2008 -0400 +++ b/test_dataset.py Tue May 06 16:03:17 2008 -0400 @@ -112,50 +112,68 @@ assert i==len(ds) del x,y,i + def test_minibatch_size(minibatch,minibatch_size,len_ds,nb_field,nb_iter_finished): + ##full minibatch or the last minibatch + for idx in range(nb_field): + test_minibatch_field_size(minibatch[idx],minibatch_size,len_ds,nb_iter_finished) + del idx + def test_minibatch_field_size(minibatch_field,minibatch_size,len_ds,nb_iter_finished): + assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size) # - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N): i=0 - for minibatch in ds.minibatches(['x','z'], minibatch_size=3): + mi=0 + m=ds.minibatches(['x','z'], minibatch_size=3) + for minibatch in m: assert len(minibatch)==2 - assert len(minibatch[0])==3 - assert len(minibatch[1])==3 + test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) assert (minibatch[0][:,0:3:2]==minibatch[1]).all() - i+=1 - #assert i==#??? What shoud be the value? #option for the rest. - print i - del minibatch,i + mi+=1 + i+=len(minibatch[0]) + assert i==len(ds) + assert mi==4 + del minibatch,i,m,mi + i=0 - for minibatch in ds.minibatches(['x','y'], minibatch_size=3): + mi=0 + m=ds.minibatches(['x','y'], minibatch_size=3) + for minibatch in m: assert len(minibatch)==2 - assert len(minibatch[0])==3 - assert len(minibatch[1])==3 - for id in range(3): + test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi) + mi+=1 + for id in range(len(minibatch[0])): assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all() i+=1 - #assert i==#??? What shoud be the value? - print i - del minibatch,i,id + assert i==len(ds) + assert mi==4 + del minibatch,i,id,m,mi # - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N): i=0 - for x,z in ds.minibatches(['x','z'], minibatch_size=3): - assert len(x)==3 - assert len(z)==3 + mi=0 + m=ds.minibatches(['x','z'], minibatch_size=3) + for x,z in m: + test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) + test_minibatch_field_size(z,m.minibatch_size,len(ds),mi) assert (x[:,0:3:2]==z).all() - i+=1 - #assert i==#??? What shoud be the value? - print i - del x,z,i + i+=len(x) + mi+=1 + assert i==len(ds) + assert mi==4 + del x,z,i,m,mi i=0 - for x,y in ds.minibatches(['x','y'], minibatch_size=3): - assert len(x)==3 - assert len(y)==3 - for id in range(3): + mi=0 + m=ds.minibatches(['x','y'], minibatch_size=3) + for x,y in m: + test_minibatch_field_size(x,m.minibatch_size,len(ds),mi) + test_minibatch_field_size(y,m.minibatch_size,len(ds),mi) + mi+=1 + for id in range(len(x)): assert (numpy.append(x[id],y[id])==a[i]).all() i+=1 - #assert i==#??? What shoud be the value? - print i - del x,y,i,id + assert i==len(ds) + assert mi==4 + del x,y,i,id,m,mi #not in doc i=0