changeset 328:09140ba68e17

Added untested RenamedFieldsDataSet
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 16 Jun 2008 16:06:59 -0400
parents 2480024bf401
children 20e08c52c98c
files dataset.py
diffstat 1 files changed, 44 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/dataset.py	Mon Jun 16 12:57:32 2008 -0400
+++ b/dataset.py	Mon Jun 16 16:06:59 2008 -0400
@@ -664,7 +664,50 @@
     def dontuse__getitem__(self,i):
         return FieldsSubsetDataSet(self.src[i],self.fieldnames)
     
-        
+class RenamedFieldsDataSet(DataSet):
+    """
+    A sub-class of L{DataSet} that selects and renames a subset of the fields.
+    """
+    def __init__(self,src,src_fieldnames,new_fieldnames):
+        self.src=src
+        self.src_fieldnames=src_fieldnames
+        self.new_fieldnames=new_fieldnames
+        assert src.hasFields(*src_fieldnames)
+        assert len(src_fieldnames)==len(new_fieldnames)
+        self.valuesHStack = src.valuesHStack
+        self.valuesVStack = src.valuesVStack
+
+    def __len__(self): return len(self.src)
+    
+    def fieldNames(self):
+        return self.new_fieldnames
+
+    def __iter__(self):
+        class FieldsSubsetIterator(object):
+            def __init__(self,ds):
+                self.ds=ds
+                self.src_iter=ds.src.__iter__()
+                self.example=None
+            def __iter__(self): return self
+            def next(self):
+                complete_example = self.src_iter.next()
+                if self.example:
+                    self.example._values=[complete_example[field]
+                                          for field in self.ds.src_fieldnames]
+                else:
+                    self.example=Example(self.ds.new_fieldnames,
+                                         [complete_example[field]
+                                          for field in self.ds.src_fieldnames])
+                return self.example
+        return FieldsSubsetIterator(self)
+
+    def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
+        assert self.hasFields(*fieldnames)
+        return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)
+    def __getitem__(self,i):
+        return FieldsSubsetDataSet(self.src[i],self.new_fieldnames)
+
+
 class DataSetFields(Example):
     """
     Although a L{DataSet} iterates over examples (like rows of a matrix), an associated