changeset 341:9c08e3af975e

corrected test for dataset.minibatches()
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 17 Jun 2008 13:33:17 -0400
parents d96be0eba3cc
children 2259f6fa4959
files _test_dataset.py
diffstat 1 files changed, 59 insertions(+), 47 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Tue Jun 17 11:41:01 2008 -0400
+++ b/_test_dataset.py	Tue Jun 17 13:33:17 2008 -0400
@@ -134,12 +134,13 @@
 #     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
     i=0
     mi=0
-    m=ds.minibatches(['x','z'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','z'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for minibatch in m:
-        assert isinstance(minibatch,DataSetFields)
+        assert isinstance(minibatch,LookupList)
         assert len(minibatch)==2
-        test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
+        test_minibatch_size(minibatch,size,len(ds),2,mi)
         if type(ds)==ArrayDataSet:
             assert (minibatch[0][:,::2]==minibatch[1]).all()
         else:
@@ -147,92 +148,103 @@
                 (minibatch[0][j][::2]==minibatch[1][j]).all()
         mi+=1
         i+=len(minibatch[0])
-    assert i==len(ds)
-    assert mi==4
-    del minibatch,i,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del minibatch,i,m,mi,size
 
     i=0
     mi=0
-    m=ds.minibatches(['x','y'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for minibatch in m:
+        assert isinstance(minibatch,LookupList)
         assert len(minibatch)==2
-        test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
+        test_minibatch_size(minibatch,size,len(ds),2,mi)
         mi+=1
         for id in range(len(minibatch[0])):
             assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all()
             i+=1
-    assert i==len(ds)
-    assert mi==4
-    del minibatch,i,id,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del minibatch,i,id,m,mi,size
 
 #     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
     i=0
     mi=0
-    m=ds.minibatches(['x','z'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','z'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for x,z in m:
-        test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
-        test_minibatch_field_size(z,m.minibatch_size,len(ds),mi)
+        test_minibatch_field_size(x,size,len(ds),mi)
+        test_minibatch_field_size(z,size,len(ds),mi)
         for id in range(len(x)):
             assert (x[id][::2]==z[id]).all()
             i+=1
         mi+=1
-    assert i==len(ds)
-    assert mi==4
-    del x,z,i,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del x,z,i,m,mi,size
+
     i=0
     mi=0
+    size=3
     m=ds.minibatches(['x','y'], minibatch_size=3)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
-        test_minibatch_field_size(y,m.minibatch_size,len(ds),mi)
+        assert len(x)==size
+        assert len(y)==size
+        test_minibatch_field_size(x,size,len(ds),mi)
+        test_minibatch_field_size(y,size,len(ds),mi)
         mi+=1
         for id in range(len(x)):
             assert (numpy.append(x[id],y[id])==array[i]).all()
             i+=1
-    assert i==len(ds)
-    assert mi==4
-    del x,y,i,id,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del x,y,i,id,m,mi,size
 
 #not in doc
     i=0
-    m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id,m
+    assert i==size
+    del x,y,i,id,m,size
 
     i=0
-    m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id,m
+    assert i==2*size
+    del x,y,i,id,m,size
 
     i=0
-    m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id
+    assert i==2*size # should not wrap
+    del x,y,i,id,size
 
-    assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
+    assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
     assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0)
 
 def test_ds_iterator(array,iterator1,iterator2,iterator3):