diff dataset.py @ 44:5a85fda9b19b

Fixed some more iterator bugs
author bengioy@grenat.iro.umontreal.ca
date Mon, 28 Apr 2008 13:52:54 -0400
parents e92244f30116
children a5c70dc42972
line wrap: on
line diff
--- a/dataset.py	Mon Apr 28 11:41:28 2008 -0400
+++ b/dataset.py	Mon Apr 28 13:52:54 2008 -0400
@@ -139,11 +139,16 @@
         """
         def __init__(self, minibatch_iterator):
             self.minibatch_iterator = minibatch_iterator
+            self.minibatch = None
         def __iter__(self): #makes for loop work
             return self
         def next(self):
             size1_minibatch = self.minibatch_iterator.next()
-            return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()])
+            if not self.minibatch:
+                self.minibatch = Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()])
+            else:
+                self.minibatch._values = [value[0] for value in size1_minibatch.values()]
+            return self.minibatch
         
         def next_index(self):
             return self.minibatch_iterator.next_index()
@@ -476,16 +481,22 @@
         return self.fieldnames
 
     def __iter__(self):
-        class Iterator(object):
+        class FieldsSubsetIterator(object):
             def __init__(self,ds):
                 self.ds=ds
                 self.src_iter=ds.src.__iter__()
+                self.example=None
             def __iter__(self): return self
             def next(self):
-                example = self.src_iter.next()
-                return Example(self.ds.fieldnames,
-                               [example[field] for field in self.ds.fieldnames])
-        return Iterator(self)
+                complete_example = self.src_iter.next()
+                if self.example:
+                    self.example._values=[complete_example[field]
+                                          for field in self.ds.fieldnames]
+                else:
+                    self.example=Example(self.ds.fieldnames,
+                                         [complete_example[field] for field in self.ds.fieldnames])
+                return self.example
+        return FieldsSubsetIterator(self)
 
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         assert self.hasFields(*fieldnames)
@@ -670,7 +681,7 @@
             
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
 
-        class Iterator(object):
+        class HStackedIterator(object):
             def __init__(self,hsds,iterators):
                 self.hsds=hsds
                 self.iterators=iterators
@@ -700,7 +711,7 @@
         else:
             datasets=self.datasets
             iterators=[dataset.minibatches(None,minibatch_size,n_batches,offset) for dataset in datasets]
-        return Iterator(self,iterators)
+        return HStackedIterator(self,iterators)
 
 
     def valuesVStack(self,fieldname,fieldvalues):
@@ -768,8 +779,8 @@
         return dataset_index, row_within_dataset
         
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
-            
-        class Iterator(object):
+
+        class VStackedIterator(object):
             def __init__(self,vsds):
                 self.vsds=vsds
                 self.next_row=offset
@@ -824,7 +835,8 @@
                 self.next_dataset_row+=minibatch_size
                 if self.next_row+minibatch_size>len(dataset):
                     self.move_to_next_dataset()
-                return 
+                return examples
+        return VStackedIterator(self)
                         
 class ArrayFieldsDataSet(DataSet):
     """
@@ -886,10 +898,11 @@
     #    """More efficient implementation than the default"""
             
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
-        class Iterator(LookupList): # store the result in the lookup-list values
+        class ArrayDataSetIterator(object):
             def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
                 if fieldnames is None: fieldnames = dataset.fieldNames()
-                LookupList.__init__(self,fieldnames,[0]*len(fieldnames))
+                # store the resulting minibatch in a lookup-list of values
+                self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
                 self.dataset=dataset
                 self.minibatch_size=minibatch_size
                 assert offset>=0 and offset<len(dataset.data)
@@ -899,11 +912,11 @@
                 return self
             def next(self):
                 sub_data =  self.dataset.data[self.current:self.current+self.minibatch_size]
-                self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names]
+                self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names]
                 self.current+=self.minibatch_size
-                return self
+                return self.minibatch
 
-        return Iterator(self,fieldnames,minibatch_size,n_batches,offset)
+        return ArrayDataSetIterator(self,fieldnames,minibatch_size,n_batches,offset)
         
 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
     """