diff dataset.py @ 353:47538a45b878

Cached dataset seems debug, using n_batches... is n_batches around to stay?
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Tue, 17 Jun 2008 17:12:43 -0400
parents 7545207466d4
children d580b3a369a4
line wrap: on
line diff
--- a/dataset.py	Tue Jun 17 16:50:35 2008 -0400
+++ b/dataset.py	Tue Jun 17 17:12:43 2008 -0400
@@ -381,7 +381,8 @@
         any other object that supports integer indexing and slicing.
 
         @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete
-        batches only, raise StopIteration
+        batches only, raise StopIteration.
+        @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it.
 
         """
         #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\
@@ -869,8 +870,9 @@
                 return self
             def next(self):
                 upper = self.next_example+minibatch_size
-                if upper>self.ds.length:
-                    raise StopIteration
+                if upper > len(self.ds) :
+                    raise StopIteration()
+                assert upper<=len(self.ds) # instead of self.ds.length
                 #minibatch = Example(self.ds._fields.keys(),
                 #                    [field[self.next_example:upper]
                 #                     for field in self.ds._fields])
@@ -1325,7 +1327,10 @@
           # into memory at once, which may be too much
           # the work could possibly be done by minibatches
           # that are as large as possible but no more than what memory allows.
-          fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
+          #
+          # field_values is supposed to be an DataSetFields, that inherits from LookupList
+          #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
+          fields_values = DataSetFields(source_dataset,None)
           assert all([len(self)==len(field_values) for field_values in fields_values])
           for example in fields_values.examples():
               self.cached_examples.append(copy.copy(example))
@@ -1344,16 +1349,25 @@
               self.dataset=dataset
               self.current=offset
               self.all_fields = self.dataset.fieldNames()==fieldnames
+              self.n_batches = n_batches
+              self.batch_counter = 0
           def __iter__(self): return self
           def next(self):
+              self.batch_counter += 1
+              if self.n_batches and self.batch_counter > self.n_batches :
+                  raise StopIteration()
               upper = self.current+minibatch_size
+              if upper > len(self.dataset.source_dataset):
+                  raise StopIteration()
               cache_len = len(self.dataset.cached_examples)
               if upper>cache_len: # whole minibatch is not already in cache
                   # cache everything from current length to upper
-                  for example in self.dataset.source_dataset[cache_len:upper]:
+                  #for example in self.dataset.source_dataset[cache_len:upper]:
+                  for example in self.dataset.source_dataset.subset[cache_len:upper]:
                       self.dataset.cached_examples.append(example)
               all_fields_minibatch = Example(self.dataset.fieldNames(),
                                              zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
+
               self.current+=minibatch_size
               if self.all_fields:
                   return all_fields_minibatch