changeset 297:d08b71d186c8

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:52:00 -0400
parents f5d33f9c0b9c (diff) 7380376816e5 (current diff)
children 5987415496df
files _test_dataset.py
diffstat 2 files changed, 17 insertions(+), 16 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Fri Jun 06 17:11:25 2008 -0400
+++ b/_test_dataset.py	Fri Jun 06 17:52:00 2008 -0400
@@ -431,7 +431,7 @@
 
     def test_FieldsSubsetDataSet(self):
         a = numpy.random.rand(10,4)
-        ds = ArrayDataSet(a,LookupList(['x','y','z','w'],[slice(3),3,[0,2],0]))
+        ds = ArrayDataSet(a,Example(['x','y','z','w'],[slice(3),3,[0,2],0]))
         ds = FieldsSubsetDataSet(ds,['x','y','z'])
 
         test_all(a,ds)
@@ -449,7 +449,7 @@
                 class MultiLengthDataSetIterator(object):
                     def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
                         if fieldnames is None: fieldnames = dataset.fieldNames()
-                        self.minibatch = LookupList(fieldnames,range(len(fieldnames)))
+                        self.minibatch = Example(fieldnames,range(len(fieldnames)))
                         self.dataset, self.minibatch_size, self.current = dataset, minibatch_size, offset
                     def __iter__(self):
                             return self
--- a/dataset.py	Fri Jun 06 17:11:25 2008 -0400
+++ b/dataset.py	Fri Jun 06 17:52:00 2008 -0400
@@ -1204,7 +1204,7 @@
     A L{DataSet} that contains as fields the results of applying a
     given function example-wise or minibatch-wise to all the fields of
     an input dataset.  The output of the function should be an iterable
-    (e.g. a list or a Example) over the resulting values.
+    (e.g. a list or a LookupList) over the resulting values.
     
     The function take as input the fields of the dataset, not the examples.
 
@@ -1221,7 +1221,9 @@
 
     If the values_{h,v}stack functions are not provided, then
     the input_dataset.values{H,V}Stack functions are used by default.
+
     """
+
     def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
                  values_hstack=None,values_vstack=None,
                  description=None,fieldtypes=None):
@@ -1253,27 +1255,27 @@
         return self.output_names
 
     def minibatches_nowrap(self, fieldnames, *args, **kwargs):
-        for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs):
+        all_input_fieldNames = self.input_dataset.fieldNames()
+        mbnw = self.input_dataset.minibatches_nowrap
 
-            #function_inputs = self.input_iterator.next()
+        for input_fields in mbnw(all_input_fieldNames, *args, **kwargs):
             if self.minibatch_mode:
-                function_outputs = self.function(*input_fields)
+                all_output_fields = self.function(*input_fields)
             else:
-                input_examples = zip(*input_fields)
+                input_examples = zip(*input_fields) #makes so that [i] means example i
                 output_examples = [self.function(*input_example)
                                     for input_example in input_examples]
-                function_outputs = [self.valuesVStack(name,values)
-                                    for name,values in zip(self.output_names,
-                                                           zip(*output_examples))]
-            all_outputs = Example(self.output_names, function_outputs)
-            print 'input_fields', input_fields
-            print 'all_outputs', all_outputs
+                all_output_fields = zip(*output_examples)
+
+            all_outputs = Example(self.output_names, all_output_fields)
+            #print 'input_fields', input_fields
+            #print 'all_outputs', all_outputs
             if fieldnames==self.output_names:
                 rval = all_outputs
             else:
                 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames])
-            print 'rval', rval
-            print '--------'
+            #print 'rval', rval
+            #print '--------'
             yield rval
 
     def untested__iter__(self): # only implemented for increased efficiency
@@ -1295,7 +1297,6 @@
                 return Example(self.output_dataset.output_names,function_outputs)
         return ApplyFunctionSingleExampleIterator(self)
     
-
 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
     """
     Wraps an arbitrary L{DataSet} into one for supervised learning tasks