diff dataset.py @ 26:672fe4b23032

Fixed dataset errors so that _test_dataset.py works again.
author bengioy@grenat.iro.umontreal.ca
date Fri, 11 Apr 2008 11:14:54 -0400
parents 526e192b0699
children 541a273bc89f
line wrap: on
line diff
--- a/dataset.py	Wed Apr 09 18:27:13 2008 -0400
+++ b/dataset.py	Fri Apr 11 11:14:54 2008 -0400
@@ -1,6 +1,7 @@
 
 from lookup_list import LookupList
 Example = LookupList
+import copy
 
 class AbstractFunction (Exception): """Derived class must override this function"""
         
@@ -142,7 +143,7 @@
         """
         raise AbstractFunction()
 
-    def hasFields(*fieldnames):
+    def hasFields(self,*fieldnames):
         """
         Return true if the given field name (or field names, if multiple arguments are
         given) is recognized by the DataSet (i.e. can be used as a field name in one
@@ -150,7 +151,7 @@
         """
         raise AbstractFunction()
         
-    def merge_fields(*specifications):
+    def merge_fields(self,*specifications):
         """
         Return a new dataset that maps old fields (of self) to new fields (of the returned 
         dataset). The minimal syntax that should be supported is the following:
@@ -162,7 +163,7 @@
         """
         raise AbstractFunction()
 
-    def merge_field_values(*field_value_pairs)
+    def merge_field_values(self,*field_value_pairs):
         """
         Return the value that corresponds to merging the values of several fields,
         given as arguments (field_name, field_value) pairs with self.hasField(field_name).
@@ -172,22 +173,28 @@
         fieldnames,fieldvalues = zip(*field_value_pairs)
         raise ValueError("Unable to merge values of these fields:"+repr(fieldnames))
 
-    def examples2minibatch(examples):
+    def examples2minibatch(self,examples):
         """
         Combine a list of Examples into a minibatch. A minibatch is an Example whose fields
         are iterable over the examples of the minibatch.
         """
         raise AbstractFunction()
     
-    def rename(rename_dict):
+    def rename(self,rename_dict):
         """
         Return a new dataset that renames fields, using a dictionnary that maps old field
         names to new field names. The only fields visible by the returned dataset are those
         whose names are keys of the rename_dict.
         """
-        return RenamingDataSet(self,rename_dict)
-
-    def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True):
+        self_class = self.__class__
+        class SelfRenamingDataSet(RenamingDataSet,self_class):
+            pass
+        self.__class__ = SelfRenamingDataSet
+        # set the rename_dict and src fields
+        SelfRenamingDataSet.__init__(self,self,rename_dict)
+        return self
+        
+    def applyFunction(self,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True):
         """
         Return a dataset that contains as fields the results of applying
         the given function (example-wise) to the specified input_fields. The
@@ -204,25 +211,6 @@
         """
         return ApplyFunctionDataSet(function, input_fields, output_fields, copy_inputs, accept_minibatches, cache)
 
-class RenamingDataSet(DataSet):
-    """A DataSet that wraps another one, and makes it look like the field names
-    are different
-
-    Renaming is done by a dictionary that maps new names to the old ones used in
-    self.src.
-    """
-    def __init__(self, src, rename_dct):
-        DataSet.__init__(self)
-        self.src = src
-        self.rename_dct = copy.copy(rename_dct)
-
-    def minibatches(self,
-            fieldnames = DataSet.minibatches_fieldnames,
-            minibatch_size = DataSet.minibatches_minibatch_size,
-            n_batches = DataSet.minibatches_n_batches):
-        dct = self.rename_dct
-        new_fieldnames = [dct.get(f, f) for f in fieldnames]
-        return self.src.minibatches(new_fieldnames, minibatches_size, n_batches)
 
 class FiniteLengthDataSet(DataSet):
     """
@@ -278,10 +266,11 @@
     def __init__(self):
         DataSet.__init__(self)
 
-    def hasFields(*fieldnames):
+    def hasFields(self,*fields):
         has_fields=True
-        for fieldname in fieldnames:
-            if fieldname not in self.fields.keys():
+        fieldnames = self.fieldNames()
+        for name in fields:
+            if name not in fieldnames:
                 has_fields=False
         return has_fields
                 
@@ -291,6 +280,30 @@
         raise AbstractFunction()
 
 
+class RenamingDataSet(FiniteWidthDataSet):
+    """A DataSet that wraps another one, and makes it look like the field names
+    are different
+
+    Renaming is done by a dictionary that maps new names to the old ones used in
+    self.src.
+    """
+    def __init__(self, src, rename_dct):
+        DataSet.__init__(self)
+        self.src = src
+        self.rename_dct = copy.copy(rename_dct)
+
+    def fieldNames(self):
+        return self.rename_dct.keys()
+    
+    def minibatches(self,
+            fieldnames = DataSet.minibatches_fieldnames,
+            minibatch_size = DataSet.minibatches_minibatch_size,
+            n_batches = DataSet.minibatches_n_batches):
+        dct = self.rename_dct
+        new_fieldnames = [dct.get(f, f) for f in fieldnames]
+        return self.src.minibatches(new_fieldnames, minibatches_size, n_batches)
+
+
 # we may want ArrayDataSet defined in another python file
 
 import numpy
@@ -548,19 +561,6 @@
             c+=slice_width
         return result
 
-    def rename(*new_field_specifications):
-        """
-        Return a new dataset that maps old fields (of self) to new fields (of the returned 
-        dataset). The minimal syntax that should be supported is the following:
-           new_field_specifications = [new_field_spec1, new_field_spec2, ...]
-           new_field_spec = ([old_field1, old_field2, ...], new_field)
-        In general both old_field and new_field should be strings, but some datasets may also
-        support additional indexing schemes within each field (e.g. column slice
-        of a matrix-like field).
-        """
-        # if all old fields of each spec are 
-        raise NotImplementedError()
-
 class ApplyFunctionDataSet(DataSet):
     """
     A dataset that contains as fields the results of applying
@@ -599,7 +599,7 @@
             else:
                 # compute a list with one tuple per example, with the function outputs
                 self.cached_examples = [ function(input) for input in src.zip(input_fields) ]
-        else if cache:
+        elif cache:
             # maybe a fixed-size array kind of structure would be more efficient than a list
             # in the case where src is FiniteDataSet. -YB
             self.cached_examples = []