changeset 19:57f4015e2e09

Iterators extend LookupList
author bergstrj@iro.umontreal.ca
date Thu, 27 Mar 2008 01:59:44 -0400
parents 60b164a0d84a
children 266c68cb6136
files _test_dataset.py dataset.py lookup_list.py
diffstat 3 files changed, 53 insertions(+), 32 deletions(-) [+]
line wrap: on
line diff
--- a/_test_dataset.py	Thu Mar 27 00:19:16 2008 -0400
+++ b/_test_dataset.py	Thu Mar 27 01:59:44 2008 -0400
@@ -73,7 +73,7 @@
 
         for i, x in enumerate(a.minibatches(["x"], minibatch_size=3, n_batches=6)):
             self.failUnless(numpy.all( x == arr2[i*3:i*3+3,0:2]))
-        
+    
 if __name__ == '__main__':
     unittest.main()
     
--- a/dataset.py	Thu Mar 27 00:19:16 2008 -0400
+++ b/dataset.py	Thu Mar 27 01:59:44 2008 -0400
@@ -40,8 +40,7 @@
         i[identifier], but the derived class is free to accept any type of
         identifier, and add extra functionality to the iterator.
         """
-        for i in self.minibatches( minibatch_size = 1):
-            yield Example(i.keys(), [v[0] for v in i.values()])
+        raise AbstractFunction()
 
     def zip(self, *fieldnames):
         """
@@ -61,8 +60,17 @@
         The derived class may accept fieldname arguments of any type.
 
         """
-        for i in self.minibatches(fieldnames, minibatch_size = 1):
-            yield [f[0] for f in i]
+        class Iter(LookupList):
+            def __init__(self, ll):
+                LookupList.__init__(self, ll.keys(), ll.values())
+                self.ll = ll
+            def __iter__(self): #makes for loop work
+                return self
+            def next(self):
+                self.ll.next()
+                self._values = [v[0] for v in self.ll._values]
+                return self
+        return Iter(self.minibatches(fieldnames, minibatch_size = 1))
 
     minibatches_fieldnames = None
     minibatches_minibatch_size = 1
@@ -177,6 +185,8 @@
             assert minibatch_size>=1 and minibatch_size<=len(dataset)
             self.current = -self.minibatch_size
             self.fieldnames = fieldnames
+            if len(dataset) % minibatch_size:
+                raise NotImplementedError()
 
         def __iter__(self):
             return self
@@ -287,11 +297,11 @@
     by the numpy.array(dataset) call.
     """
 
-    class Iterator(object):
+    class Iterator(LookupList):
         """An iterator over a finite dataset that implements wrap-around"""
         def __init__(self, dataset, fieldnames, minibatch_size, next_max):
+            LookupList.__init__(self, fieldnames, [0] * len(fieldnames))
             self.dataset=dataset
-            self.fieldnames = fieldnames
             self.minibatch_size=minibatch_size
             self.next_count = 0
             self.next_max = next_max
@@ -300,8 +310,7 @@
             if minibatch_size >= len(dataset):
                 raise NotImplementedError()
 
-        def __iter__(self):
-            #Why do we do this?  -JB
+        def __iter__(self): #makes for loop work
             return self
 
         @staticmethod
@@ -323,28 +332,29 @@
                 raise StopIteration
 
             #determine the first and last elements of the slice we'll return
+            rows = self.dataset.data.shape[0]
             self.current += self.minibatch_size
-            if self.current >= len(self.dataset):
-                self.current -= len(self.dataset)
+            if self.current >= rows:
+                self.current -= rows
             upper = self.current + self.minibatch_size
 
-            if upper <= len(self.dataset):
+            data = self.dataset.data
+
+            if upper <= rows:
                 #this is the easy case, we only need once slice
-                dataview = self.dataset.data[self.current:upper]
+                dataview = data[self.current:upper]
             else:
                 # the minibatch wraps around the end of the dataset
-                dataview = self.dataset.data[self.current:]
-                upper -= len(self.dataset)
+                dataview = data[self.current:]
+                upper -= rows
                 assert upper > 0
-                dataview = self.matcat(dataview, self.dataset.data[:upper])
+                dataview = self.matcat(dataview, data[:upper])
 
 
-            rval = [dataview[:, self.dataset.fields[f]] for f in self.fieldnames]
+            self._values = [dataview[:, self.dataset.fields[f]]\
+                    for f in self._names]
 
-            if self.fieldnames:
-                rval = Example(self.fieldnames, rval)
-
-            return rval
+            return self
 
 
     def __init__(self, data, fields=None):
@@ -372,6 +382,9 @@
                 # and coherent with the data array
                 assert fieldslice.start >= 0 and fieldslice.stop <= cols
 
+    def __iter__(self):
+        return self.zip(*self.fieldNames())
+
     def minibatches(self,
             fieldnames = DataSet.minibatches_fieldnames,
             minibatch_size = DataSet.minibatches_minibatch_size,
--- a/lookup_list.py	Thu Mar 27 00:19:16 2008 -0400
+++ b/lookup_list.py	Thu Mar 27 01:59:44 2008 -0400
@@ -47,20 +47,28 @@
             if key in self._name2index:
                 self._values[self._name2index[key]]=value
             else:
-                self._name2index[key]=len(self)
-                self._values.append(value)
-                self._names.append(key)
+                raise KeyError(key)
 
     def __getattr__(self,name):
-        return self._values[self._name2index[name]]
+        try:
+            return self._values[self._name2index[name]]
+        except KeyError, e:
+            raise AttributeError(name)
 
-    def __setattr__(self,name,value):
-        if name in self._name2index:
-            self._values[self._name2index[name]]=value
-        else:
-            self._name2index[name]=len(self)
-            self._values.append(value)
-            self._names.append(name)
+    if 0:
+        # This makes subclassing horrible, just call append_keyval if it's
+        # really what you want to do.
+        # -JB
+        def __setattr__(self,name,value):
+            if name in self._name2index:
+                self._values[self._name2index[name]]=value
+            else:
+                raise AttributeError(name)
+
+    def append_keyval(self, key, value):
+        self._name2index[key]=len(self)
+        self._values.append(value)
+        self._names.append(key)
 
     def __len__(self):
         return len(self._values)