diff dataset.py @ 73:69f97aad3faf

Coded untested ApplyFunctionDataSet and CacheDataSet
author bengioy@bengiomac.local
date Sat, 03 May 2008 14:29:56 -0400
parents 2b6656b2ef52
children b4159cbdc06b
line wrap: on
line diff
--- a/dataset.py	Fri May 02 18:36:47 2008 -0400
+++ b/dataset.py	Sat May 03 14:29:56 2008 -0400
@@ -227,7 +227,9 @@
             self.n_batches_done+=1
             if upper >= self.L and self.n_batches:
                 self.next_row -= self.L
-            return minibatch
+            return DataSetFields(MinibatchDataSet(minibatch,self.dataset.valuesVStack,
+                                                  self.dataset.valuesHStack),
+                                 minibatch.keys()))
 
 
     minibatches_fieldnames = None
@@ -392,7 +394,8 @@
             return MinibatchDataSet(
                 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
                                             for fieldname,field_values
-                                            in zip(self.fieldNames(),fields_values)]))
+                                            in zip(self.fieldNames(),fields_values)]),
+                self.valuesVStack,self.valuesHStack)
         # else check for a fieldname
         if self.hasFields(i):
             return self.minibatches(fieldnames=[i],minibatch_size=len(self),n_batches=1,offset=0).next()[0]
@@ -609,9 +612,12 @@
         return self.length
 
     def __getitem__(self,i):
-        if type(i) in (int,slice,list):
-            return DataSetFields(MinibatchDataSet(
-                Example(self._fields.keys(),[field[i] for field in self._fields])),self._fields)
+        if type(i) is int:
+            return Example(self._fields.keys(),[field[i] for field in self._fields])
+        if type(i) in (slice,list):
+            return MinibatchDataSet(Example(self._fields.keys(),
+                                            [field[i] for field in self._fields]),
+                                    self.valuesVStack,self.valuesHStack)
         if self.hasFields(i):
             return self._fields[i]
         assert i in self.__dict__ # else it means we are trying to access a non-existing property
@@ -643,7 +649,7 @@
                                     [field[self.next_example:upper]
                                      for field in self.ds._fields])
                 self.next_example+=minibatch_size
-                return DataSetFields(MinibatchDataSet(minibatch),*fieldnames)
+                return minibatch
 
         return Iterator(self)
 
@@ -716,11 +722,7 @@
                 return self
             def next(self):
                 # concatenate all the fields of the minibatches
-                minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators])
-                # and return a DataSetFields whose dataset is the transpose (=examples()) of this minibatch
-                return DataSetFields(MinibatchDataSet(minibatch,self.hsds.valuesVStack,
-                                                      self.hsds.valuesHStack),
-                                     fieldnames if fieldnames else hsds.fieldNames())
+                return reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators])
                                      
         assert self.hasfields(fieldnames)
         # find out which underlying datasets are necessary to service the required fields
@@ -849,11 +851,10 @@
                     while self.n_left_in_mb>0:
                         self.move_to_next_dataset()
                         extra_mb.append(self.next_iterator.next())
-                    examples = Example(names,
+                    mb = Example(fieldnames,
                                        [dataset.valuesVStack(name,
                                                              [mb[name]]+[b[name] for b in extra_mb])
                                             for name in fieldnames])
-                    mb = DataSetFields(MinibatchDataSet(examples),fieldnames)
                     
                 self.next_row+=minibatch_size
                 self.next_dataset_row+=minibatch_size
@@ -926,7 +927,8 @@
                            [self.data[i,self.fields_columns[f]] for f in fieldnames])
         if type(i) in (slice,list):
             return MinibatchDataSet(Example(fieldnames,
-                                            [self.data[i,self.fields_columns[f]] for f in fieldnames]))
+                                            [self.data[i,self.fields_columns[f]] for f in fieldnames]),
+                                    self.valuesVStack,self.valuesHStack)
         # else check for a fieldname
         if self.hasFields(i):
             return Example([i],[self.data[self.fields_columns[i],:]])
@@ -967,20 +969,140 @@
   Optionally, for finite-length dataset, all the values can be computed
   (and cached) upon construction of the CachedDataSet, rather at the
   first access.
+
+  TODO: add disk-buffering capability, so that when the cache becomes too
+  big for memory, we cache things on disk, trying to keep in memory only
+  the record most likely to be accessed next.
   """
+  def __init__(self,source_dataset,cache_all_upon_construction=False):
+      self.source_dataset=source_dataset
+      self.cache_all_upon_construction=cache_all_upon_construction
+      if cache_all_upon_construction:
+          self.cached_examples = zip(*source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next())
+      else:
+          self.cached_examples = []
 
+      self.fieldNames = source_dataset.fieldNames
+      self.hasFields = source_dataset.hasFields
+      self.valuesHStack = source_dataset.valuesHStack
+      self.valuesVStack = source_dataset.valuesVStack
+      
+  def __len__(self):
+      return len(self.source_dataset)
+
+  def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+      class CacheIterator(object):
+          def __init__(self,dataset):
+              self.dataset=dataset
+              self.current=offset
+          def __iter__(self): return self
+          def next(self):
+              upper = self.current+minibatch_size
+              cache_len = len(self.dataset.cached_examples)
+              if upper>=cache_len: # whole minibatch is not already in cache
+                  # cache everything from current length to upper
+                  for example in self.dataset.source_dataset[cache_len:upper]:
+                      self.dataset.cached_examples.append(example)
+              all_fields_minibatch = Example(self.dataset.fieldNames(),
+                                             self.dataset.cached_examples[self.current:self.current+minibatch_size])
+              if self.dataset.fieldNames()==fieldnames:
+                  return all_fields_minibatch
+              return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
+      return CacheIterator(self)
+
+                      
 class ApplyFunctionDataSet(DataSet):
   """
   A dataset that contains as fields the results of applying a given function
   example-wise or minibatch-wise to all the fields of an input dataset.
   The output of the function should be an iterable (e.g. a list or a LookupList)
-  over the resulting values. In minibatch mode, the function is expected
-  to work on minibatches (takes a minibatch in input and returns a minibatch
-  in output).
+  over the resulting values.
+
+  In minibatch mode, the function is expected to work on minibatches (takes
+  a minibatch in input and returns a minibatch in output). More precisely,
+  it means that each element of the input or output list should be iterable
+  and indexable over the individual example values (typically these
+  elements will be numpy arrays). All of the elements in the input and
+  output lists should have the same length, which is the length of the
+  minibatch.
 
   The function is applied each time an example or a minibatch is accessed.
   To avoid re-doing computation, wrap this dataset inside a CachedDataSet.
+
+  If the values_{h,v}stack functions are not provided, then
+  the input_dataset.values{H,V}Stack functions are used by default.
   """
+  def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
+               values_hstack=None,values_vstack,
+               description=None,fieldtypes=None):
+      """
+      Constructor takes an input dataset that has as many fields as the function
+      expects as inputs. The resulting dataset has as many fields as the function
+      produces as outputs, and that should correspond to the number of output names
+      (provided in a list).
+
+      Note that the expected semantics of the function differs in minibatch mode
+      (it takes minibatches of inputs and produces minibatches of outputs, as
+      documented in the class comment).
+      """
+      self.input_dataset=input_dataset
+      self.function=function
+      self.output_names=output_names
+      self.minibatch_mode=minibatch_mode
+      DataSet.__init__(description,fieldtypes)
+      self.valuesHStack = values_hstack if values_hstack else input_dataset.valuesHStack
+      self.valuesVStack = values_vstack if values_vstack else input_dataset.valuesVStack
+
+  def __len__(self):
+      return len(self.input_dataset)
+
+  def fieldnames(self):
+      return self.output_names
+
+  def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+      class ApplyFunctionIterator(object):
+          def __init__(self,output_dataset):
+              self.input_dataset=output_dataset.input_dataset
+              self.output_dataset=output_dataset
+              self.input_iterator=input_dataset.minibatches(minibatch_size=minibatch_size,
+                                                            n_batches=n_batches,offset=offset).__iter__()
+
+          def __iter__(self): return self
+
+          def next(self):
+              function_inputs = self.input_iterator.next()
+              all_output_names = self.output_dataset.output_names
+              if self.output_dataset.minibatch_mode:
+                  function_outputs = self.output_dataset.function(function_inputs)
+              else:
+                  input_examples = zip(*function_inputs)
+                  output_examples = [self.output_dataset.function(input_example)
+                                     for input_example in input_examples]
+                  function_outputs = [self.output_dataset.valuesVStack(name,values)
+                                      for name,values in zip(all_output_names,
+                                                             zip(*output_examples))]
+              all_outputs = Example(all_output_names,function_outputs)
+              if fieldnames==all_output_names:
+                  return all_outputs
+              return Example(fieldnames,[all_outputs[name] for name in fieldnames])
+
+      return ApplyFunctionIterator(self.input_dataset,self)
+
+  def __iter__(self): # only implemented for increased efficiency
+      class ApplyFunctionSingleExampleIterator(object):
+          def __init__(self,output_dataset):
+              self.current=0
+              self.output_dataset=output_dataset
+              self.input_iterator=output_dataset.input_dataset.__iter__()
+          def __iter__(self): return self
+          def next(self):
+              function_inputs = self.input_iterator.next()
+              if self.output_dataset.minibatch_mode:
+                  function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)]
+              else:
+                  function_outputs = self.output_dataset.function(function_inputs)
+              return Example(self.output_dataset.output_names,function_outputs)
+      return ApplyFunctionSingleExampleIterator(self)
   
 
 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):