changeset 93:a62c79ec7c8a

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 05 May 2008 18:14:44 -0400
parents eee739fefdff (diff) c4726e19b8ec (current diff)
children 05cfe011ca20
files
diffstat 2 files changed, 136 insertions(+), 51 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Mon May 05 18:14:32 2008 -0400
+++ b/dataset.py	Mon May 05 18:14:44 2008 -0400
@@ -27,8 +27,10 @@
     feasible or not recommanded on streams).
 
     To iterate over examples, there are several possibilities:
-     - for example in dataset([field1, field2,field3, ...]):
-     - for val1,val2,val3 in dataset([field1, field2,field3]):
+     - for example in dataset:
+     - for val1,val2,... in dataset:
+     - for example in dataset(field1, field2,field3, ...):
+     - for val1,val2,val3 in dataset(field1, field2,field3):
      - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
      - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
      - for example in dataset::
@@ -610,7 +612,7 @@
     
 class MinibatchDataSet(DataSet):
     """
-    Turn a LookupList of same-length fields into an example-iterable dataset.
+    Turn a LookupList of same-length (iterable) fields into an example-iterable dataset.
     Each element of the lookup-list should be an iterable and sliceable, all of the same length.
     """
     def __init__(self,fields_lookuplist,values_vstack=DataSet().valuesVStack,
@@ -636,9 +638,7 @@
             return DataSetFields(MinibatchDataSet(
                 Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames())
         if type(i) is int:
-            return DataSetFields(MinibatchDataSet(
-                Example(self._fields.keys(),[[field[i]] for field in self._fields])),self.fieldNames())
-
+            return Example(self._fields.keys(),[field[i] for field in self._fields])
         if self.hasFields(i):
             return self._fields[i]
         assert i in self.__dict__ # else it means we are trying to access a non-existing property
@@ -954,7 +954,13 @@
                 if self.hasFields(key[i]):
                     key[i]=self.fields_columns[key[i]]
             return MinibatchDataSet(Example(fieldnames,
-                                            [self.data[key,self.fields_columns[f]] for f in fieldnames]),
+                                            #we must separate differently for list as numpy
+                                            # don't support self.data[[i1,...],[i2,...]]
+                                            # when their is more then two i1 and i2
+                                            [self.data[key,:][:,self.fields_columns[f]]
+                                             if isinstance(self.fields_columns[f],list) else
+                                             self.data[key,self.fields_columns[f]]
+                                             for f in fieldnames]),
                                     self.valuesVStack,self.valuesHStack)
 
         # else check for a fieldname
--- a/test_dataset.py	Mon May 05 18:14:32 2008 -0400
+++ b/test_dataset.py	Mon May 05 18:14:44 2008 -0400
@@ -40,26 +40,109 @@
 
     print "test_ArrayDataSet"
     a = numpy.random.rand(10,4)
-    ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})
+    ds = ArrayDataSet(a,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
+    ds = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
     assert len(ds)==10
     #assert ds==a? should this work?
-    for i in range(len(ds)):
-        assert ds[i]['x'].all()==a[i][:2].all()
-        assert ds[i]['y']==a[i][3]
-        assert ds[i]['z'].all()==a[i][0:3:2].all()
+
+#not in doc!!!
+    for example in range(len(ds)):
+        assert ds[example]['x'].all()==a[example][:2].all()
+        assert ds[example]['y']==a[example][3]
+        assert ds[example]['z'].all()==a[example][0:3:2].all()
+
+#     - for example in dataset:
+    i=0
+    for example in ds:
+        assert example['x'].all()==a[i][:2].all()
+        assert example['y']==a[i][3]
+        assert example['z'].all()==a[i][0:3:2].all()
+        assert numpy.append(example['x'],example['y']).all()==a[i].all()
+        i+=1
+    assert i==len(ds)
+#     - for val1,val2,... in dataset:
     i=0
-    for x in ds('x','y'):
-        assert numpy.append(x['x'],x['y']).all()==a[i].all()
+    for x,y,z in ds:
+        assert x.all()==a[i][:2].all()
+        assert y==a[i][3]
+        assert z.all()==a[i][0:3:2].all()
+        assert numpy.append(example['x'],example['y']).all()==a[i].all()
         i+=1
+    assert i==len(ds)
+#     - for example in dataset(field1, field2,field3, ...):
+    i=0
+    for example in ds('x','y','z'):
+        assert example['x'].all()==a[i][:2].all()
+        assert example['y']==a[i][3]
+        assert example['z'].all()==a[i][0:3:2].all()
+        assert numpy.append(example['x'],example['y']).all()==a[i].all()
+        i+=1
+    assert i==len(ds)
+
+#     - for val1,val2,val3 in dataset(field1, field2,field3):
+
+#     - for example in dataset(field1, field2,field3, ...):
 
+    def test_ds_iterator(iterator1,iterator2,iterator3):
+        i=0
+        for x,y in iterator1:
+            assert x.all()==a[i][:2].all()
+            assert y==a[i][3]
+            assert numpy.append(x,y).all()==a[i].all()
+            i+=1
+        assert i==len(ds)
+        i=0
+        for y,z in iterator2:
+            assert y==a[i][3]
+            assert z.all()==a[i][0:3:2].all()
+            i+=1
+        assert i==len(ds)
+        i=0
+        for x,y,z in iterator3:
+            assert x.all()==a[i][:2].all()
+            assert y==a[i][3]
+            assert z.all()==a[i][0:3:2].all()
+            assert numpy.append(x,y).all()==a[i].all()
+            i+=1
+        assert i==len(ds)
+
+#not in doc!!!     - for val1,val2,val3 in dataset(field1, field2,field3):
+    test_ds_iterator(ds('x','y'),ds('y','z'),ds('x','y','z'))
+
+#     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N):
+    for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
+        assert len(minibatch)==2
+        assert len(minibatch[0])==3
+        assert len(minibatch[1])==3
+        assert minibatch[0][:,0:3:2].all()==minibatch[1].all()
     i=0
-    for x,y in ds('x','y'):
-        assert numpy.append(x,y).all()==a[i].all()
-        i+=1
-    for minibatch in ds.minibatches(['x','z'], minibatch_size=3):
-        assert minibatch[0][:,0:3:2].all()==minibatch[1].all()
+    for minibatch in ds.minibatches(['x','y'], minibatch_size=3):
+        assert len(minibatch)==2
+        assert len(minibatch[0])==3
+        assert len(minibatch[1])==3
+        for id in range(3):
+            assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
+            i+=1
+
+#     - for mini1,mini2,mini3 in dataset.minibatches([field1, field2, field3], minibatch_size=N):
     for x,z in ds.minibatches(['x','z'], minibatch_size=3):
+        assert len(x)==3
+        assert len(z)==3
         assert x[:,0:3:2].all()==z.all()
+    i=0
+    for x,y in ds.minibatches(['x','y'], minibatch_size=3):
+        assert len(x)==3
+        assert len(y)==3
+        for id in range(3):
+            assert numpy.append(minibatch[0][id],minibatch[1][id]).all()==a[i].all()
+            i+=1
+#     - for x,y,z in dataset: # fail x,y,z order not fixed as it is a dict.
+#    for x,y,z in ds:
+#        assert x.all()==a[i][:2].all()
+#        assert y==a[i][3]
+#        assert z.all()==a[i][0:3:2].all()
+#        assert numpy.append(x,y).all()==a[i].all()
+#        i+=1
 
 #    for minibatch in ds.minibatches(['z','y'], minibatch_size=3):
 #        print minibatch
@@ -86,43 +169,41 @@
 
     assert ds == ds.fields().examples()
 
+    def test_ds(orig,ds,index):
+        i=0
+        assert len(ds)==len(index)
+        for x,z,y in ds('x','z','y'):
+            assert orig[index[i]]['x'].all()==a[index[i]][:3].all()
+            assert orig[index[i]]['x'].all()==x.all()
+            assert orig[index[i]]['y']==a[index[i]][3]
+            assert orig[index[i]]['y']==y
+            assert orig[index[i]]['z'].all()==a[index[i]][0:3:2].all()
+            assert orig[index[i]]['z'].all()==z.all()
+            i+=1
+        del i
+        ds[0]
+        if len(ds)>2:
+            ds[:1]
+            ds[1:1]
+            ds[1:1:1]
+        if len(ds)>5:
+            ds[[1,2,3]]
+        for x in ds:
+            pass
 
     #ds[:n] returns a dataset with the n first examples.
-    assert len(ds[:3])==3
-    i=0
-    for x,z,y in ds[:3]('x','z','y'):
-        assert ds[i]['x'].all()==a[i][:3].all()
-        assert ds[i]['x'].all()==x.all()
-        assert ds[i]['y']==a[i][3]
-        assert ds[i]['y']==y
-        assert ds[i]['z'].all()==a[i][0:3:2].all()
-        assert ds[i]['z'].all()==z.all()
-        i+=1
-    i=0
-    for x,z in ds[:3]('x','z'):
-        assert ds[i]['z'].all()==a[i][0:3:2].all()
-        i+=1
+    ds2=ds[:3]
+    test_ds(ds,ds2,index=[0,1,2])
 
     #ds[i1:i2:s]# returns a ds with the examples i1,i1+s,...i2-s.
-    ds[1:7:2][1]
-    assert len(ds[1:7:2])==3 # should be number example 1,3 and 5
-    i=0
-    index=[1,3,5]
-    for z,y,x in ds[1:7:2]('z','y','x'):
-        assert ds[index[i]]['x'].all()==a[index[i]][:3].all()
-        assert ds[index[i]]['x'].all()==x.all()
-        assert ds[index[i]]['y']==a[index[i]][3]
-        assert ds[index[i]]['y']==y
-        assert ds[index[i]]['z'].all()==a[index[i]][0:3:2].all()
-        assert ds[index[i]]['z'].all()==z.all()
-        i+=1
+    ds2=ds[1:7:2]
+    test_ds(ds,ds2,[1,3,5])
 
     #ds[[i1,i2,...in]]# returns a ds with examples i1,i2,...in.
-    i=0
-    for x in ds[[1,2]]:
-        assert numpy.append(x['x'],x['y']).all()==a[i].all()
-        i+=1
+    ds2=ds[[4,7,2,8]]
+    test_ds(ds,ds2,[4,7,2,8])
     #ds[i1,i2,...]# should we accept????
+
     #ds[fieldname]# an iterable over the values of the field fieldname across
       #the ds (the iterable is obtained by default by calling valuesVStack
       #over the values for individual examples).
@@ -184,5 +265,3 @@
 test_LookupList()
 test_ArrayDataSet()
 
-
-