diff test_dataset.py @ 104:e1a004b21daa

more test
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 16:12:37 -0400
parents a90d85fef3d4
children cf9bdb1d9656
line wrap: on
line diff
--- a/test_dataset.py	Tue May 06 16:07:39 2008 -0400
+++ b/test_dataset.py	Tue May 06 16:12:37 2008 -0400
@@ -124,6 +124,7 @@
         i=0
         mi=0
         m=ds.minibatches(['x','z'], minibatch_size=3)
+        assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
         for minibatch in m:
             assert len(minibatch)==2
             test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
@@ -137,6 +138,7 @@
         i=0
         mi=0
         m=ds.minibatches(['x','y'], minibatch_size=3)
+        assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
         for minibatch in m:
             assert len(minibatch)==2
             test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
@@ -152,6 +154,7 @@
         i=0
         mi=0
         m=ds.minibatches(['x','z'], minibatch_size=3)
+        assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
         for x,z in m:
             test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
             test_minibatch_field_size(z,m.minibatch_size,len(ds),mi)
@@ -177,27 +180,32 @@
 
 #not in doc
         i=0
-        for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4):
+        m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
+        assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+        for x,y in m:
             assert len(x)==3
             assert len(y)==3
             for id in range(3):
                 assert (numpy.append(x[id],y[id])==a[i+4]).all()
                 i+=1
         assert i==3
-        del x,y,i,id
+        del x,y,i,id,m
 
         i=0
-        for x,y in ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4):
+        m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
+        assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+        for x,y in m:
             assert len(x)==3
             assert len(y)==3
             for id in range(3):
                 assert (numpy.append(x[id],y[id])==a[i+4]).all()
                 i+=1
         assert i==6
-        del x,y,i,id
+        del x,y,i,id,m
 
         i=0
         m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
+        assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
         for x,y in m:
             assert len(x)==3
             assert len(y)==3