diff dataset.py @ 42:9b68774fcc6b

Testing basic functionality and removing obvious bugs
author bengioy@grenat.iro.umontreal.ca
date Fri, 25 Apr 2008 16:00:31 -0400
parents 283e95c15b47
children e92244f30116
line wrap: on
line diff
--- a/dataset.py	Fri Apr 25 12:04:55 2008 -0400
+++ b/dataset.py	Fri Apr 25 16:00:31 2008 -0400
@@ -1,13 +1,13 @@
 
 from lookup_list import LookupList
 Example = LookupList
-from misc import *
-import copy
-import string
+from misc import unique_elements_list_intersection
+from string import join
+from sys import maxint
 
 class AbstractFunction (Exception): """Derived class must override this function"""
 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented"""
-class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)"""
+#class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)"""
 
 class DataSet(object):
     """A virtual base class for datasets.
@@ -124,7 +124,7 @@
     def __init__(self,description=None,field_types=None):
         if description is None:
             # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)"
-            description = type(self).__name__ + " ( " + string.join([x.__name__ for x in type(self).__bases__]) + " )"
+            description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )"
         self.description=description
         self.field_types=field_types
     
@@ -143,7 +143,7 @@
             return self
         def next(self):
             size1_minibatch = self.minibatch_iterator.next()
-            return Example(size1_minibatch.keys,[value[0] for value in size1_minibatch.values()])
+            return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()])
         
         def next_index(self):
             return self.minibatch_iterator.next_index()
@@ -197,11 +197,11 @@
             return self.next_row
 
         def next(self):
-            if self.n_batches and self.n_batches_done==self.n_batches:
+            if self.n_batches and self.n_batches_done==self.n_batches
                 raise StopIteration
-            upper = self.next_row+minibatch_size
+            upper = self.next_row+self.minibatch_size
             if upper <=self.L:
-                minibatch = self.minibatch_iterator.next()
+                minibatch = self.iterator.next()
             else:
                 if not self.n_batches:
                     raise StopIteration
@@ -214,8 +214,8 @@
                                      for name in self.fieldnames])
             self.next_row=upper
             self.n_batches_done+=1
-            if upper >= L:
-                self.next_row -= L
+            if upper >= self.L:
+                self.next_row -= self.L
             return minibatch
 
 
@@ -275,7 +275,7 @@
         any other object that supports integer indexing and slicing.
 
         """
-        return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)
+        return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)
 
     def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
         """
@@ -322,14 +322,14 @@
         """
         Return a dataset that sees only the fields whose name are specified.
         """
-        assert self.hasFields(fieldnames)
-        return self.fields(fieldnames).examples()
+        assert self.hasFields(*fieldnames)
+        return self.fields(*fieldnames).examples()
 
     def fields(self,*fieldnames):
         """
         Return a DataSetFields object associated with this dataset.
         """
-        return DataSetFields(self,fieldnames)
+        return DataSetFields(self,*fieldnames)
 
     def __getitem__(self,i):
         """
@@ -371,7 +371,7 @@
             rows = i
         if rows is not None:
             fields_values = zip(*[self[row] for row in rows])
-            return MinibatchDataSet(
+            return DataSet.MinibatchDataSet(
                 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
                                             for fieldname,field_values
                                             in zip(self.fieldNames(),fields_values)]))
@@ -459,7 +459,41 @@
         return datasets[0]
     return VStackedDataSet(datasets)
 
+class FieldsSubsetDataSet(DataSet):
+    """
+    A sub-class of DataSet that selects a subset of the fields.
+    """
+    def __init__(self,src,fieldnames):
+        self.src=src
+        self.fieldnames=fieldnames
+        assert src.hasFields(*fieldnames)
+        self.valuesHStack = src.valuesHStack
+        self.valuesVStack = src.valuesVStack
 
+    def __len__(self): return len(self.src)
+    
+    def fieldNames(self):
+        return self.fieldnames
+
+    def __iter__(self):
+        class Iterator(object):
+            def __init__(self,ds):
+                self.ds=ds
+                self.src_iter=ds.src.__iter__()
+            def __iter__(self): return self
+            def next(self):
+                example = self.src_iter.next()
+                return Example(self.ds.fieldnames,
+                               [example[field] for field in self.ds.fieldnames])
+        return Iterator(self)
+
+    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+        assert self.hasFields(*fieldnames)
+        return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)
+    def __getitem__(self,i):
+        return FieldsSubsetDataSet(self.src[i],self.fieldnames)
+    
+        
 class DataSetFields(LookupList):
     """
     Although a DataSet iterates over examples (like rows of a matrix), an associated
@@ -488,13 +522,18 @@
     the examples.
     """
     def __init__(self,dataset,*fieldnames):
-        self.dataset=dataset
         if not fieldnames:
             fieldnames=dataset.fieldNames()
+        elif fieldnames is not dataset.fieldNames():
+            dataset = FieldsSubsetDataSet(dataset,fieldnames)
         assert dataset.hasFields(*fieldnames)
-        LookupList.__init__(self,dataset.fieldNames(),
-                            dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(),
-                                                minibatch_size=len(dataset)).next())
+        self.dataset=dataset
+        minibatch_iterator = dataset.minibatches(fieldnames,
+                                                 minibatch_size=len(dataset),
+                                                 n_batches=1)
+        minibatch=minibatch_iterator.next()
+        LookupList.__init__(self,fieldnames,minibatch)
+        
     def examples(self):
         return self.dataset
     
@@ -813,16 +852,16 @@
 
     """
     Construct an ArrayDataSet from the underlying numpy array (data) and
-    a map from fieldnames to field columns. The columns of a field are specified
+    a map (fields_columns) from fieldnames to field columns. The columns of a field are specified
     using the standard arguments for indexing/slicing: integer for a column index,
     slice for an interval of columns (with possible stride), or iterable of column indices.
     """
-    def __init__(self, data_array, fields_names_columns):
+    def __init__(self, data_array, fields_columns):
         self.data=data_array
-        self.fields=fields_names_columns
+        self.fields_columns=fields_columns
 
         # check consistency and complete slices definitions
-        for fieldname, fieldcolumns in self.fields.items():
+        for fieldname, fieldcolumns in self.fields_columns.items():
             if type(fieldcolumns) is int:
                 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
             elif type(fieldcolumns) is slice:
@@ -832,38 +871,38 @@
                 if not fieldcolumns.step:
                     step=1
                 if start or step:
-                    self.fields[fieldname]=slice(start,fieldcolumns.stop,step)
+                    self.fields_columns[fieldname]=slice(start,fieldcolumns.stop,step)
             elif hasattr(fieldcolumns,"__iter__"): # something like a list
                 for i in fieldcolumns:
                     assert i>=0 and i<data_array.shape[1]
 
-        def fieldNames(self):
-            return self.fields.keys()
+    def fieldNames(self):
+        return self.fields_columns.keys()
 
-        def __len__(self):
-            return len(self.data)
+    def __len__(self):
+        return len(self.data)
 
-        #def __getitem__(self,i):
-        #    """More efficient implementation than the default"""
+    #def __getitem__(self,i):
+    #    """More efficient implementation than the default"""
             
-        def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
-            class Iterator(LookupList): # store the result in the lookup-list values
-                def __init__(dataset,fieldnames,minibatch_size,n_batches,offset):
-                    if fieldnames is None: fieldnames = dataset.fieldNames()
-                    LookupList.__init__(self,fieldnames,[0]*len(fieldnames))
-                    self.dataset=dataset
-                    self.minibatch_size=minibatch_size
-                    assert offset>=0 and offset<len(dataset.data)
-                    assert offset+minibatch_size<len(dataset.data)
-                    self.current=offset
-                def __iter__(self):
-                    return self
-                def next(self):
-                    sub_data =  self.dataset.data[self.current:self.current+self.minibatch_size]
-                    self._values = [sub_data[:,self.dataset.fields[f]] for f in self._names]
-                    return self
-                
-            return Iterator(self,fieldnames,minibatch_size,n_batches,offset)
+    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+        class Iterator(LookupList): # store the result in the lookup-list values
+            def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
+                if fieldnames is None: fieldnames = dataset.fieldNames()
+                LookupList.__init__(self,fieldnames,[0]*len(fieldnames))
+                self.dataset=dataset
+                self.minibatch_size=minibatch_size
+                assert offset>=0 and offset<len(dataset.data)
+                assert offset+minibatch_size<=len(dataset.data)
+                self.current=offset
+            def __iter__(self):
+                return self
+            def next(self):
+                sub_data =  self.dataset.data[self.current:self.current+self.minibatch_size]
+                self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names]
+                return self
+
+        return Iterator(self,fieldnames,minibatch_size,n_batches,offset)
         
 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
     """