diff _test_dataset.py @ 376:c9a89be5cb0a

Redesigning linear_regression
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 07 Jul 2008 10:08:35 -0400
parents 18702ceb2096
children 82da179d95b2
line wrap: on
line diff
--- a/_test_dataset.py	Mon Jun 16 17:47:36 2008 -0400
+++ b/_test_dataset.py	Mon Jul 07 10:08:35 2008 -0400
@@ -2,7 +2,7 @@
 from dataset import *
 from math import *
 import numpy, unittest, sys
-from misc import *
+#from misc import *
 from lookup_list import LookupList
 
 def have_raised(to_eval, **var):
@@ -134,12 +134,13 @@
 #     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
     i=0
     mi=0
-    m=ds.minibatches(['x','z'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','z'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for minibatch in m:
-        assert isinstance(minibatch,DataSetFields)
+        assert isinstance(minibatch,LookupList)
         assert len(minibatch)==2
-        test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
+        test_minibatch_size(minibatch,size,len(ds),2,mi)
         if type(ds)==ArrayDataSet:
             assert (minibatch[0][:,::2]==minibatch[1]).all()
         else:
@@ -147,92 +148,103 @@
                 (minibatch[0][j][::2]==minibatch[1][j]).all()
         mi+=1
         i+=len(minibatch[0])
-    assert i==len(ds)
-    assert mi==4
-    del minibatch,i,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del minibatch,i,m,mi,size
 
     i=0
     mi=0
-    m=ds.minibatches(['x','y'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for minibatch in m:
+        assert isinstance(minibatch,LookupList)
         assert len(minibatch)==2
-        test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
+        test_minibatch_size(minibatch,size,len(ds),2,mi)
         mi+=1
         for id in range(len(minibatch[0])):
             assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all()
             i+=1
-    assert i==len(ds)
-    assert mi==4
-    del minibatch,i,id,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del minibatch,i,id,m,mi,size
 
 #     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
     i=0
     mi=0
-    m=ds.minibatches(['x','z'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','z'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for x,z in m:
-        test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
-        test_minibatch_field_size(z,m.minibatch_size,len(ds),mi)
+        test_minibatch_field_size(x,size,len(ds),mi)
+        test_minibatch_field_size(z,size,len(ds),mi)
         for id in range(len(x)):
             assert (x[id][::2]==z[id]).all()
             i+=1
         mi+=1
-    assert i==len(ds)
-    assert mi==4
-    del x,z,i,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del x,z,i,m,mi,size
+
     i=0
     mi=0
+    size=3
     m=ds.minibatches(['x','y'], minibatch_size=3)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
-        test_minibatch_field_size(y,m.minibatch_size,len(ds),mi)
+        assert len(x)==size
+        assert len(y)==size
+        test_minibatch_field_size(x,size,len(ds),mi)
+        test_minibatch_field_size(y,size,len(ds),mi)
         mi+=1
         for id in range(len(x)):
             assert (numpy.append(x[id],y[id])==array[i]).all()
             i+=1
-    assert i==len(ds)
-    assert mi==4
-    del x,y,i,id,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del x,y,i,id,m,mi,size
 
 #not in doc
     i=0
-    m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id,m
+    assert i==size
+    del x,y,i,id,m,size
 
     i=0
-    m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id,m
+    assert i==2*size
+    del x,y,i,id,m,size
 
     i=0
-    m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id
+    assert i==2*size # should not wrap
+    del x,y,i,id,size
 
-    assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
+    assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
     assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0)
 
 def test_ds_iterator(array,iterator1,iterator2,iterator3):
@@ -262,14 +274,17 @@
 def test_getitem(array,ds):
     def test_ds(orig,ds,index):
         i=0
-        assert len(ds)==len(index)
-        for x,z,y in ds('x','z','y'):
-            assert (orig[index[i]]['x']==array[index[i]][:3]).all()
-            assert (orig[index[i]]['x']==x).all()
-            assert orig[index[i]]['y']==array[index[i]][3]
-            assert (orig[index[i]]['y']==y).all() # why does it crash sometimes?
-            assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all()
-            assert (orig[index[i]]['z']==z).all()
+        assert isinstance(ds,LookupList)
+        assert len(ds)==3
+        assert len(ds[0])==len(index)
+#        for x,z,y in ds('x','z','y'):
+        for idx in index:
+            assert (orig[idx]['x']==array[idx][:3]).all()
+            assert (orig[idx]['x']==ds['x'][i]).all()
+            assert orig[idx]['y']==array[idx][3]
+            assert (orig[idx]['y']==ds['y'][i]).all() # why does it crash sometimes?
+            assert (orig[idx]['z']==array[idx][0:3:2]).all()
+            assert (orig[idx]['z']==ds['z'][i]).all()
             i+=1
         del i
         ds[0]
@@ -282,19 +297,22 @@
         for x in ds:
             pass
 
-#ds[:n] returns a dataset with the n first examples.
+#ds[:n] returns a LookupList with the n first examples.
     ds2=ds[:3]
-    assert isinstance(ds2,LookupList)
     test_ds(ds,ds2,index=[0,1,2])
     del ds2
 
-#ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
-    ds2=ds.subset[1:7:2]
-    assert isinstance(ds2,DataSet)
+#ds[i:j] returns a LookupList with examples i,i+1,...,j-1.
+    ds2=ds[1:3]
+    test_ds(ds,ds2,index=[1,2])
+    del ds2
+
+#ds[i1:i2:s] returns a LookupList with the examples i1,i1+s,...i2-s.
+    ds2=ds[1:7:2]
     test_ds(ds,ds2,[1,3,5])
     del ds2
 
-#ds[i]
+#ds[i] returns the (i+1)-th example of the dataset.
     ds2=ds[5]
     assert isinstance(ds2,Example)
     assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds)  # index not defined
@@ -302,8 +320,8 @@
     del ds2
 
 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
-    ds2=ds.subset[[4,7,2,8]]
-    assert isinstance(ds2,DataSet)
+    ds2=ds[[4,7,2,8]]
+#    assert isinstance(ds2,DataSet)
     test_ds(ds,ds2,[4,7,2,8])
     del ds2
 
@@ -326,6 +344,71 @@
     #        del i,example
     #* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#????
 
+def test_subset(array,ds):
+    def test_ds(orig,ds,index):
+        i=0
+        assert isinstance(ds2,DataSet)
+        assert len(ds)==len(index)
+        for x,z,y in ds('x','z','y'):
+            assert (orig[index[i]]['x']==array[index[i]][:3]).all()
+            assert (orig[index[i]]['x']==x).all()
+            assert orig[index[i]]['y']==array[index[i]][3]
+            assert orig[index[i]]['y']==y
+            assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all()
+            assert (orig[index[i]]['z']==z).all()
+            i+=1
+        del i
+        ds[0]
+        if len(ds)>2:
+            ds[:1]
+            ds[1:1]
+            ds[1:1:1]
+        if len(ds)>5:
+            ds[[1,2,3]]
+        for x in ds:
+            pass
+
+#ds[:n] returns a dataset with the n first examples.
+    ds2=ds.subset[:3]
+    test_ds(ds,ds2,index=[0,1,2])
+#    del ds2
+
+#ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
+    ds2=ds.subset[1:7:2]
+    test_ds(ds,ds2,[1,3,5])
+#     del ds2
+
+# #ds[i]
+#     ds2=ds.subset[5]
+#     assert isinstance(ds2,Example)
+#     assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds)  # index not defined
+#     assert not have_raised("var['ds']["+str(len(ds)-1)+"]",ds=ds)
+#     del ds2
+
+#ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
+    ds2=ds.subset[[4,7,2,8]]
+    test_ds(ds,ds2,[4,7,2,8])
+#     del ds2
+
+#ds.<property># returns the value of a property associated with
+  #the name <property>. The following properties should be supported:
+  #    - 'description': a textual description or name for the ds
+  #    - 'fieldtypes': a list of types (one per field)
+
+#* ds1 | ds2 | ds3 == ds.hstack([ds1,ds2,ds3])#????
+    #assert hstack([ds('x','y'),ds('z')])==ds
+    #hstack([ds('z','y'),ds('x')])==ds
+    assert have_raised2(hstack,[ds('x'),ds('x')])
+    assert have_raised2(hstack,[ds('y','x'),ds('x')])
+    assert not have_raised2(hstack,[ds('x'),ds('y')])
+    
+#        i=0
+#        for example in hstack([ds('x'),ds('y'),ds('z')]):
+#            example==ds[i]
+#            i+=1 
+#        del i,example
+#* ds1 & ds2 & ds3 == ds.vstack([ds1,ds2,ds3])#????
+
 def test_fields_fct(ds):
     #@todo, fill correctly
     assert len(ds.fields())==3
@@ -455,6 +538,7 @@
     test_iterate_over_examples(array, ds)
     test_overrides(ds)
     test_getitem(array, ds)
+    test_subset(array, ds)
     test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z'))
     test_fields_fct(ds)
 
@@ -515,6 +599,15 @@
 
         del a, ds
 
+    def test_RenamedFieldsDataSet(self):
+        a = numpy.random.rand(10,4)
+        ds = ArrayDataSet(a,Example(['x1','y1','z1','w1'],[slice(3),3,[0,2],0]))
+        ds = RenamedFieldsDataSet(ds,['x1','y1','z1'],['x','y','z'])
+
+        test_all(a,ds)
+
+        del a, ds
+
     def test_MinibatchDataSet(self):
         raise NotImplementedError()
     def test_HStackedDataSet(self):
@@ -570,14 +663,17 @@
         res = dsc[:]
 
 if __name__=='__main__':
-    if len(sys.argv)==2:
-        if sys.argv[1]=="--debug":
+    tests = []
+    debug=False
+    if len(sys.argv)==1:
+        unittest.main()
+    else:
+        assert sys.argv[1]=="--debug"
+        for arg in sys.argv[2:]:
+            tests.append(arg)
+        if tests:
+            unittest.TestSuite(map(T_DataSet, tests)).debug()
+        else:
             module = __import__("_test_dataset")
             tests = unittest.TestLoader().loadTestsFromModule(module)
             tests.debug()
-        print "bad argument: only --debug is accepted"
-    elif len(sys.argv)==1:
-        unittest.main()
-    else:
-        print "bad argument: only --debug is accepted"
-