changeset 97:05cfe011ca20

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 06 May 2008 10:53:28 -0400
parents a62c79ec7c8a (current diff) 352910e0dbf5 (diff)
children 7186e4f502d1 c4916445e025
files
diffstat 2 files changed, 182 insertions(+), 112 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Mon May 05 18:14:44 2008 -0400
+++ b/dataset.py	Tue May 06 10:53:28 2008 -0400
@@ -33,9 +33,6 @@
      - for val1,val2,val3 in dataset(field1, field2,field3):
      - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
      - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
-     - for example in dataset::
-        print example['x']
-     - for x,y,z in dataset:
      Each of these is documented below. All of these iterators are expected
      to provide, in addition to the usual 'next()' method, a 'next_index()' method
      which returns a non-negative integer pointing to the position of the next
--- a/test_dataset.py	Mon May 05 18:14:44 2008 -0400
+++ b/test_dataset.py	Tue May 06 10:53:28 2008 -0400
@@ -38,120 +38,193 @@
     #test with y too
     #test missing value
 
+    def test_iterate_over_examples(array,ds):
+#not in doc!!!
+        i=0
+        for example in range(len(ds)):
+            assert (ds[example]['x']==a[example][:3]).all()
+            assert ds[example]['y']==a[example][3]
+            assert (ds[example]['z']==a[example][[0,2]]).all()
+            i+=1
+        assert i==len(ds)
+        del example,i
+
+#     - for example in dataset:
+        i=0
+        for example in ds:
+            assert len(example)==3
+            assert (example['x']==array[i][:3]).all()
+            assert example['y']==array[i][3]
+            assert (example['z']==array[i][0:3:2]).all()
+            assert (numpy.append(example['x'],example['y'])==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del example,i
+
+#     - for val1,val2,... in dataset:
+        i=0
+        for x,y,z in ds:
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (z==array[i][0:3:2]).all()
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del x,y,z,i
+
+#     - for example in dataset(field1, field2,field3, ...):
+        i=0
+        for example in ds('x','y','z'):
+            assert len(example)==3
+            assert (example['x']==array[i][:3]).all()
+            assert example['y']==array[i][3]
+            assert (example['z']==array[i][0:3:2]).all()
+            assert (numpy.append(example['x'],example['y'])==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del example,i
+        i=0
+        for example in ds('y','x'):
+            assert len(example)==2
+            assert (example['x']==array[i][:3]).all()
+            assert example['y']==array[i][3]
+            assert (numpy.append(example['x'],example['y'])==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del example,i
+
+#     - for val1,val2,val3 in dataset(field1, field2,field3):
+        i=0
+        for x,y,z in ds('x','y','z'):
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (z==array[i][0:3:2]).all()
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del x,y,z,i
+        i=0
+        for y,x in ds('y','x',):
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        del x,y,i
+
+
+#     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
+        i=0
+        for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
+            assert len(minibatch)==2
+            assert len(minibatch[0])==3
+            assert len(minibatch[1])==3
+            assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
+            i+=1
+        #assert i==#??? What shoud be the value?
+        print i
+        del minibatch,i
+        i=0
+        for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
+            assert len(minibatch)==2
+            assert len(minibatch[0])==3
+            assert len(minibatch[1])==3
+            for id in range(3):
+                assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all()
+                i+=1
+        #assert i==#??? What shoud be the value?
+        print i
+        del minibatch,i,id
+
+#     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
+        i=0
+        for x,z in ds.minibatches(['x','z'], minibatch_size=3):
+            assert len(x)==3
+            assert len(z)==3
+            assert (x[:,0:3:2]==z).all()
+            i+=1
+        #assert i==#??? What shoud be the value?
+        print i 
+        del x,z,i
+        i=0
+        for x,y in ds.minibatches(['x','y'], minibatch_size=3):
+            assert len(x)==3
+            assert len(y)==3
+            for id in range(3):
+                assert (numpy.append(x[id],y[id])==a[i]).all()
+                i+=1
+        #assert i==#??? What shoud be the value?
+        print i 
+        del x,y,i,id
+
+#not in doc
+        i=0
+        for x,y in ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4):
+            assert len(x)==3
+            assert len(y)==3
+            for id in range(3):
+                assert (numpy.append(x[id],y[id])==a[i+4]).all()
+                i+=1
+        assert i==3
+        del x,y,i,id
+
+        i=0
+        for x,y in ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4):
+            assert len(x)==3
+            assert len(y)==3
+            for id in range(3):
+                assert (numpy.append(x[id],y[id])==a[i+4]).all()
+                i+=1
+        assert i==6
+        del x,y,i,id
+
+        i=0
+        for x,y in ds.minibatches(['x','y'],n_batches=10,minibatch_size=3,offset=4):
+            assert len(x)==3
+            assert len(y)==3
+            for id in range(3):
+                assert (numpy.append(x[id],y[id])==a[i+4]).all()
+                i+=1
+        assert i==6
+        del x,y,i,id
+
+
+    def test_ds_iterator(array,iterator1,iterator2,iterator3):
+        i=0
+        for x,y in iterator1:
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+        i=0
+        for y,z in iterator2:
+            assert y==array[i][3]
+            assert (z==array[i][0:3:2]).all()
+            i+=1
+        assert i==len(ds)
+        i=0
+        for x,y,z in iterator3:
+            assert (x==array[i][:3]).all()
+            assert y==array[i][3]
+            assert (z==array[i][0:3:2]).all()
+            assert (numpy.append(x,y)==array[i]).all()
+            i+=1
+        assert i==len(ds)
+
     print "test_ArrayDataSet"
     a = numpy.random.rand(10,4)
     ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
     ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
     assert len(ds)==10
     #assert ds==a? should this work?
-
-#not in doc!!!
-    for example in range(len(ds)):
-        assert ds[example]['x'].all()==a[example][:2].all()
-        assert ds[example]['y']==a[example][3]
-        assert ds[example]['z'].all()==a[example][0:3:2].all()
+    
+    test_iterate_over_examples(a, ds)
 
-#     - for example in dataset:
-    i=0
-    for example in ds:
-        assert example['x'].all()==a[i][:2].all()
-        assert example['y']==a[i][3]
-        assert example['z'].all()==a[i][0:3:2].all()
-        assert numpy.append(example['x'],example['y']).all()==a[i].all()
-        i+=1
-    assert i==len(ds)
-#     - for val1,val2,... in dataset:
-    i=0
-    for x,y,z in ds:
-        assert x.all()==a[i][:2].all()
-        assert y==a[i][3]
-        assert z.all()==a[i][0:3:2].all()
-        assert numpy.append(example['x'],example['y']).all()==a[i].all()
-        i+=1
-    assert i==len(ds)
-#     - for example in dataset(field1, field2,field3, ...):
-    i=0
-    for example in ds('x','y','z'):
-        assert example['x'].all()==a[i][:2].all()
-        assert example['y']==a[i][3]
-        assert example['z'].all()==a[i][0:3:2].all()
-        assert numpy.append(example['x'],example['y']).all()==a[i].all()
-        i+=1
-    assert i==len(ds)
 
 #     - for val1,val2,val3 in dataset(field1, field2,field3):
-
-#     - for example in dataset(field1, field2,field3, ...):
-
-    def test_ds_iterator(iterator1,iterator2,iterator3):
-        i=0
-        for x,y in iterator1:
-            assert x.all()==a[i][:2].all()
-            assert y==a[i][3]
-            assert numpy.append(x,y).all()==a[i].all()
-            i+=1
-        assert i==len(ds)
-        i=0
-        for y,z in iterator2:
-            assert y==a[i][3]
-            assert z.all()==a[i][0:3:2].all()
-            i+=1
-        assert i==len(ds)
-        i=0
-        for x,y,z in iterator3:
-            assert x.all()==a[i][:2].all()
-            assert y==a[i][3]
-            assert z.all()==a[i][0:3:2].all()
-            assert numpy.append(x,y).all()==a[i].all()
-            i+=1
-        assert i==len(ds)
-
-#not in doc!!!     - for val1,val2,val3 in dataset(field1, field2,field3):
-    test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z'))
+    test_ds_iterator(a,ds('x','y'),ds('y','z'),ds('x','y','z'))
 
-#     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
-    for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
-        assert len(minibatch)==2
-        assert len(minibatch[0])==3
-        assert len(minibatch[1])==3
-        assert minibatch[0][:,0:3:2].all()==minibatch[1].all()
-    i=0
-    for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
-        assert len(minibatch)==2
-        assert len(minibatch[0])==3
-        assert len(minibatch[1])==3
-        for id in range(3):
-            assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
-            i+=1
-
-#     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
-    for x,z in ds.minibatches(['x','z'], minibatch_size=3):
-        assert len(x)==3
-        assert len(z)==3
-        assert x[:,0:3:2].all()==z.all()
-    i=0
-    for x,y in ds.minibatches(['x','y'], minibatch_size=3):
-        assert len(x)==3
-        assert len(y)==3
-        for id in range(3):
-            assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
-            i+=1
-#     - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict.
-#    for x,y,z in ds:
-#        assert x.all()==a[i][:2].all()
-#        assert y==a[i][3]
-#        assert z.all()==a[i][0:3:2].all()
-#        assert numpy.append(x,y).all()==a[i].all()
-#        i+=1
-
-#    for minibatch in ds.minibatches(['z','y'], minibatch_size=3):
-#        print minibatch
-#    minibatch_iterator = ds.minibatches(fieldnames=['z','y'],n_batches=1,minibatch_size=3,offset=4)
-#    minibatch = minibatch_iterator.__iter__().next()
-#    print "minibatch=",minibatch
-#    for var in minibatch:
-#        print "var=",var
-#    print "take a slice and look at field y",ds[1:6:2]["y"]
     assert have_raised("ds['h']")  # h is not defined...
     assert have_raised("ds[['h']]")  # h is not defined...
 
@@ -173,12 +246,12 @@
         i=0
         assert len(ds)==len(index)
         for x,z,y in ds('x','z','y'):
-            assert orig[index[i]]['x'].all()==a[index[i]][:3].all()
-            assert orig[index[i]]['x'].all()==x.all()
+            assert (orig[index[i]]['x']==a[index[i]][:3]).all()
+            assert (orig[index[i]]['x']==x).all()
             assert orig[index[i]]['y']==a[index[i]][3]
             assert orig[index[i]]['y']==y
-            assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all()
-            assert orig[index[i]]['z'].all()==z.all()
+            assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all()
+            assert (orig[index[i]]['z']==z).all()
             i+=1
         del i
         ds[0]
@@ -215,7 +288,7 @@
     #* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])
     #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])
 
-#    for (x,y) in (ds('x','y'),a): #don't work # haven't found a variant that work.
+#    for (x,y) in (ds('x','y'),a): #???don't work # haven't found a variant that work.
 #        assert numpy.append(x,y)==z
 
 def test_LookupList():