changeset 91:eee739fefdff

corrected test from discution about the syntax with Yoshua
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 05 May 2008 17:13:53 -0400
parents a289b8bed64c
children a62c79ec7c8a 9c8f3c9c247b
files test_dataset.py
diffstat 1 files changed, 25 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/test_dataset.py	Mon May 05 17:13:07 2008 -0400
+++ b/test_dataset.py	Mon May 05 17:13:53 2008 -0400
@@ -41,7 +41,7 @@
     print "test_ArrayDataSet"
     a = numpy.random.rand(10,4)
     ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
-
+    ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
     assert len(ds)==10
     #assert ds==a? should this work?
 
@@ -50,7 +50,8 @@
         assert ds[example]['x'].all()==a[example][:2].all()
         assert ds[example]['y']==a[example][3]
         assert ds[example]['z'].all()==a[example][0:3:2].all()
-#     - for example in dataset::
+
+#     - for example in dataset:
     i=0
     for example in ds:
         assert example['x'].all()==a[i][:2].all()
@@ -59,6 +60,28 @@
         assert numpy.append(example['x'],example['y']).all()==a[i].all()
         i+=1
     assert i==len(ds)
+#     - for val1,val2,... in dataset:
+    i=0
+    for x,y,z in ds:
+        assert x.all()==a[i][:2].all()
+        assert y==a[i][3]
+        assert z.all()==a[i][0:3:2].all()
+        assert numpy.append(example['x'],example['y']).all()==a[i].all()
+        i+=1
+    assert i==len(ds)
+#     - for example in dataset(field1, field2,field3, ...):
+    i=0
+    for example in ds('x','y','z'):
+        assert example['x'].all()==a[i][:2].all()
+        assert example['y']==a[i][3]
+        assert example['z'].all()==a[i][0:3:2].all()
+        assert numpy.append(example['x'],example['y']).all()==a[i].all()
+        i+=1
+    assert i==len(ds)
+
+#     - for val1,val2,val3 in dataset(field1, field2,field3):
+
+#     - for example in dataset(field1, field2,field3, ...):
 
     def test_ds_iterator(iterator1,iterator2,iterator3):
         i=0
@@ -86,12 +109,6 @@
 #not in doc!!!     - for val1,val2,val3 in dataset(field1, field2,field3):
     test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z'))
 
-#not in doc!!!     - for val1,val2,val3 in dataset((field1, field2,field3)):
-    test_ds_iterator(ds(('x','y')),ds(('y','z')),ds(('x','y','z')))
-
-#     - for val1,val2,val3 in dataset([field1, field2,field3]): #was bugged
-    test_ds_iterator(ds(['x','y']),ds(['y','z']),ds(['x','y','z']))
-    
 #     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
     for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
         assert len(minibatch)==2