Mercurial > pylearn
diff dataset.py @ 48:b6730f9a336d
Fixing MinibatchDataSet getitem
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Tue, 29 Apr 2008 13:40:13 -0400 |
parents | c5b07e87b0cb |
children | e3ac93e27e16 66619ce44497 |
line wrap: on
line diff
--- a/dataset.py Tue Apr 29 12:39:09 2008 -0400 +++ b/dataset.py Tue Apr 29 13:40:13 2008 -0400 @@ -8,7 +8,6 @@ class AbstractFunction (Exception): """Derived class must override this function""" class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" -#class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" class DataSet(object): """A virtual base class for datasets. @@ -19,7 +18,7 @@ python object, which depends on the particular dataset. We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method - should raise an UnboundedDataSet exception). + should return sys.maxint). A DataSet is a generator of iterators; these iterators can run through the examples or the fields in a variety of ways. A DataSet need not necessarily have a finite @@ -304,11 +303,17 @@ def __len__(self): """ len(dataset) returns the number of examples in the dataset. - By default, a DataSet is a 'stream', i.e. it has an unbounded length (raises UnboundedDataSet). + By default, a DataSet is a 'stream', i.e. it has an unbounded length (sys.maxint). Sub-classes which implement finite-length datasets should redefine this method. Some methods only make sense for finite-length datasets. """ - raise UnboundedDataSet() + return sys.maxint + + def is_unbounded(self): + """ + Tests whether a dataset is unbounded (e.g. a stream). + """ + return len(self)==sys.maxint def hasFields(self,*fieldnames): """ @@ -380,7 +385,8 @@ elif type(i) is list: rows = i if rows is not None: - fields_values = zip(*[self[row] for row in rows]) + examples = [self[row] for row in rows] + fields_values = zip(*examples) return MinibatchDataSet( Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) for fieldname,field_values @@ -592,15 +598,19 @@ return self.length def __getitem__(self,i): - return DataSetFields(MinibatchDataSet( - Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) + if type(i) in (int,slice,list): + return DataSetFields(MinibatchDataSet( + Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) + if self.hasFields(i): + return self.fields[i] + return self.__dict__[i] def fieldNames(self): return self.fields.keys() def hasFields(self,*fieldnames): for fieldname in fieldnames: - if fieldname not in self.fields: + if fieldname not in self.fields.keys(): return False return True @@ -749,11 +759,8 @@ # We use this map from row index to dataset index for constant-time random access of examples, # to avoid having to search for the appropriate dataset each time and slice is asked for. for dataset,k in enumerate(datasets[0:-1]): - try: - L=len(dataset) - except UnboundedDataSet: - print "All VStacked datasets (except possibly the last) must be bounded (have a length)." - assert False + assert dataset.is_unbounded() # All VStacked datasets (except possibly the last) must be bounded (have a length). + L=len(dataset) for i in xrange(L): self.index2dataset[self.length+i]=k self.datasets_start_row.append(self.length)