changeset 94:9c8f3c9c247b

corrected the use of .all()
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 17:44:49 -0400
parents eee739fefdff
children 6fe972a7393c
files test_dataset.py
diffstat 1 files changed, 28 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/test_dataset.py	Mon May 05 17:13:53 2008 -0400
+++ b/test_dataset.py	Mon May 05 17:44:49 2008 -0400
@@ -47,35 +47,35 @@
 
 #not in doc!!!
     for example in range(len(ds)):
-        assert ds[example]['x'].all()==a[example][:2].all()
+        assert (ds[example]['x']==a[example][:3]).all()
         assert ds[example]['y']==a[example][3]
-        assert ds[example]['z'].all()==a[example][0:3:2].all()
+        assert (ds[example]['z']==a[example][[0,2]]).all()
 
 #     - for example in dataset:
     i=0
     for example in ds:
-        assert example['x'].all()==a[i][:2].all()
+        assert (example['x']==a[i][:3]).all()
         assert example['y']==a[i][3]
-        assert example['z'].all()==a[i][0:3:2].all()
-        assert numpy.append(example['x'],example['y']).all()==a[i].all()
+        assert (example['z']==a[i][0:3:2]).all()
+        assert (numpy.append(example['x'],example['y'])==a[i]).all()
         i+=1
     assert i==len(ds)
 #     - for val1,val2,... in dataset:
     i=0
     for x,y,z in ds:
-        assert x.all()==a[i][:2].all()
+        assert (x==a[i][:3]).all()
         assert y==a[i][3]
-        assert z.all()==a[i][0:3:2].all()
-        assert numpy.append(example['x'],example['y']).all()==a[i].all()
+        assert (z==a[i][0:3:2]).all()
+        assert (numpy.append(x,y)==a[i]).all()
         i+=1
     assert i==len(ds)
 #     - for example in dataset(field1, field2,field3, ...):
     i=0
     for example in ds('x','y','z'):
-        assert example['x'].all()==a[i][:2].all()
+        assert (example['x']==a[i][:3]).all()
         assert example['y']==a[i][3]
-        assert example['z'].all()==a[i][0:3:2].all()
-        assert numpy.append(example['x'],example['y']).all()==a[i].all()
+        assert (example['z']==a[i][0:3:2]).all()
+        assert (numpy.append(example['x'],example['y'])==a[i]).all()
         i+=1
     assert i==len(ds)
 
@@ -86,23 +86,23 @@
     def test_ds_iterator(iterator1,iterator2,iterator3):
         i=0
         for x,y in iterator1:
-            assert x.all()==a[i][:2].all()
+            assert (x==a[i][:3]).all()
             assert y==a[i][3]
-            assert numpy.append(x,y).all()==a[i].all()
+            assert (numpy.append(x,y)==a[i]).all()
             i+=1
         assert i==len(ds)
         i=0
         for y,z in iterator2:
             assert y==a[i][3]
-            assert z.all()==a[i][0:3:2].all()
+            assert (z==a[i][0:3:2]).all()
             i+=1
         assert i==len(ds)
         i=0
         for x,y,z in iterator3:
-            assert x.all()==a[i][:2].all()
+            assert (x==a[i][:3]).all()
             assert y==a[i][3]
-            assert z.all()==a[i][0:3:2].all()
-            assert numpy.append(x,y).all()==a[i].all()
+            assert (z==a[i][0:3:2]).all()
+            assert (numpy.append(x,y)==a[i]).all()
             i+=1
         assert i==len(ds)
 
@@ -114,34 +114,34 @@
         assert len(minibatch)==2
         assert len(minibatch[0])==3
         assert len(minibatch[1])==3
-        assert minibatch[0][:,0:3:2].all()==minibatch[1].all()
+        assert (minibatch[0][:,0:3:2]==minibatch[1]).all()
     i=0
     for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
         assert len(minibatch)==2
         assert len(minibatch[0])==3
         assert len(minibatch[1])==3
         for id in range(3):
-            assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
+            assert (numpy.append(minibatch[0][id],minibatch[1][id])==a[i]).all()
             i+=1
 
 #     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
     for x,z in ds.minibatches(['x','z'], minibatch_size=3):
         assert len(x)==3
         assert len(z)==3
-        assert x[:,0:3:2].all()==z.all()
+        assert (x[:,0:3:2]==z).all()
     i=0
     for x,y in ds.minibatches(['x','y'], minibatch_size=3):
         assert len(x)==3
         assert len(y)==3
         for id in range(3):
-            assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
+            assert (numpy.append(x[id],y[id])==a[i]).all()
             i+=1
 #     - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict.
 #    for x,y,z in ds:
-#        assert x.all()==a[i][:2].all()
+#        assert (x==a[i][:2]).all()
 #        assert y==a[i][3]
-#        assert z.all()==a[i][0:3:2].all()
-#        assert numpy.append(x,y).all()==a[i].all()
+#        assert (z==a[i][0:3:2]).all()
+#        assert (numpy.append(x,y)==a[i]).all()
 #        i+=1
 
 #    for minibatch in ds.minibatches(['z','y'], minibatch_size=3):
@@ -173,12 +173,12 @@
         i=0
         assert len(ds)==len(index)
         for x,z,y in ds('x','z','y'):
-            assert orig[index[i]]['x'].all()==a[index[i]][:3].all()
-            assert orig[index[i]]['x'].all()==x.all()
+            assert (orig[index[i]]['x']==a[index[i]][:3]).all()
+            assert (orig[index[i]]['x']==x).all()
             assert orig[index[i]]['y']==a[index[i]][3]
             assert orig[index[i]]['y']==y
-            assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all()
-            assert orig[index[i]]['z'].all()==z.all()
+            assert (orig[index[i]]['z']==a[index[i]][0:3:2]).all()
+            assert (orig[index[i]]['z']==z).all()
             i+=1
         del i
         ds[0]