diff dataset.py @ 23:526e192b0699

Working on ApplyFunctionDataSet, added constraint that DataSet iterators must have a next_index() method.
author bengioy@esprit.iro.umontreal.ca
date Wed, 09 Apr 2008 18:27:13 -0400
parents b6b36f65664f
children 672fe4b23032
line wrap: on
line diff
--- a/dataset.py	Mon Apr 07 20:44:37 2008 -0400
+++ b/dataset.py	Wed Apr 09 18:27:13 2008 -0400
@@ -17,7 +17,14 @@
     - for val1,val2,val3 in dataset.zip([field1, field2,field3])
     - for minibatch in dataset.minibatches([field1, field2, ...],minibatch_size=N)
     - for example in dataset
-    Each of these is documented below.
+    Each of these is documented below. All of these iterators are expected
+    to provide, in addition to the usual 'next()' method, a 'next_index()' method
+    which returns a non-negative integer pointing to the position of the next
+    example that will be returned by 'next()' (or of the first example in the
+    next minibatch returned). This is important because these iterators
+    can wrap around the dataset in order to do multiple passes through it,
+    in possibly unregular ways if the minibatch size is not a divisor of the
+    dataset length.
 
     Note: Fields are not mutually exclusive, i.e. two fields can overlap in their actual content.
 
@@ -40,7 +47,7 @@
     def __init__(self):
         pass
     
-    class Iter(LookupList):
+    class Iterator(LookupList):
         def __init__(self, ll):
             LookupList.__init__(self, ll.keys(), ll.values())
             self.ll = ll
@@ -50,6 +57,8 @@
             self.ll.next()
             self._values = [v[0] for v in self.ll._values]
             return self
+        def next_index(self):
+            return self.ll.next_index()
 
     def __iter__(self):
         """Supports the syntax "for i in dataset: ..."
@@ -61,7 +70,7 @@
         Example returned by this iterator), but the derived class is free
         to accept any type of identifier, and add extra functionality to the iterator.
         """
-        return DataSet.Iter(self.minibatches(None, minibatch_size = 1))
+        return DataSet.Iterator(self.minibatches(None, minibatch_size = 1))
 
     def zip(self, *fieldnames):
         """
@@ -81,7 +90,7 @@
         The derived class may accept fieldname arguments of any type.
 
         """
-        return DataSet.Iter(self.minibatches(fieldnames, minibatch_size = 1))
+        return DataSet.Iterator(self.minibatches(fieldnames, minibatch_size = 1))
 
     minibatches_fieldnames = None
     minibatches_minibatch_size = 1
@@ -141,15 +150,7 @@
         """
         raise AbstractFunction()
         
-    def rename(*new_field_specifications):
-        #Yoshua- 
-        # Do you mean for this to be a virtual method?
-        # Wouldn't this functionality be easier to provide via a
-        # RenamingDataSet, such as the one I've written below?
-        # -JB
-        # You are right. Whichever implementation, however, we need a generic way to
-        # 'concatenate' fields, to handle the ([old_field1, old_field2, ...], new_field) semantics.
-        # -YB
+    def merge_fields(*specifications):
         """
         Return a new dataset that maps old fields (of self) to new fields (of the returned 
         dataset). The minimal syntax that should be supported is the following:
@@ -161,6 +162,30 @@
         """
         raise AbstractFunction()
 
+    def merge_field_values(*field_value_pairs)
+        """
+        Return the value that corresponds to merging the values of several fields,
+        given as arguments (field_name, field_value) pairs with self.hasField(field_name).
+        This may be used by implementations of merge_fields.
+        Raise a ValueError if the operation is not possible.
+        """
+        fieldnames,fieldvalues = zip(*field_value_pairs)
+        raise ValueError("Unable to merge values of these fields:"+repr(fieldnames))
+
+    def examples2minibatch(examples):
+        """
+        Combine a list of Examples into a minibatch. A minibatch is an Example whose fields
+        are iterable over the examples of the minibatch.
+        """
+        raise AbstractFunction()
+    
+    def rename(rename_dict):
+        """
+        Return a new dataset that renames fields, using a dictionnary that maps old field
+        names to new field names. The only fields visible by the returned dataset are those
+        whose names are keys of the rename_dict.
+        """
+        return RenamingDataSet(self,rename_dict)
 
     def applyFunction(function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True):
         """
@@ -278,7 +303,7 @@
     if hasattr(dataset, 'as_array_dataset'):
         return dataset.as_array_dataset()
 
-    raise NotImplementedError()
+    raise NotImplementedError
 
     # Make ONE big minibatch with all the examples, to separate the fields.
     n_examples = len(dataset)
@@ -343,6 +368,13 @@
             rval[:a0,:] = a
             rval[a0:,:] = b
             return rval
+
+        def next_index(self):
+            n_rows = self.dataset.data.shape[0]
+            next_i = self.current+self.minibatch_size
+            if next_i >= n_rows:
+                next_i -= n_rows
+            return next_i
         
         def next(self):
 
@@ -352,21 +384,19 @@
                 raise StopIteration
 
             #determine the first and last elements of the slice we'll return
-            rows = self.dataset.data.shape[0]
-            self.current += self.minibatch_size
-            if self.current >= rows:
-                self.current -= rows
+            n_rows = self.dataset.data.shape[0]
+            self.current = self.next_index()
             upper = self.current + self.minibatch_size
 
             data = self.dataset.data
 
-            if upper <= rows:
+            if upper <= n_rows:
                 #this is the easy case, we only need once slice
                 dataview = data[self.current:upper]
             else:
                 # the minibatch wraps around the end of the dataset
                 dataview = data[self.current:]
-                upper -= rows
+                upper -= n_rows
                 assert upper > 0
                 dataview = self.matcat(dataview, data[:upper])
 
@@ -518,6 +548,19 @@
             c+=slice_width
         return result
 
+    def rename(*new_field_specifications):
+        """
+        Return a new dataset that maps old fields (of self) to new fields (of the returned 
+        dataset). The minimal syntax that should be supported is the following:
+           new_field_specifications = [new_field_spec1, new_field_spec2, ...]
+           new_field_spec = ([old_field1, old_field2, ...], new_field)
+        In general both old_field and new_field should be strings, but some datasets may also
+        support additional indexing schemes within each field (e.g. column slice
+        of a matrix-like field).
+        """
+        # if all old fields of each spec are 
+        raise NotImplementedError()
+
 class ApplyFunctionDataSet(DataSet):
     """
     A dataset that contains as fields the results of applying
@@ -532,31 +575,35 @@
     once the output fields for some examples have been computed, then
     are cached (to avoid recomputation if the same examples are again requested).
     """
-    def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True):
+    def __init__(src,function, input_fields, output_fields, copy_inputs=True, accept_minibatches=True, cache=True, compute_now=False):
         DataSet.__init__(self)
         self.src=src
         self.function=function
+        assert src.hasFields(input_fields)
         self.input_fields=input_fields
         self.output_fields=output_fields
+        assert not (copy_inputs and compute_now and not hasattr(src,'fieldNames'))
         self.copy_inputs=copy_inputs
         self.accept_minibatches=accept_minibatches
-        src_fieldnames = src.fieldNames()
-        if copy_inputs:
-            for src_field in src_fieldnames:
-                assert src_field not in output_fields
-            self.fieldnames=src_fieldnames+output_fields
-        else:
-            self.fieldnames=output_fields
-        for input_field in input_fields:
-            assert input_field in src_fieldnames
         self.cache=cache
-        if cache:
+        self.compute_now=compute_now
+        if compute_now:
+            assert hasattr(src,'__len__') and len(src)>=0
+            fieldnames = output_fields
+            if copy_inputs: fieldnames = src.fieldNames() + output_fields
+            if accept_minibatches:
+                # make a single minibatch with all the inputs
+                inputs = src.minibatches(input_fields,len(src)).next()
+                # and apply the function to it, and transpose into a list of examples (field values, actually)
+                self.cached_examples = zip(*Example(output_fields,function(*inputs)))
+            else:
+                # compute a list with one tuple per example, with the function outputs
+                self.cached_examples = [ function(input) for input in src.zip(input_fields) ]
+        else if cache:
             # maybe a fixed-size array kind of structure would be more efficient than a list
             # in the case where src is FiniteDataSet. -YB
-            self.cached_examples = [] 
+            self.cached_examples = []
 
-    def fieldNames(self): return self.fieldnames
-    
     def minibatches(self,
                     fieldnames = DataSet.minibatches_fieldnames,
                     minibatch_size = DataSet.minibatches_minibatch_size,
@@ -566,30 +613,69 @@
 
             def __init__(self,dataset):
                 if fieldnames is None:
-                    LookupList.__init__(self, [],[])
-                else:
-                    LookupList.__init__(self, fieldnames, [0]*len(fieldnames))
+                    assert hasattr(dataset,"fieldNames")
+                    fieldnames = dataset.fieldNames()
+                self.example_index=0
+                LookupList.__init__(self, fieldnames, [0]*len(fieldnames))
                 self.dataset=dataset
-                self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(self.dataset.input_fields))),
+                self.src_iterator=self.src.minibatches(list(set.union(set(fieldnames),set(dataset.input_fields))),
                                                        minibatch_size,n_batches)
+                self.fieldnames_not_in_input = []
+                if self.copy_inputs:
+                    self.fieldnames_not_in_input = filter(lambda x: not x in dataset.input_fields, fieldnames)
                                                        
             def __iter__(self):
                 return self
 
+            def next_index(self):
+                return self.src_iterator.next_index()
+            
             def next(self):
+                example_index = self.src_iterator.next_index()
                 src_examples = self.src_iterator.next()
                 if self.dataset.copy_inputs:
-                    function_inputs = src_examples
+                    function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields]
                 else:
-                    function_inputs = [src_examples[field_name] for field_name in self.dataset.input_fields]
-                outputs = Example(self.dataset.output_fields,self.dataset.function(*function_inputs))
-                if self.dataset.copy_inputs:
-                    return src_examples + outputs
+                    function_inputs = src_examples
+                if self.dataset.cached_examples:
+                    cache_len=len(self.cached_examples)
+                    if example_index<cache_len+minibatch_size:
+                        outputs_list = self.cached_examples[example_index:example_index+minibatch_size]
+                        # convert the minibatch list of examples 
+                        # into a list of fields each of which iterate over the minibatch
+                        outputs = zip(*outputs_list)
+                    else:
+                        outputs = self.dataset.function(*function_inputs)
+                        if self.dataset.cache:
+                            # convert the list of fields, each of which can iterate over the minibatch
+                            # into a list of examples in the minibatch (each of which is a list of field values)
+                            outputs_list = zip(*outputs)
+                            # copy the outputs_list into the cache
+                            for i in xrange(cache_len,example_index):
+                                self.cached_examples.append(None)
+                            self.cached_examples += outputs_list
                 else:
-                    return outputs
+                    outputs = self.dataset.function(*function_inputs)
+                
+                return Example(self.fieldnames_not_in_input+self.dataset.output_fields,
+                               [src_examples[field_name] for field_name in self.fieldnames_not_in_input]+outputs)
+                               
 
         for fieldname in fieldnames:
             assert fieldname in self.output_fields or self.src.hasFields(fieldname)
         return Iterator(self)
 
     
+def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
+    """
+    Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the
+    user to define a set of fields as the 'input' field and a set of fields
+    as the 'target' field. Optionally, a single weight_field can also be defined.
+    """
+    args = ((input_fields,'input'),(output_fields,'target'))
+    if weight_field: args+=(([weight_field],'weight'))
+    return src_dataset.rename(*args)
+
+        
+
+