diff dataset.py @ 19:57f4015e2e09

Iterators extend LookupList
author bergstrj@iro.umontreal.ca
date Thu, 27 Mar 2008 01:59:44 -0400
parents 759d17112b23
children 266c68cb6136
line wrap: on
line diff
--- a/dataset.py	Thu Mar 27 00:19:16 2008 -0400
+++ b/dataset.py	Thu Mar 27 01:59:44 2008 -0400
@@ -40,8 +40,7 @@
         i[identifier], but the derived class is free to accept any type of
         identifier, and add extra functionality to the iterator.
         """
-        for i in self.minibatches( minibatch_size = 1):
-            yield Example(i.keys(), [v[0] for v in i.values()])
+        raise AbstractFunction()
 
     def zip(self, *fieldnames):
         """
@@ -61,8 +60,17 @@
         The derived class may accept fieldname arguments of any type.
 
         """
-        for i in self.minibatches(fieldnames, minibatch_size = 1):
-            yield [f[0] for f in i]
+        class Iter(LookupList):
+            def __init__(self, ll):
+                LookupList.__init__(self, ll.keys(), ll.values())
+                self.ll = ll
+            def __iter__(self): #makes for loop work
+                return self
+            def next(self):
+                self.ll.next()
+                self._values = [v[0] for v in self.ll._values]
+                return self
+        return Iter(self.minibatches(fieldnames, minibatch_size = 1))
 
     minibatches_fieldnames = None
     minibatches_minibatch_size = 1
@@ -177,6 +185,8 @@
             assert minibatch_size>=1 and minibatch_size<=len(dataset)
             self.current = -self.minibatch_size
             self.fieldnames = fieldnames
+            if len(dataset) % minibatch_size:
+                raise NotImplementedError()
 
         def __iter__(self):
             return self
@@ -287,11 +297,11 @@
     by the numpy.array(dataset) call.
     """
 
-    class Iterator(object):
+    class Iterator(LookupList):
         """An iterator over a finite dataset that implements wrap-around"""
         def __init__(self, dataset, fieldnames, minibatch_size, next_max):
+            LookupList.__init__(self, fieldnames, [0] * len(fieldnames))
             self.dataset=dataset
-            self.fieldnames = fieldnames
             self.minibatch_size=minibatch_size
             self.next_count = 0
             self.next_max = next_max
@@ -300,8 +310,7 @@
             if minibatch_size >= len(dataset):
                 raise NotImplementedError()
 
-        def __iter__(self):
-            #Why do we do this?  -JB
+        def __iter__(self): #makes for loop work
             return self
 
         @staticmethod
@@ -323,28 +332,29 @@
                 raise StopIteration
 
             #determine the first and last elements of the slice we'll return
+            rows = self.dataset.data.shape[0]
             self.current += self.minibatch_size
-            if self.current >= len(self.dataset):
-                self.current -= len(self.dataset)
+            if self.current >= rows:
+                self.current -= rows
             upper = self.current + self.minibatch_size
 
-            if upper <= len(self.dataset):
+            data = self.dataset.data
+
+            if upper <= rows:
                 #this is the easy case, we only need once slice
-                dataview = self.dataset.data[self.current:upper]
+                dataview = data[self.current:upper]
             else:
                 # the minibatch wraps around the end of the dataset
-                dataview = self.dataset.data[self.current:]
-                upper -= len(self.dataset)
+                dataview = data[self.current:]
+                upper -= rows
                 assert upper > 0
-                dataview = self.matcat(dataview, self.dataset.data[:upper])
+                dataview = self.matcat(dataview, data[:upper])
 
 
-            rval = [dataview[:, self.dataset.fields[f]] for f in self.fieldnames]
+            self._values = [dataview[:, self.dataset.fields[f]]\
+                    for f in self._names]
 
-            if self.fieldnames:
-                rval = Example(self.fieldnames, rval)
-
-            return rval
+            return self
 
 
     def __init__(self, data, fields=None):
@@ -372,6 +382,9 @@
                 # and coherent with the data array
                 assert fieldslice.start >= 0 and fieldslice.stop <= cols
 
+    def __iter__(self):
+        return self.zip(*self.fieldNames())
+
     def minibatches(self,
             fieldnames = DataSet.minibatches_fieldnames,
             minibatch_size = DataSet.minibatches_minibatch_size,