diff dataset.py @ 38:d637ad8f7352

Finished first untested version of VStackedDataset
author bengioy@esprit.iro.umontreal.ca
date Thu, 24 Apr 2008 13:25:39 -0400
parents 73c4212ba5b3
children c682c6e9bf93
line wrap: on
line diff
--- a/dataset.py	Thu Apr 24 12:03:06 2008 -0400
+++ b/dataset.py	Thu Apr 24 13:25:39 2008 -0400
@@ -149,7 +149,8 @@
         dataset because n_batches*minibatch_size > len(dataset). It is constructed from
         a dataset that provides a minibatch iterator that does not need to handle that problem.
         This class is a utility for dataset subclass writers, so that they do not have to handle
-        this issue multiple times.
+        this issue multiple times, nor check that fieldnames are valid, nor handle the
+        empty fieldnames (meaning 'use all the fields').
         """
         def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
             self.dataset=dataset
@@ -163,6 +164,10 @@
             ds_nbatches = (self.L-offset)/minibatch_size
             if n_batches is not None:
                 ds_nbatches = max(n_batches,ds_nbatches)
+            if fieldnames:
+                assert dataset.hasFields(*fieldnames)
+            else:
+                fieldnames=dataset.fieldNames()
             self.iterator = dataset.minibatches_nowrap(fieldnames,minibatch_size,ds_nbatches,offset)
 
         def __iter__(self):
@@ -629,6 +634,7 @@
         self.index2dataset={}
         assert len(datasets)>0
         fieldnames = datasets[-1].fieldNames()
+        self.datasets_start_row=[]
         # We use this map from row index to dataset index for constant-time random access of examples,
         # to avoid having to search for the appropriate dataset each time and slice is asked for.
         for dataset,k in enumerate(datasets[0:-1]):
@@ -636,9 +642,10 @@
             assert L<DataSet.infinity
             for i in xrange(L):
                 self.index2dataset[self.length+i]=k
+            self.datasets_start_row.append(self.length)
             self.length+=L
             assert dataset.fieldNames()==fieldnames
-        self.last_start=self.length
+        self.datasets_start_row.append(self.length)
         self.length+=len(datasets[-1])
         # If length is very large, we should use a more memory-efficient mechanism
         # that does not store all indices
@@ -656,17 +663,23 @@
     def hasFields(self,*fieldnames):
         return self.datasets[0].hasFields(*fieldnames)
 
+    def locate_row(self,row):
+        """Return (dataset_index, row_within_dataset) for global row number"""
+        dataset_index = self.index2dataset[row]
+        row_within_dataset = self.datasets_start_row[dataset_index]
+        return dataset_index, row_within_dataset
+        
     def minibatches_nowrap(self,
                            fieldnames = minibatches_fieldnames,
                            minibatch_size = minibatches_minibatch_size,
                            n_batches = minibatches_n_batches,
                            offset = 0):
+            
         class Iterator(object):
             def __init__(self,vsds):
                 self.vsds=vsds
                 self.next_row=offset
-                self.next_dataset_index=0
-                self.next_dataset_row=0
+                self.next_dataset_index,self.next_dataset_row=self.vsds.locate_row(offset)
                 self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
                   self.next_iterator(vsds.datasets[0],offset,n_batches)
 
@@ -678,16 +691,23 @@
                 if minibatch_size>L:
                     ds_minibatch_size=L
                     n_left_in_mb=minibatch_size-L
-                else: n_left_in_mb=0
+                    ds_nbatches=1
+                else:
+                    n_left_in_mb=0
                 return dataset.minibatches(fieldnames,minibatch_size,ds_nbatches,starting_offset), \
                        L-(starting_offset+ds_nbatches*minibatch_size), n_left_in_mb
 
             def move_to_next_dataset(self):
-                self.next_dataset_index +=1
-                if self.next_dataset_index==len(self.vsds.datasets):
-                    self.next_dataset_index = 0
-                self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
-                   self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches)
+                if self.n_left_at_the_end_of_ds>0:
+                    self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
+                      self.next_iterator(vsds.datasets[self.next_dataset_index],
+                                         self.n_left_at_the_end_of_ds,1)
+                else:
+                    self.next_dataset_index +=1
+                    if self.next_dataset_index==len(self.vsds.datasets):
+                        self.next_dataset_index = 0
+                    self.current_iterator,self.n_left_at_the_end_of_ds,self.n_left_in_mb= \
+                      self.next_iterator(vsds.datasets[self.next_dataset_index],starting_offset,n_batches)
                 
             def __iter__(self):
                 return self
@@ -696,15 +716,17 @@
                 dataset=self.vsds.datasets[self.next_dataset_index]
                 mb = self.next_iterator.next()
                 if self.n_left_in_mb:
-                    names=self.vsds.fieldNames()
                     extra_mb = []
                     while self.n_left_in_mb>0:
                         self.move_to_next_dataset()
                         extra_mb.append(self.next_iterator.next())
                     mb = Example(names,
                                  [dataset.valuesVStack(name,[mb[name]]+[b[name] for b in extra_mb])
-                                  for name in names])
+                                  for name in fieldnames])
                 self.next_row+=minibatch_size
+                self.next_dataset_row+=minibatch_size
+                if self.next_row+minibatch_size>len(dataset):
+                    self.move_to_next_dataset()
                 return mb