diff test_dataset.py @ 89:05dc4804357b

more test and refactoring
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 16:54:16 -0400
parents 3fd6879e0f76
children eee739fefdff
line wrap: on
line diff
--- a/test_dataset.py	Mon May 05 14:51:41 2008 -0400
+++ b/test_dataset.py	Mon May 05 16:54:16 2008 -0400
@@ -44,23 +44,88 @@
 
     assert len(ds)==10
     #assert ds==a? should this work?
-    for i in range(len(ds)):
-        assert ds[i]['x'].all()==a[i][:2].all()
-        assert ds[i]['y']==a[i][3]
-        assert ds[i]['z'].all()==a[i][0:3:2].all()
+
+#not in doc!!!
+    for example in range(len(ds)):
+        assert ds[example]['x'].all()==a[example][:2].all()
+        assert ds[example]['y']==a[example][3]
+        assert ds[example]['z'].all()==a[example][0:3:2].all()
+#     - for example in dataset::
     i=0
-    for x in ds('x','y'):
-        assert numpy.append(x['x'],x['y']).all()==a[i].all()
+    for example in ds:
+        assert example['x'].all()==a[i][:2].all()
+        assert example['y']==a[i][3]
+        assert example['z'].all()==a[i][0:3:2].all()
+        assert numpy.append(example['x'],example['y']).all()==a[i].all()
         i+=1
+    assert i==len(ds)
+
+    def test_ds_iterator(iterator1,iterator2,iterator3):
+        i=0
+        for x,y in iterator1:
+            assert x.all()==a[i][:2].all()
+            assert y==a[i][3]
+            assert numpy.append(x,y).all()==a[i].all()
+            i+=1
+        assert i==len(ds)
+        i=0
+        for y,z in iterator2:
+            assert y==a[i][3]
+            assert z.all()==a[i][0:3:2].all()
+            i+=1
+        assert i==len(ds)
+        i=0
+        for x,y,z in iterator3:
+            assert x.all()==a[i][:2].all()
+            assert y==a[i][3]
+            assert z.all()==a[i][0:3:2].all()
+            assert numpy.append(x,y).all()==a[i].all()
+            i+=1
+        assert i==len(ds)
 
+#not in doc!!!     - for val1,val2,val3 in dataset(field1, field2,field3):
+    test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z'))
+
+#not in doc!!!     - for val1,val2,val3 in dataset((field1, field2,field3)):
+    test_ds_iterator(ds(('x','y')),ds(('y','z')),ds(('x','y','z')))
+
+#     - for val1,val2,val3 in dataset([field1, field2,field3]): #was bugged
+    test_ds_iterator(ds(['x','y']),ds(['y','z']),ds(['x','y','z']))
+    
+#     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
+    for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
+        assert len(minibatch)==2
+        assert len(minibatch[0])==3
+        assert len(minibatch[1])==3
+        assert minibatch[0][:,0:3:2].all()==minibatch[1].all()
     i=0
-    for x,y in ds('x','y'):
-        assert numpy.append(x,y).all()==a[i].all()
-        i+=1
-    for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
-        assert minibatch[0][:,0:3:2].all()==minibatch[1].all()
+    for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
+        assert len(minibatch)==2
+        assert len(minibatch[0])==3
+        assert len(minibatch[1])==3
+        for id in range(3):
+            assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
+            i+=1
+
+#     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
     for x,z in ds.minibatches(['x','z'], minibatch_size=3):
+        assert len(x)==3
+        assert len(z)==3
         assert x[:,0:3:2].all()==z.all()
+    i=0
+    for x,y in ds.minibatches(['x','y'], minibatch_size=3):
+        assert len(x)==3
+        assert len(y)==3
+        for id in range(3):
+            assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
+            i+=1
+#     - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict.
+#    for x,y,z in ds:
+#        assert x.all()==a[i][:2].all()
+#        assert y==a[i][3]
+#        assert z.all()==a[i][0:3:2].all()
+#        assert numpy.append(x,y).all()==a[i].all()
+#        i+=1
 
 #    for minibatch in ds.minibatches(['z','y'], minibatch_size=3):
 #        print minibatch
@@ -99,6 +164,15 @@
             assert orig[index[i]]['z'].all()==z.all()
             i+=1
         del i
+        ds[0]
+        if len(ds)>2:
+            ds[:1]
+            ds[1:1]
+            ds[1:1:1]
+        if len(ds)>5:
+            ds[[1,2,3]]
+        for x in ds:
+            pass
 
     #ds[:n] returns a dataset with the n first examples.
     ds2=ds[:3]
@@ -106,22 +180,13 @@
 
     #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
     ds2=ds[1:7:2]
-    ds2[1]
     test_ds(ds,ds2,[1,3,5])
+
     #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
-#   ds2=ds[[4,7,2,8]]# fail???
-#   assert len(ds2)==4
-#   i=0
-#   index=[4,7,2,8]
-#    for x in ds2:
-#        assert ds[index[i]]['x'].all()==a[index[i]][:3].all()
-#        assert ds[index[i]]['x'].all()==x.all()
-#        assert ds[index[i]]['y']==a[index[i]][3]
-#        assert ds[index[i]]['y']==y
-#        assert ds[index[i]]['z'].all()==a[index[i]][0:3:2].all()
-#        assert ds[index[i]]['z'].all()==z.all()
-#        i+=1
+    ds2=ds[[4,7,2,8]]
+    test_ds(ds,ds2,[4,7,2,8])
     #ds[i1,i2,...]# should we accept????
+
     #ds[fieldname]# an iterable over the values of the field fieldname across
       #the ds (the iterable is obtained by default by calling valuesVStack
       #over the values for individual examples).
@@ -183,5 +248,3 @@
 test_LookupList()
 test_ArrayDataSet()
 
-
-