changeset 203:80731832c62b

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Thu, 15 May 2008 15:21:00 -0400
parents cb6b945acf5a (current diff) b9950ae5e54b (diff)
children c5a7105fa40b 6f55e301c687
files dataset.py
diffstat 2 files changed, 32 insertions(+), 21 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Thu May 15 12:55:21 2008 -0400
+++ b/dataset.py	Thu May 15 15:21:00 2008 -0400
@@ -1111,6 +1111,8 @@
   given function example-wise or minibatch-wise to all the fields of
   an input dataset.  The output of the function should be an iterable
   (e.g. a list or a LookupList) over the resulting values.
+  
+  The function take as input the fields of the dataset, not the examples.
 
   In minibatch mode, the function is expected to work on minibatches
   (takes a minibatch in input and returns a minibatch in output). More
@@ -1170,7 +1172,7 @@
                   function_outputs = self.output_dataset.function(*function_inputs)
               else:
                   input_examples = zip(*function_inputs)
-                  output_examples = [self.output_dataset.function(input_example)
+                  output_examples = [self.output_dataset.function(*input_example)
                                      for input_example in input_examples]
                   function_outputs = [self.output_dataset.valuesVStack(name,values)
                                       for name,values in zip(all_output_names,
@@ -1190,11 +1192,14 @@
               self.input_iterator=output_dataset.input_dataset.__iter__()
           def __iter__(self): return self
           def next(self):
-              function_inputs = self.input_iterator.next()
               if self.output_dataset.minibatch_mode:
-                  function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)]
+                  function_inputs = [[input] for input in self.input_iterator.next()]
+                  outputs = self.output_dataset.function(*function_inputs)
+                  assert all([hasattr(output,'__iter__') for output in outputs])
+                  function_outputs = [output[0] for output in outputs]
               else:
-                  function_outputs = self.output_dataset.function(function_inputs)
+                  function_inputs = self.input_iterator.next()
+                  function_outputs = self.output_dataset.function(*function_inputs)
               return Example(self.output_dataset.output_names,function_outputs)
       return ApplyFunctionSingleExampleIterator(self)
   
--- a/test_dataset.py	Thu May 15 12:55:21 2008 -0400
+++ b/test_dataset.py	Thu May 15 15:21:00 2008 -0400
@@ -394,6 +394,13 @@
     assert len(ds('y').fields()) == 1
 
     del field
+def test_all(array,ds):
+    assert len(ds)==10
+
+    test_iterate_over_examples(array, ds)
+    test_getitem(array, ds)
+    test_ds_iterator(array,ds('x','y'),ds('y','z'),ds('x','y','z'))
+    test_fields_fct(ds)
 
 def test_ArrayDataSet():
     #don't test stream
@@ -406,13 +413,9 @@
     a2 = numpy.random.rand(10,4)
     ds = ArrayDataSet(a2,{'x':slice(3),'y':3,'z':[0,2]})###???tuple not tested
     ds = ArrayDataSet(a2,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
-    assert len(ds)==10
     #assert ds==a? should this work?
 
-    test_iterate_over_examples(a2, ds)
-    test_getitem(a2, ds)
-    test_ds_iterator(a2,ds('x','y'),ds('y','z'),ds('x','y','z'))
-    test_fields_fct(ds)
+    test_all(a2,ds)
 
     del a2, ds
 
@@ -442,18 +445,9 @@
     ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
     ds2 = CachedDataSet(ds1)
     ds3 = CachedDataSet(ds1,cache_all_upon_construction=True)
-    assert len(ds2)==10
-    assert len(ds3)==10
 
-    test_iterate_over_examples(a, ds2)
-    test_getitem(a, ds2)
-    test_ds_iterator(a,ds2('x','y'),ds2('y','z'),ds2('x','y','z'))
-    test_fields_fct(ds2)
-
-    test_iterate_over_examples(a, ds3)
-    test_getitem(a, ds3)
-    test_ds_iterator(a,ds3('x','y'),ds3('y','z'),ds3('x','y','z'))
-    test_fields_fct(ds3)
+    test_all(a,ds2)
+    test_all(a,ds3)
 
     del a,ds1,ds2,ds3
 
@@ -464,7 +458,18 @@
 
 def test_ApplyFunctionDataSet():
     print "test_ApplyFunctionDataSet"
-    raise NotImplementedError()
+    a = numpy.random.rand(10,4)
+    a2 = a+1
+    ds1 = ArrayDataSet(a,LookupList(['x','y','z'],[slice(3),3,[0,2]]))###???tuple not tested
+
+    ds2 = ApplyFunctionDataSet(ds1,lambda x,y,z: (x+1,y+1,z+1), ['x','y','z'],minibatch_mode=False)
+    ds3 = ApplyFunctionDataSet(ds1,lambda x,y,z: (numpy.array(x)+1,numpy.array(y)+1,numpy.array(z)+1), ['x','y','z'],minibatch_mode=True)
+
+    test_all(a2,ds2)
+    test_all(a2,ds3)
+
+    del a,ds1,ds2,ds3
+
 def test_FieldsSubsetDataSet():
     print "test_FieldsSubsetDataSet"
     raise NotImplementedError()
@@ -485,4 +490,5 @@
     test_LookupList()
     test_ArrayDataSet()
     test_CachedDataSet()
+    test_ApplyFunctionDataSet()
 #test pmat.py