changeset 256:aef979d5bad9

Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 03 Jun 2008 13:25:40 -0400
parents e93e511deb9a (current diff) bf0a1ebc6e52 (diff)
children 19b14afe04b7 1cafd495098c
files
diffstat 2 files changed, 36 insertions(+), 12 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Tue Jun 03 13:18:33 2008 -0400
+++ b/dataset.py	Tue Jun 03 13:25:40 2008 -0400
@@ -1141,6 +1141,7 @@
           def __init__(self,dataset):
               self.dataset=dataset
               self.current=offset
+              self.all_fields = self.dataset.fieldNames()==fieldnames
           def __iter__(self): return self
           def next(self):
               upper = self.current+minibatch_size
@@ -1152,7 +1153,7 @@
               all_fields_minibatch = Example(self.dataset.fieldNames(),
                                              zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
               self.current+=minibatch_size
-              if self.dataset.fieldNames()==fieldnames:
+              if self.all_fields:
                   return all_fields_minibatch
               return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
       return CacheIterator(self)
@@ -1161,8 +1162,31 @@
       if type(i)==int and len(self.cached_examples)>i:
           return self.cached_examples[i]
       else:
-          return DataSet.__getitem__(self,i)
-                      
+          return self.source_dataset[i]
+      
+  def __iter__(self):
+      class CacheIteratorIter(object):
+          def __init__(self,dataset):
+              self.dataset=dataset
+              self.l = len(dataset)
+              self.current = 0
+              self.fieldnames = self.dataset.fieldNames()
+              self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames))
+          def __iter__(self): return self
+          def next(self):
+              if self.current>=self.l:
+                  raise StopIteration
+              cache_len = len(self.dataset.cached_examples)
+              if self.current>=cache_len: # whole minibatch is not already in cache
+                  # cache everything from current length to upper
+                  self.dataset.cached_examples.append(
+                      self.dataset.source_dataset[self.current])
+              self.example._values = self.dataset.cached_examples[self.current]
+              self.current+=1
+              return self.example
+
+      return CacheIteratorIter(self)
+
 class ApplyFunctionDataSet(DataSet):
   """
   A L{DataSet} that contains as fields the results of applying a
--- a/test_dataset.py	Tue Jun 03 13:18:33 2008 -0400
+++ b/test_dataset.py	Tue Jun 03 13:25:40 2008 -0400
@@ -493,12 +493,11 @@
     raise NotImplementedError()
 
 
-def test_speed():
-    print "test_speed"
-    import time
-    a2 = numpy.random.rand(100000,400)
-    ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)})
+def test_speed(array, ds):
+    print "test_speed", ds.__class__
+
     mat = numpy.random.rand(400,100)
+
     @print_timing
     def f_array_full(a):
         a+1
@@ -540,11 +539,13 @@
             exs[0]+1
 #            ex[0]*mat
 
-    f_array_full(a2)
-    f_array_index(a2)
-    f_array_iter(a2)
+    f_array_full(array)
+    f_array_index(array)
+    f_array_iter(array)
 
     f_ds_index(ds)
+    f_ds_index(ds)
+    f_ds_iter(ds)
     f_ds_iter(ds)
 
     f_ds_mb1(ds,10)
@@ -556,7 +557,6 @@
     f_ds_mb2(ds,1000)
     f_ds_mb2(ds,10000)
 
-    del a2, ds
 
 if __name__=='__main__':
     test1()