diff test_dataset.py @ 96:352910e0dbf5

added test and some restructuring for futur use
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 10:53:21 -0400
parents 9c8f3c9c247b
children 574f4db76022
line wrap: on
line diff
--- a/test_dataset.py	Tue May 06 10:52:45 2008 -0400
+++ b/test_dataset.py	Tue May 06 10:53:21 2008 -0400
@@ -38,120 +38,193 @@
     #test with y too
     #test missing value
 
+    def test_iterate_over_examples(array,ds):
+#not in doc!!!
+        i=0
+        for example in range(len(ds)):
+            assert (ds[example]['x']==a[example][:3]).all()
+            assert ds[example]['y']==a[example][3]
+            assert (ds[example]['z']==a[example][[0,2]]).all()
+            i+=1
+        assert i==len(ds)
+        del example,i
+
+#     - for example in dataset:
+        i=0
+        for example in ds:
+            assert len(example)==3
+            assert (example['x']==array[i][:3]).all()
+            assert example['y']==array[i][3]
+            assert (example['z']==array[i][0:3:2]).all()
+            assert (numpy.append(example['x'],example['y'])==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del example,i
+
+#     - for val1,val2,... in dataset:
+        i=0
+        for x,y,z in ds:
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (z==array[i][0:3:2]).all()
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del x,y,z,i
+
+#     - for example in dataset(field1, field2,field3, ...):
+        i=0
+        for example in ds('x','y','z'):
+            assert len(example)==3
+            assert (example['x']==array[i][:3]).all()
+            assert example['y']==array[i][3]
+            assert (example['z']==array[i][0:3:2]).all()
+            assert (numpy.append(example['x'],example['y'])==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del example,i
+        i=0
+        for example in ds('y','x'):
+            assert len(example)==2
+            assert (example['x']==array[i][:3]).all()
+            assert example['y']==array[i][3]
+            assert (numpy.append(example['x'],example['y'])==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del example,i
+
+#     - for val1,val2,val3 in dataset(field1, field2,field3):
+        i=0
+        for x,y,z in ds('x','y','z'):
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (z==array[i][0:3:2]).all()
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del x,y,z,i
+        i=0
+        for y,x in ds('y','x',):
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del x,y,i
+
+
+#     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
+        i=0
+        for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
+            assert len(minibatch)==2
+            assert len(minibatch[0])==3
+            assert len(minibatch[1])==3
+            assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
+            i+=1
+        #assert i==#??? What shoud be the value?
+        print i
+        del minibatch,i
+        i=0
+        for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
+            assert len(minibatch)==2
+            assert len(minibatch[0])==3
+            assert len(minibatch[1])==3
+            for id in range(3):
+                assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all()
+                i+=1
+        #assert i==#??? What shoud be the value?
+        print i
+        del minibatch,i,id
+
+#     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
+        i=0
+        for x,z in ds.minibatches(['x','z'], minibatch_size=3):
+            assert len(x)==3
+            assert len(z)==3
+            assert (x[:,0:3:2]==z).all()
+            i+=1
+        #assert i==#??? What shoud be the value?
+        print i 
+        del x,z,i
+        i=0
+        for x,y in ds.minibatches(['x','y'], minibatch_size=3):
+            assert len(x)==3
+            assert len(y)==3
+            for id in range(3):
+                assert (numpy.append(x[id],y[id])==a[i]).all()
+                i+=1
+        #assert i==#??? What shoud be the value?
+        print i 
+        del x,y,i,id
+
+#not in doc
+        i=0
+        for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4):
+            assert len(x)==3
+            assert len(y)==3
+            for id in range(3):
+                assert (numpy.append(x[id],y[id])==a[i+4]).all()
+                i+=1
+        assert i==3
+        del x,y,i,id
+
+        i=0
+        for x,y in ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4):
+            assert len(x)==3
+            assert len(y)==3
+            for id in range(3):
+                assert (numpy.append(x[id],y[id])==a[i+4]).all()
+                i+=1
+        assert i==6
+        del x,y,i,id
+
+        i=0
+        for x,y in ds.minibatches(['x','y'],n_batches=10,minibatch_size=3,offset=4):
+            assert len(x)==3
+            assert len(y)==3
+            for id in range(3):
+                assert (numpy.append(x[id],y[id])==a[i+4]).all()
+                i+=1
+        assert i==6
+        del x,y,i,id
+
+
+    def test_ds_iterator(array,iterator1,iterator2,iterator3):
+        i=0
+        for x,y in iterator1:
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        i=0
+        for y,z in iterator2:
+            assert y==array[i][3]
+            assert (z==array[i][0:3:2]).all()
+            i+=1
+        assert i==len(ds)
+        i=0
+        for x,y,z in iterator3:
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (z==array[i][0:3:2]).all()
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+
     print "test_ArrayDataSet"
     a = numpy.random.rand(10,4)
     ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
     ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
     assert len(ds)==10
     #assert ds==a? should this work?
-
-#not in doc!!!
-    for example in range(len(ds)):
-        assert (ds[example]['x']==a[example][:3]).all()
-        assert ds[example]['y']==a[example][3]
-        assert (ds[example]['z']==a[example][[0,2]]).all()
+    
+    test_iterate_over_examples(a, ds)
 
-#     - for example in dataset:
-    i=0
-    for example in ds:
-        assert (example['x']==a[i][:3]).all()
-        assert example['y']==a[i][3]
-        assert (example['z']==a[i][0:3:2]).all()
-        assert (numpy.append(example['x'],example['y'])==a[i]).all()
-        i+=1
-    assert i==len(ds)
-#     - for val1,val2,... in dataset:
-    i=0
-    for x,y,z in ds:
-        assert (x==a[i][:3]).all()
-        assert y==a[i][3]
-        assert (z==a[i][0:3:2]).all()
-        assert (numpy.append(x,y)==a[i]).all()
-        i+=1
-    assert i==len(ds)
-#     - for example in dataset(field1, field2,field3, ...):
-    i=0
-    for example in ds('x','y','z'):
-        assert (example['x']==a[i][:3]).all()
-        assert example['y']==a[i][3]
-        assert (example['z']==a[i][0:3:2]).all()
-        assert (numpy.append(example['x'],example['y'])==a[i]).all()
-        i+=1
-    assert i==len(ds)
 
 #     - for val1,val2,val3 in dataset(field1, field2,field3):
-
-#     - for example in dataset(field1, field2,field3, ...):
-
-    def test_ds_iterator(iterator1,iterator2,iterator3):
-        i=0
-        for x,y in iterator1:
-            assert (x==a[i][:3]).all()
-            assert y==a[i][3]
-            assert (numpy.append(x,y)==a[i]).all()
-            i+=1
-        assert i==len(ds)
-        i=0
-        for y,z in iterator2:
-            assert y==a[i][3]
-            assert (z==a[i][0:3:2]).all()
-            i+=1
-        assert i==len(ds)
-        i=0
-        for x,y,z in iterator3:
-            assert (x==a[i][:3]).all()
-            assert y==a[i][3]
-            assert (z==a[i][0:3:2]).all()
-            assert (numpy.append(x,y)==a[i]).all()
-            i+=1
-        assert i==len(ds)
-
-#not in doc!!!     - for val1,val2,val3 in dataset(field1, field2,field3):
-    test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z'))
+    test_ds_iterator(a,ds('x','y'),ds('y','z'),ds('x','y','z'))
 
-#     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
-    for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
-        assert len(minibatch)==2
-        assert len(minibatch[0])==3
-        assert len(minibatch[1])==3
-        assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
-    i=0
-    for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
-        assert len(minibatch)==2
-        assert len(minibatch[0])==3
-        assert len(minibatch[1])==3
-        for id in range(3):
-            assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all()
-            i+=1
-
-#     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
-    for x,z in ds.minibatches(['x','z'], minibatch_size=3):
-        assert len(x)==3
-        assert len(z)==3
-        assert (x[:,0:3:2]==z).all()
-    i=0
-    for x,y in ds.minibatches(['x','y'], minibatch_size=3):
-        assert len(x)==3
-        assert len(y)==3
-        for id in range(3):
-            assert (numpy.append(x[id],y[id])==a[i]).all()
-            i+=1
-#     - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict.
-#    for x,y,z in ds:
-#        assert (x==a[i][:2]).all()
-#        assert y==a[i][3]
-#        assert (z==a[i][0:3:2]).all()
-#        assert (numpy.append(x,y)==a[i]).all()
-#        i+=1
-
-#    for minibatch in ds.minibatches(['z','y'], minibatch_size=3):
-#        print minibatch
-#    minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4)
-#    minibatch = minibatch_iterator.__iter__().next()
-#    print "minibatch=",minibatch
-#    for var in minibatch:
-#        print "var=",var
-#    print "take a slice and look at field y",ds[1:6:2]["y"]
     assert have_raised("ds['h']")  # h is not defined...
     assert have_raised("ds[['h']]")  # h is not defined...
 
@@ -215,7 +288,7 @@
     #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])
     #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])
 
-#    for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work.
+#    for (x,y) in (ds('x','y'),a): #???don't work # haven't found a variant that work.
 #        assert numpy.append(x,y)==z
 
 def test_LookupList():