changeset 296:f5d33f9c0b9c

ApplyFunctionDataSet passing
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:50:29 -0400
parents f7924e13e426
children d08b71d186c8
files dataset.py
diffstat 1 files changed, 15 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Fri Jun 06 16:15:47 2008 -0400
+++ b/dataset.py	Fri Jun 06 17:50:29 2008 -0400
@@ -1204,7 +1204,7 @@
     A L{DataSet} that contains as fields the results of applying a
     given function example-wise or minibatch-wise to all the fields of
     an input dataset.  The output of the function should be an iterable
-    (e.g. a list or a Example) over the resulting values.
+    (e.g. a list or a LookupList) over the resulting values.
     
     The function take as input the fields of the dataset, not the examples.
 
@@ -1221,7 +1221,9 @@
 
     If the values_{h,v}stack functions are not provided, then
     the input_dataset.values{H,V}Stack functions are used by default.
+
     """
+
     def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
                  values_hstack=None,values_vstack=None,
                  description=None,fieldtypes=None):
@@ -1253,27 +1255,27 @@
         return self.output_names
 
     def minibatches_nowrap(self, fieldnames, *args, **kwargs):
-        for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs):
+        all_input_fieldNames = self.input_dataset.fieldNames()
+        mbnw = self.input_dataset.minibatches_nowrap
 
-            #function_inputs = self.input_iterator.next()
+        for input_fields in mbnw(all_input_fieldNames, *args, **kwargs):
             if self.minibatch_mode:
-                function_outputs = self.function(*input_fields)
+                all_output_fields = self.function(*input_fields)
             else:
-                input_examples = zip(*input_fields)
+                input_examples = zip(*input_fields) #makes so that [i] means example i
                 output_examples = [self.function(*input_example)
                                     for input_example in input_examples]
-                function_outputs = [self.valuesVStack(name,values)
-                                    for name,values in zip(self.output_names,
-                                                           zip(*output_examples))]
-            all_outputs = Example(self.output_names, function_outputs)
-            print 'input_fields', input_fields
-            print 'all_outputs', all_outputs
+                all_output_fields = zip(*output_examples)
+
+            all_outputs = Example(self.output_names, all_output_fields)
+            #print 'input_fields', input_fields
+            #print 'all_outputs', all_outputs
             if fieldnames==self.output_names:
                 rval = all_outputs
             else:
                 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames])
-            print 'rval', rval
-            print '--------'
+            #print 'rval', rval
+            #print '--------'
             yield rval
 
     def untested__iter__(self): # only implemented for increased efficiency
@@ -1295,7 +1297,6 @@
                 return Example(self.output_dataset.output_names,function_outputs)
         return ApplyFunctionSingleExampleIterator(self)
     
-
 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
     """
     Wraps an arbitrary L{DataSet} into one for supervised learning tasks