diff dataset.py @ 152:3f627e844cba

Fixes in CacheDataSet
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 12 May 2008 16:11:24 -0400
parents 39bb21348fdf
children f8a1ae7eb83e
line wrap: on
line diff
--- a/dataset.py	Mon May 12 15:51:43 2008 -0400
+++ b/dataset.py	Mon May 12 16:11:24 2008 -0400
@@ -1045,15 +1045,16 @@
   def __init__(self,source_dataset,cache_all_upon_construction=False):
       self.source_dataset=source_dataset
       self.cache_all_upon_construction=cache_all_upon_construction
+      self.cached_examples = []
       if cache_all_upon_construction:
           # this potentially brings all the source examples
           # into memory at once, which may be too much
           # the work could possibly be done by minibatches
           # that are as large as possible but no more than what memory allows.
           fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
-          self.cached_examples = zip(*fields_values)
-      else:
-          self.cached_examples = []
+          assert all([len(self)==len(field_values) for field_values in fields_values])
+          for example in fields_values.examples():
+              self.cached_examples.append(example)
 
       self.fieldNames = source_dataset.fieldNames
       self.hasFields = source_dataset.hasFields
@@ -1077,7 +1078,7 @@
                   for example in self.dataset.source_dataset[cache_len:upper]:
                       self.dataset.cached_examples.append(example)
               all_fields_minibatch = Example(self.dataset.fieldNames(),
-                                             *self.dataset.cached_examples[self.current:self.current+minibatch_size])
+                                             zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
               if self.dataset.fieldNames()==fieldnames:
                   return all_fields_minibatch
               return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])