diff dataset.py @ 29:46c5c90019c2

Changed apply_function so that it propagates methods of the source.
author bengioy@grenat.iro.umontreal.ca
date Fri, 11 Apr 2008 15:46:18 -0400
parents 541a273bc89f
children 438440ba0627
line wrap: on
line diff
--- a/dataset.py	Fri Apr 11 13:08:51 2008 -0400
+++ b/dataset.py	Fri Apr 11 15:46:18 2008 -0400
@@ -150,6 +150,7 @@
         of the iterators).
         """
         raise AbstractFunction()
+
         
     def merge_fields(self,*specifications):
         """
@@ -182,7 +183,7 @@
     
     def rename(self,rename_dict):
         """
-        Return a new dataset that renames fields, using a dictionnary that maps old field
+        Changes a dataset into one that renames fields, using a dictionnary that maps old field
         names to new field names. The only fields visible by the returned dataset are those
         whose names are keys of the rename_dict.
         """
@@ -194,9 +195,9 @@
         SelfRenamingDataSet.__init__(self,self,rename_dict)
         return self
         
-    def applyFunction(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True):
+    def apply_function(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True):
         """
-        Return a dataset that contains as fields the results of applying
+        Changes a dataset into one that contains as fields the results of applying
         the given function (example-wise) to the specified input_fields. The
         function should return a sequence whose elements will be stored in
         fields whose names are given in the output_fields list. If copy_inputs
@@ -209,7 +210,13 @@
         are cached (to avoid recomputation if the same examples are again
         requested).
         """
-        return ApplyFunctionDataSet(function, input_fields, output_fields, copy_inputs, accept_minibatches, cache)
+        self_class = self.__class__
+        class SelfApplyFunctionDataSet(ApplyFunctionDataSet,self_class):
+            pass
+        self.__class__ = SelfApplyFunctionDataSet
+        # set the required additional fields
+        ApplyFunctionDataSet.__init__(self,self,function, input_fields, output_fields, copy_inputs, accept_minibatches, cache)
+        return self
 
 
 class FiniteLengthDataSet(DataSet):
@@ -223,7 +230,21 @@
     def __len__(self):
         """len(dataset) returns the number of examples in the dataset."""
         raise AbstractFunction()
-    
+
+    def __call__(self,fieldname_or_fieldnames):
+        """
+        Extract one or more fields. This may be an expensive operation when the
+        dataset is large. It is not the recommanded way to access individual values
+        (use the iterators instead). If the argument is a string fieldname, then the result
+        is a sequence (iterable object) of values for that field, for the whole dataset. If the
+        argument is a list of field names, then the result is a 'batch', i.e., an Example with keys
+        corresponding to the given field names and values being iterable objects over the
+        individual example values.
+        """
+        if type(fieldname_or_fieldnames) is string:
+            minibatch = self.minibatches([fieldname_or_fieldnames],len(self)).next()
+            return minibatch[fieldname_or_fieldnames]
+        return self.minibatches(fieldname_or_fieldnames,len(self)).next()
                  
 class SliceableDataSet(DataSet):
     """
@@ -473,20 +494,6 @@
         if n_batches is None: n_batches = len(self) / minibatch_size
         return ArrayDataSet.Iterator(self, fieldnames, minibatch_size, n_batches)
 
-    def __getattr__(self,fieldname):
-        """
-        Return a numpy array with the content associated with the given field name.
-        If this is a one-example dataset, then a row, i.e., numpy array (of one less dimension
-        than the dataset itself) is returned.
-        """
-        if len(self.data)==1:
-            return self.data[0,self.fields[fieldname]]
-        return self.data[:,self.fields[fieldname]]
-
-    def __call__(self,*fieldnames):
-        """Return a sub-dataset containing only the given fieldnames as fields."""
-        return ArrayDataSet(self.data,fields=LookupList(fieldnames,[self.fields[fieldname] for fieldname in fieldnames]))
-
     def fieldNames(self):
         """Return the list of field names that are supported by getattr and hasField."""
         return self.fields.keys()
@@ -560,7 +567,7 @@
                 i=j
         return slice(start,stop,step)
     
-class ApplyFunctionDataSet(DataSet):
+class ApplyFunctionDataSet(FiniteWidthDataSet):
     """
     A dataset that contains as fields the results of applying
     a given function (example-wise) to specified input_fields of a source
@@ -603,6 +610,11 @@
             # in the case where src is FiniteDataSet. -YB
             self.cached_examples = []
 
+    def fieldNames(self):
+        if self.copy_inputs:
+            return self.output_fields + self.src.fieldNames()
+        return self.output_fields
+    
     def minibatches(self,
                     fieldnames = DataSet.minibatches_fieldnames,
                     minibatch_size = DataSet.minibatches_minibatch_size,