diff dataset.py @ 48:b6730f9a336d

Fixing MinibatchDataSet getitem
author bengioy@grenat.iro.umontreal.ca
date Tue, 29 Apr 2008 13:40:13 -0400
parents c5b07e87b0cb
children e3ac93e27e16 66619ce44497
line wrap: on
line diff
--- a/dataset.py	Tue Apr 29 12:39:09 2008 -0400
+++ b/dataset.py	Tue Apr 29 13:40:13 2008 -0400
@@ -8,7 +8,6 @@
 
 class AbstractFunction (Exception): """Derived class must override this function"""
 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented"""
-#class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)"""
 
 class DataSet(object):
     """A virtual base class for datasets.
@@ -19,7 +18,7 @@
     python object, which depends on the particular dataset.
     
     We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method
-    should raise an UnboundedDataSet exception).
+    should return sys.maxint).
 
     A DataSet is a generator of iterators; these iterators can run through the
     examples or the fields in a variety of ways.  A DataSet need not necessarily have a finite
@@ -304,11 +303,17 @@
     def __len__(self):
         """
         len(dataset) returns the number of examples in the dataset.
-        By default, a DataSet is a 'stream', i.e. it has an unbounded length (raises UnboundedDataSet).
+        By default, a DataSet is a 'stream', i.e. it has an unbounded length (sys.maxint).
         Sub-classes which implement finite-length datasets should redefine this method.
         Some methods only make sense for finite-length datasets.
         """
-        raise UnboundedDataSet()
+        return sys.maxint
+
+    def is_unbounded(self):
+        """
+        Tests whether a dataset is unbounded (e.g. a stream).
+        """
+        return len(self)==sys.maxint
 
     def hasFields(self,*fieldnames):
         """
@@ -380,7 +385,8 @@
         elif type(i) is list:
             rows = i
         if rows is not None:
-            fields_values = zip(*[self[row] for row in rows])
+            examples = [self[row] for row in rows]
+            fields_values = zip(*examples)
             return MinibatchDataSet(
                 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
                                             for fieldname,field_values
@@ -592,15 +598,19 @@
         return self.length
 
     def __getitem__(self,i):
-        return DataSetFields(MinibatchDataSet(
-            Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields)
+        if type(i) in (int,slice,list):
+            return DataSetFields(MinibatchDataSet(
+                Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields)
+        if self.hasFields(i):
+            return self.fields[i]
+        return self.__dict__[i]
 
     def fieldNames(self):
         return self.fields.keys()
 
     def hasFields(self,*fieldnames):
         for fieldname in fieldnames:
-            if fieldname not in self.fields:
+            if fieldname not in self.fields.keys():
                 return False
         return True
 
@@ -749,11 +759,8 @@
         # We use this map from row index to dataset index for constant-time random access of examples,
         # to avoid having to search for the appropriate dataset each time and slice is asked for.
         for dataset,k in enumerate(datasets[0:-1]):
-            try:
-                L=len(dataset)
-            except UnboundedDataSet:
-                print "All VStacked datasets (except possibly the last) must be bounded (have a length)."
-                assert False
+            assert dataset.is_unbounded() # All VStacked datasets (except possibly the last) must be bounded (have a length).
+            L=len(dataset)
             for i in xrange(L):
                 self.index2dataset[self.length+i]=k
             self.datasets_start_row.append(self.length)