changeset 346:9de4274ad5ba

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Tue, 17 Jun 2008 14:33:15 -0400
parents 2259f6fa4959 (diff) 4efb503fd0da (current diff)
children 3aa9e5a5802a
files _test_dataset.py
diffstat 2 files changed, 87 insertions(+), 64 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Tue Jun 17 14:32:54 2008 -0400
+++ b/_test_dataset.py	Tue Jun 17 14:33:15 2008 -0400
@@ -134,12 +134,13 @@
 #     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
     i=0
     mi=0
-    m=ds.minibatches(['x','z'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','z'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for minibatch in m:
-        assert isinstance(minibatch,DataSetFields)
+        assert isinstance(minibatch,LookupList)
         assert len(minibatch)==2
-        test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
+        test_minibatch_size(minibatch,size,len(ds),2,mi)
         if type(ds)==ArrayDataSet:
             assert (minibatch[0][:,::2]==minibatch[1]).all()
         else:
@@ -147,92 +148,103 @@
                 (minibatch[0][j][::2]==minibatch[1][j]).all()
         mi+=1
         i+=len(minibatch[0])
-    assert i==len(ds)
-    assert mi==4
-    del minibatch,i,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del minibatch,i,m,mi,size
 
     i=0
     mi=0
-    m=ds.minibatches(['x','y'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for minibatch in m:
+        assert isinstance(minibatch,LookupList)
         assert len(minibatch)==2
-        test_minibatch_size(minibatch,m.minibatch_size,len(ds),2,mi)
+        test_minibatch_size(minibatch,size,len(ds),2,mi)
         mi+=1
         for id in range(len(minibatch[0])):
             assert (numpy.append(minibatch[0][id],minibatch[1][id])==array[i]).all()
             i+=1
-    assert i==len(ds)
-    assert mi==4
-    del minibatch,i,id,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del minibatch,i,id,m,mi,size
 
 #     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
     i=0
     mi=0
-    m=ds.minibatches(['x','z'], minibatch_size=3)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','z'], minibatch_size=size)
+    assert hasattr(m,'__iter__')
     for x,z in m:
-        test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
-        test_minibatch_field_size(z,m.minibatch_size,len(ds),mi)
+        test_minibatch_field_size(x,size,len(ds),mi)
+        test_minibatch_field_size(z,size,len(ds),mi)
         for id in range(len(x)):
             assert (x[id][::2]==z[id]).all()
             i+=1
         mi+=1
-    assert i==len(ds)
-    assert mi==4
-    del x,z,i,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del x,z,i,m,mi,size
+
     i=0
     mi=0
+    size=3
     m=ds.minibatches(['x','y'], minibatch_size=3)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        test_minibatch_field_size(x,m.minibatch_size,len(ds),mi)
-        test_minibatch_field_size(y,m.minibatch_size,len(ds),mi)
+        assert len(x)==size
+        assert len(y)==size
+        test_minibatch_field_size(x,size,len(ds),mi)
+        test_minibatch_field_size(y,size,len(ds),mi)
         mi+=1
         for id in range(len(x)):
             assert (numpy.append(x[id],y[id])==array[i]).all()
             i+=1
-    assert i==len(ds)
-    assert mi==4
-    del x,y,i,id,m,mi
+    assert i==(len(ds)/size)*size
+    assert mi==(len(ds)/size)
+    del x,y,i,id,m,mi,size
 
 #not in doc
     i=0
-    m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=1,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id,m
+    assert i==size
+    del x,y,i,id,m,size
 
     i=0
-    m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=2,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[i+4]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id,m
+    assert i==2*size
+    del x,y,i,id,m,size
 
     i=0
-    m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=3,offset=4)
-    assert isinstance(m,DataSet.MinibatchWrapAroundIterator)
+    size=3
+    m=ds.minibatches(['x','y'],n_batches=20,minibatch_size=size,offset=4)
+    assert hasattr(m,'__iter__')
     for x,y in m:
-        assert len(x)==m.minibatch_size
-        assert len(y)==m.minibatch_size
-        for id in range(m.minibatch_size):
+        assert len(x)==size
+        assert len(y)==size
+        for id in range(size):
             assert (numpy.append(x[id],y[id])==array[(i+4)%array.shape[0]]).all()
             i+=1
-    assert i==m.n_batches*m.minibatch_size
-    del x,y,i,id
+    assert i==2*size # should not wrap
+    del x,y,i,id,size
 
-    assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
+    assert have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array)+1,offset=0)
     assert not have_raised2(ds.minibatches,['x','y'],n_batches=1,minibatch_size=len(array),offset=0)
 
 def test_ds_iterator(array,iterator1,iterator2,iterator3):
@@ -262,14 +274,17 @@
 def test_getitem(array,ds):
     def test_ds(orig,ds,index):
         i=0
-        assert len(ds)==len(index)
-        for x,z,y in ds('x','z','y'):
-            assert (orig[index[i]]['x']==array[index[i]][:3]).all()
-            assert (orig[index[i]]['x']==x).all()
-            assert orig[index[i]]['y']==array[index[i]][3]
-            assert (orig[index[i]]['y']==y).all() # why does it crash sometimes?
-            assert (orig[index[i]]['z']==array[index[i]][0:3:2]).all()
-            assert (orig[index[i]]['z']==z).all()
+        assert isinstance(ds,LookupList)
+        assert len(ds)==3
+        assert len(ds[0])==len(index)
+#        for x,z,y in ds('x','z','y'):
+        for idx in index:
+            assert (orig[idx]['x']==array[idx][:3]).all()
+            assert (orig[idx]['x']==ds['x'][i]).all()
+            assert orig[idx]['y']==array[idx][3]
+            assert (orig[idx]['y']==ds['y'][i]).all() # why does it crash sometimes?
+            assert (orig[idx]['z']==array[idx][0:3:2]).all()
+            assert (orig[idx]['z']==ds['z'][i]).all()
             i+=1
         del i
         ds[0]
@@ -282,19 +297,22 @@
         for x in ds:
             pass
 
-#ds[:n] returns a dataset with the n first examples.
+#ds[:n] returns a LookupList with the n first examples.
     ds2=ds[:3]
-    assert isinstance(ds2,LookupList)
     test_ds(ds,ds2,index=[0,1,2])
     del ds2
 
-#ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
-    ds2=ds.subset[1:7:2]
-    assert isinstance(ds2,DataSet)
+#ds[i:j] returns a LookupList with examples i,i+1,...,j-1.
+    ds2=ds[1:3]
+    test_ds(ds,ds2,index=[1,2])
+    del ds2
+
+#ds[i1:i2:s] returns a LookupList with the examples i1,i1+s,...i2-s.
+    ds2=ds[1:7:2]
     test_ds(ds,ds2,[1,3,5])
     del ds2
 
-#ds[i]
+#ds[i] returns the (i+1)-th example of the dataset.
     ds2=ds[5]
     assert isinstance(ds2,Example)
     assert have_raised("var['ds']["+str(len(ds))+"]",ds=ds)  # index not defined
@@ -302,8 +320,8 @@
     del ds2
 
 #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
-    ds2=ds.subset[[4,7,2,8]]
-    assert isinstance(ds2,DataSet)
+    ds2=ds[[4,7,2,8]]
+#    assert isinstance(ds2,DataSet)
     test_ds(ds,ds2,[4,7,2,8])
     del ds2
 
--- a/dataset.py	Tue Jun 17 14:32:54 2008 -0400
+++ b/dataset.py	Tue Jun 17 14:33:15 2008 -0400
@@ -721,7 +721,12 @@
         assert self.hasFields(*fieldnames)
         return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)
     def __getitem__(self,i):
-        return FieldsSubsetDataSet(self.src[i],self.new_fieldnames)
+#        return FieldsSubsetDataSet(self.src[i],self.new_fieldnames)
+        complete_example = self.src[i]
+        return Example(self.new_fieldnames,
+                             [complete_example[field]
+                              for field in self.src_fieldnames])
+
 
 
 class DataSetFields(Example):