diff dataset.py @ 252:856d14dc4468

implemented CachedDataSet.__iter__ as an optimization
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 03 Jun 2008 13:22:45 -0400
parents 7e6edee187e3
children 394e07e2b0fd
line wrap: on
line diff
--- a/dataset.py	Tue Jun 03 12:25:53 2008 -0400
+++ b/dataset.py	Tue Jun 03 13:22:45 2008 -0400
@@ -1162,7 +1162,58 @@
           return self.cached_examples[i]
       else:
           return self.source_dataset[i]
-                      
+      
+  def __iter__(self):
+      class CacheIteratorIter(object):
+          def __init__(self,dataset):
+              self.dataset=dataset
+              self.l = len(dataset)
+              self.current = 0
+              self.fieldnames = self.dataset.fieldNames()
+              self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames))
+          def __iter__(self): return self
+          def next(self):
+              if self.current>=self.l:
+                  raise StopIteration
+              cache_len = len(self.dataset.cached_examples)
+              if self.current>=cache_len: # whole minibatch is not already in cache
+                  # cache everything from current length to upper
+                  self.dataset.cached_examples.append(
+                      self.dataset.source_dataset[self.current])
+              self.example._values = self.dataset.cached_examples[self.current]
+              self.current+=1
+              return self.example
+
+      return CacheIteratorIter(self)
+
+#       class CachedDataSetIterator(object):
+#           def __init__(self,dataset,fieldnames):#,minibatch_size,n_batches,offset):
+# #              if fieldnames is None: fieldnames = dataset.fieldNames()
+#   # store the resulting minibatch in a lookup-list of values
+#               self.minibatch = LookupList(fieldnames,[0]*len(fieldnames))
+#               self.dataset=dataset
+# #              self.minibatch_size=minibatch_size
+# #              assert offset>=0 and offset<len(dataset.data)
+# #              assert offset+minibatch_size<=len(dataset.data)
+#               self.current=0
+#               self.columns = [self.dataset.fields_columns[f] 
+#                               for f in self.minibatch._names]
+#               self.l = len(self.dataset)
+#           def __iter__(self):
+#               return self
+#           def next(self):
+#               #@todo: we suppose that we need to stop only when minibatch_size == 1.
+#               # Otherwise, MinibatchWrapAroundIterator do it.
+#               if self.current>=self.l:
+#                   raise StopIteration
+#               sub_data =  self.dataset.data[self.current]
+#               self.minibatch._values = [sub_data[c] for c in self.columns]
+              
+#               self.current+=self.minibatch_size
+#               return self.minibatch
+
+#         return CachedDataSetIterator(self,self.fieldNames())#,1,0,0)
+    
 class ApplyFunctionDataSet(DataSet):
   """
   A L{DataSet} that contains as fields the results of applying a