changeset 37:73c4212ba5b3

Factored the minibatch-writing code into an iterator class inside DataSet
author bengioy@esprit.iro.umontreal.ca
date Thu, 24 Apr 2008 12:03:06 -0400
parents 438440ba0627
children d637ad8f7352
files dataset.py
diffstat 1 files changed, 168 insertions(+), 49 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Tue Apr 22 18:03:11 2008 -0400
+++ b/dataset.py	Thu Apr 24 12:03:06 2008 -0400
@@ -98,7 +98,7 @@
       * __len__ if it is not a stream
       * __getitem__ may not be feasible with some streams
       * fieldNames
-      * minibatches
+      * minibatches_nowrap (called by DataSet.minibatches())
       * valuesHStack
       * valuesVStack
     For efficiency of implementation, a sub-class might also want to redefine
@@ -142,13 +142,66 @@
         """
         return DataSet.MinibatchToSingleExampleIterator(self.minibatches(None, minibatch_size = 1))
 
+
+    class MinibatchWrapAroundIterator(object):
+        """
+        An iterator for minibatches that handles the case where we need to wrap around the
+        dataset because n_batches*minibatch_size > len(dataset). It is constructed from
+        a dataset that provides a minibatch iterator that does not need to handle that problem.
+        This class is a utility for dataset subclass writers, so that they do not have to handle
+        this issue multiple times.
+        """
+        def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
+            self.dataset=dataset
+            self.fieldnames=fieldnames
+            self.minibatch_size=minibatch_size
+            self.n_batches=n_batches
+            self.n_batches_done=0
+            self.next_row=offset
+            self.L=len(dataset)
+            assert offset+minibatch_size<=self.L
+            ds_nbatches = (self.L-offset)/minibatch_size
+            if n_batches is not None:
+                ds_nbatches = max(n_batches,ds_nbatches)
+            self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset)
+
+        def __iter__(self):
+            return self
+
+        def next_index(self):
+            return self.next_row
+
+        def next(self):
+            if self.n_batches and self.n_batches_done==self.n_batches:
+                raise StopIteration
+            upper = self.next_row+minibatch_size
+            if upper <=self.L:
+                minibatch = self.minibatch_iterator.next()
+            else:
+                if not self.n_batches:
+                    raise StopIteration
+                # we must concatenate (vstack) the bottom and top parts of our minibatch
+                # first get the beginning of our minibatch (top of dataset)
+                first_part = self.dataset.minibatches_nowrap(fieldnames,self.L-self.next_row,1,self.next_row).next()
+                second_part = self.dataset.minibatches_nowrap(fieldnames,upper-self.L,1,0).next()
+                minibatch = Example(self.fieldnames,
+                                    [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
+                                     for name in self.fieldnames])
+            self.next_row=upper
+            self.n_batches_done+=1
+            if upper >= L:
+                self.next_row -= L
+            return minibatch
+
+
     minibatches_fieldnames = None
     minibatches_minibatch_size = 1
     minibatches_n_batches = None
     def minibatches(self,
-            fieldnames = minibatches_fieldnames,
-            minibatch_size = minibatches_minibatch_size,
-            n_batches = minibatches_n_batches):
+                    fieldnames = minibatches_fieldnames,
+                    minibatch_size = minibatches_minibatch_size,
+                    n_batches = minibatches_n_batches,
+                    offset = 0):
         """
         Return an iterator that supports three forms of syntax:
 
@@ -193,13 +246,29 @@
         the derived class can choose a default.  If (-1), then the returned
         iterator should support looping indefinitely.
 
+        - offset (integer, default 0)
+        The iterator will start at example 'offset' in the dataset, rather than the default.
+        
         Note: A list-like container is something like a tuple, list, numpy.ndarray or
         any other object that supports integer indexing and slicing.
 
         """
+        return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)
+
+    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+        """
+        This is the minibatches iterator generator that sub-classes must define.
+        It does not need to worry about wrapping around multiple times across the dataset,
+        as this is handled by MinibatchWrapAroundIterator when DataSet.minibatches() is called.
+        The next() method of the returned iterator does not even need to worry about
+        the termination condition (as StopIteration will be raised by DataSet.minibatches
+        before an improper call to minibatches_nowrap's next() is made).
+        That next() method can assert that its next row will always be within [0,len(dataset)).
+        The iterator returned by minibatches_nowrap does not need to implement
+        a next_index() method either, as this will be provided by MinibatchWrapAroundIterator.
+        """
         raise AbstractFunction()
 
-
     def __len__(self):
         """
         len(dataset) returns the number of examples in the dataset.
@@ -358,9 +427,10 @@
     """
     def __init__(self,dataset,*fieldnames):
         self.dataset=dataset
-        assert dataset.hasField(*fieldnames)
+        assert dataset.hasFields(*fieldnames)
         LookupList.__init__(self,dataset.fieldNames(),
-                            dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),minibatch_size=len(dataset)).next()
+                            dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),
+                                                minibatch_size=len(dataset)).next()
     def examples(self):
         return self.dataset
     
@@ -378,6 +448,7 @@
         """
         return (self.examples() | other.examples()).fields()
 
+    
 class MinibatchDataSet(DataSet):
     """
     Turn a LookupList of same-length fields into an example-iterable dataset.
@@ -407,52 +478,32 @@
     def fieldNames(self):
         return self.fields.keys()
 
-    def hasField(self,*fieldnames):
+    def hasFields(self,*fieldnames):
         for fieldname in fieldnames:
             if fieldname not in self.fields:
                 return False
         return True
 
-    def minibatches(self,
-                    fieldnames = minibatches_fieldnames,
-                    minibatch_size = minibatches_minibatch_size,
-                    n_batches = minibatches_n_batches):
+    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         class Iterator(object):
             def __init__(self,ds):
                 self.ds=ds
-                self.next_example=0
-                self.n_batches_done=0
+                self.next_example=offset
                 assert minibatch_size > 0
-                if minibatch_size > ds.length
+                if offset+minibatch_size > ds.length
                     raise NotImplementedError()
             def __iter__(self):
                 return self
-            def next_index(self):
-                return self.next_example
             def next(self):
                 upper = next_example+minibatch_size
-                if upper<=self.ds.length:
-                    minibatch = Example(self.ds.fields.keys(),
-                                        [field[next_example:upper]
-                                         for field in self.ds.fields])
-                else: # we must concatenate (vstack) the bottom and top parts of our minibatch
-                    minibatch = Example(self.ds.fields.keys(),
-                                        [self.ds.valuesVStack(name,[value[next_example:],
-                                                                     value[0:upper-self.ds.length]])
-                                         for name,value in self.ds.fields.items()])
+                assert upper<=self.ds.length
+                minibatch = Example(self.ds.fields.keys(),
+                                    [field[next_example:upper]
+                                     for field in self.ds.fields])
                 self.next_example+=minibatch_size
-                self.n_batches_done+=1
-                if n_batches:
-                    if self.n_batches_done==n_batches:
-                        raise StopIteration
-                    if self.next_example>=self.ds.length:
-                        self.next_example-=self.ds.length
-                else:
-                    if self.next_example>=self.ds.length:
-                        raise StopIteration
                 return DataSetFields(MinibatchDataSet(minibatch),fieldnames)
 
-        return Iterator(self)
+        return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)
 
     def valuesVStack(self,fieldname,fieldvalues):
         return self.values_vstack(fieldname,fieldvalues)
@@ -504,7 +555,7 @@
             del self.fieldname2dataset[fieldname]
             self.fieldname2dataset[rename_field(fieldname,self.datasets[i],i)]=i
             
-    def hasField(self,*fieldnames):
+    def hasFields(self,*fieldnames):
         for fieldname in fieldnames:
             if not fieldname in self.fieldname2dataset:
                 return False
@@ -513,10 +564,11 @@
     def fieldNames(self):
         return self.fieldname2dataset.keys()
             
-    def minibatches(self,
-            fieldnames = minibatches_fieldnames,
-            minibatch_size = minibatches_minibatch_size,
-            n_batches = minibatches_n_batches):
+    def minibatches_nowrap(self,
+                           fieldnames = minibatches_fieldnames,
+                           minibatch_size = minibatches_minibatch_size,
+                           n_batches = minibatches_n_batches,
+                           offset = 0):
 
         class Iterator(object):
             def __init__(self,hsds,iterators):
@@ -524,8 +576,6 @@
                 self.iterators=iterators
             def __iter__(self):
                 return self
-            def next_index(self):
-                return self.iterators[0].next_index()
             def next(self):
                 # concatenate all the fields of the minibatches
                 minibatch = reduce(LookupList.__add__,[iterator.next() for iterator in self.iterators])
@@ -545,11 +595,11 @@
                 datasets.add(dataset)
                 fields_in_dataset[dataset].append(fieldname)
             datasets=list(datasets)
-            iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches)
+            iterators=[dataset.minibatches(fields_in_dataset[dataset],minibatch_size,n_batches,offset)
                        for dataset in datasets]
         else:
             datasets=self.datasets
-            iterators=[dataset.minibatches(None,minibatch_size,n_batches) for dataset in datasets]
+            iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets]
         return Iterator(self,iterators)
 
 
@@ -577,18 +627,87 @@
         self.datasets=datasets
         self.length=0
         self.index2dataset={}
-        # we use this map from row index to dataset index for constant-time random access of examples,
-        # to avoid having to search for the appropriate dataset each time and slice is asked for
+        assert len(datasets)>0
+        fieldnames = datasets[-1].fieldNames()
+        # We use this map from row index to dataset index for constant-time random access of examples,
+        # to avoid having to search for the appropriate dataset each time and slice is asked for.
         for dataset,k in enumerate(datasets[0:-1]):
             L=len(dataset)
             assert L<DataSet.infinity
             for i in xrange(L):
                 self.index2dataset[self.length+i]=k
             self.length+=L
+            assert dataset.fieldNames()==fieldnames
         self.last_start=self.length
         self.length+=len(datasets[-1])
-        
-            
+        # If length is very large, we should use a more memory-efficient mechanism
+        # that does not store all indices
+        if self.length>1000000:
+            # 1 million entries would require about 60 meg for the index2dataset map
+            # TODO
+            print "A more efficient mechanism for index2dataset should be implemented"
+
+    def __len__(self):
+        return self.length
+    
+    def fieldNames(self):
+        return self.datasets[0].fieldNames()
+
+    def hasFields(self,*fieldnames):
+        return self.datasets[0].hasFields(*fieldnames)
+
+    def minibatches_nowrap(self,
+                           fieldnames = minibatches_fieldnames,
+                           minibatch_size = minibatches_minibatch_size,
+                           n_batches = minibatches_n_batches,
+                           offset = 0):
+        class Iterator(object):
+            def __init__(self,vsds):
+                self.vsds=vsds
+                self.next_row=offset
+                self.next_dataset_index=0
+                self.next_dataset_row=0
+                self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
+                  self.next_iterator(vsds.datasets[0],offset,n_batches)
+
+            def next_iterator(self,dataset,starting_offset,batches_left):
+                L=len(dataset)
+                ds_nbatches = (L-starting_offset)/minibatch_size
+                if batches_left is not None:
+                    ds_nbatches = max(batches_left,ds_nbatches)
+                if minibatch_size>L:
+                    ds_minibatch_size=L
+                    n_left_in_mb=minibatch_size-L
+                else: n_left_in_mb=0
+                return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \
+                       L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb
+
+            def move_to_next_dataset(self):
+                self.next_dataset_index +=1
+                if self.next_dataset_index==len(self.vsds.datasets):
+                    self.next_dataset_index = 0
+                self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
+                   self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches)
+                
+            def __iter__(self):
+                return self
+
+            def next(self):
+                dataset=self.vsds.datasets[self.next_dataset_index]
+                mb = self.next_iterator.next()
+                if self.n_left_in_mb:
+                    names=self.vsds.fieldNames()
+                    extra_mb = []
+                    while self.n_left_in_mb>0:
+                        self.move_to_next_dataset()
+                        extra_mb.append(self.next_iterator.next())
+                    mb = Example(names,
+                                 [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb])
+                                  for name in names])
+                self.next_row+=minibatch_size
+                return mb
+                        
+                
 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
     """
     Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the