diff test_dataset.py @ 102:4537ac630348

modifed test to accomodate the last change in dataset.py. i.e. minibatch without a fixed number of batch return an incomplete minibatch at the end to stop at the end of the dataset.
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 16:03:17 -0400
parents 574f4db76022
children a90d85fef3d4
line wrap: on
line diff
--- a/test_dataset.py	Tue May 06 16:01:53 2008 -0400
+++ b/test_dataset.py	Tue May 06 16:03:17 2008 -0400
@@ -112,50 +112,68 @@
         assert i==len(ds)
         del x,y,i
 
+        def test_minibatch_size(minibatch,minibatch_size,len_ds,nb_field,nb_iter_finished):
+            ##full minibatch or the last minibatch
+            for idx in range(nb_field):
+                test_minibatch_field_size(minibatch[idx],minibatch_size,len_ds,nb_iter_finished)
+            del idx
+        def test_minibatch_field_size(minibatch_field,minibatch_size,len_ds,nb_iter_finished):
+            assert len(minibatch_field)==minibatch_size or ((nb_iter_finished*minibatch_size+len(minibatch_field))==len_ds and len(minibatch_field)<minibatch_size)
 
 #     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
         i=0
-        for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
+        mi=0
+        m=ds.minibatches(['x','z'], minibatch_size=3)
+        for minibatch in m:
             assert len(minibatch)==2
-            assert len(minibatch[0])==3
-            assert len(minibatch[1])==3
+            test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
             assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
-            i+=1
-        #assert i==#??? What shoud be the value? #option for the rest.
-        print i
-        del minibatch,i
+            mi+=1
+            i+=len(minibatch[0])
+        assert i==len(ds)
+        assert mi==4
+        del minibatch,i,m,mi
+
         i=0
-        for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
+        mi=0
+        m=ds.minibatches(['x','y'], minibatch_size=3)
+        for minibatch in m:
             assert len(minibatch)==2
-            assert len(minibatch[0])==3
-            assert len(minibatch[1])==3
-            for id in range(3):
+            test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
+            mi+=1
+            for id in range(len(minibatch[0])):
                 assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all()
                 i+=1
-        #assert i==#??? What shoud be the value?
-        print i
-        del minibatch,i,id
+        assert i==len(ds)
+        assert mi==4
+        del minibatch,i,id,m,mi
 
 #     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
         i=0
-        for x,z in ds.minibatches(['x','z'], minibatch_size=3):
-            assert len(x)==3
-            assert len(z)==3
+        mi=0
+        m=ds.minibatches(['x','z'], minibatch_size=3)
+        for x,z in m:
+            test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
+            test_minibatch_field_size(z,m.minibatch_size,len(ds),mi)
             assert (x[:,0:3:2]==z).all()
-            i+=1
-        #assert i==#??? What shoud be the value?
-        print i 
-        del x,z,i
+            i+=len(x)
+            mi+=1
+        assert i==len(ds)
+        assert mi==4
+        del x,z,i,m,mi
         i=0
-        for x,y in ds.minibatches(['x','y'], minibatch_size=3):
-            assert len(x)==3
-            assert len(y)==3
-            for id in range(3):
+        mi=0
+        m=ds.minibatches(['x','y'], minibatch_size=3)
+        for x,y in m:
+            test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
+            test_minibatch_field_size(y,m.minibatch_size,len(ds),mi)
+            mi+=1
+            for id in range(len(x)):
                 assert (numpy.append(x[id],y[id])==a[i]).all()
                 i+=1
-        #assert i==#??? What shoud be the value?
-        print i 
-        del x,y,i,id
+        assert i==len(ds)
+        assert mi==4
+        del x,y,i,id,m,mi
 
 #not in doc
         i=0