comparison dataset.py @ 48:b6730f9a336d

Fixing MinibatchDataSet getitem
author bengioy@grenat.iro.umontreal.ca
date Tue, 29 Apr 2008 13:40:13 -0400
parents c5b07e87b0cb
children e3ac93e27e16 66619ce44497
comparison
equal deleted inserted replaced
47:7086cfcd8ed6 48:b6730f9a336d
6 from sys import maxint 6 from sys import maxint
7 import numpy 7 import numpy
8 8
9 class AbstractFunction (Exception): """Derived class must override this function""" 9 class AbstractFunction (Exception): """Derived class must override this function"""
10 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" 10 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented"""
11 #class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)"""
12 11
13 class DataSet(object): 12 class DataSet(object):
14 """A virtual base class for datasets. 13 """A virtual base class for datasets.
15 14
16 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction 15 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction
17 with learning algorithms (for training and testing them): rows/records are called examples, and 16 with learning algorithms (for training and testing them): rows/records are called examples, and
18 columns/attributes are called fields. The field value for a particular example can be an arbitrary 17 columns/attributes are called fields. The field value for a particular example can be an arbitrary
19 python object, which depends on the particular dataset. 18 python object, which depends on the particular dataset.
20 19
21 We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method 20 We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method
22 should raise an UnboundedDataSet exception). 21 should return sys.maxint).
23 22
24 A DataSet is a generator of iterators; these iterators can run through the 23 A DataSet is a generator of iterators; these iterators can run through the
25 examples or the fields in a variety of ways. A DataSet need not necessarily have a finite 24 examples or the fields in a variety of ways. A DataSet need not necessarily have a finite
26 or known length, so this class can be used to interface to a 'stream' which 25 or known length, so this class can be used to interface to a 'stream' which
27 feeds on-line learning (however, as noted below, some operations are not 26 feeds on-line learning (however, as noted below, some operations are not
302 raise AbstractFunction() 301 raise AbstractFunction()
303 302
304 def __len__(self): 303 def __len__(self):
305 """ 304 """
306 len(dataset) returns the number of examples in the dataset. 305 len(dataset) returns the number of examples in the dataset.
307 By default, a DataSet is a 'stream', i.e. it has an unbounded length (raises UnboundedDataSet). 306 By default, a DataSet is a 'stream', i.e. it has an unbounded length (sys.maxint).
308 Sub-classes which implement finite-length datasets should redefine this method. 307 Sub-classes which implement finite-length datasets should redefine this method.
309 Some methods only make sense for finite-length datasets. 308 Some methods only make sense for finite-length datasets.
310 """ 309 """
311 raise UnboundedDataSet() 310 return sys.maxint
311
312 def is_unbounded(self):
313 """
314 Tests whether a dataset is unbounded (e.g. a stream).
315 """
316 return len(self)==sys.maxint
312 317
313 def hasFields(self,*fieldnames): 318 def hasFields(self,*fieldnames):
314 """ 319 """
315 Return true if the given field name (or field names, if multiple arguments are 320 Return true if the given field name (or field names, if multiple arguments are
316 given) is recognized by the DataSet (i.e. can be used as a field name in one 321 given) is recognized by the DataSet (i.e. can be used as a field name in one
378 rows = range(i.start,i.stop,i.step) 383 rows = range(i.start,i.stop,i.step)
379 # or a list of indices 384 # or a list of indices
380 elif type(i) is list: 385 elif type(i) is list:
381 rows = i 386 rows = i
382 if rows is not None: 387 if rows is not None:
383 fields_values = zip(*[self[row] for row in rows]) 388 examples = [self[row] for row in rows]
389 fields_values = zip(*examples)
384 return MinibatchDataSet( 390 return MinibatchDataSet(
385 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) 391 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
386 for fieldname,field_values 392 for fieldname,field_values
387 in zip(self.fieldNames(),fields_values)])) 393 in zip(self.fieldNames(),fields_values)]))
388 # else check for a fieldname 394 # else check for a fieldname
590 596
591 def __len__(self): 597 def __len__(self):
592 return self.length 598 return self.length
593 599
594 def __getitem__(self,i): 600 def __getitem__(self,i):
595 return DataSetFields(MinibatchDataSet( 601 if type(i) in (int,slice,list):
596 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) 602 return DataSetFields(MinibatchDataSet(
603 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields)
604 if self.hasFields(i):
605 return self.fields[i]
606 return self.__dict__[i]
597 607
598 def fieldNames(self): 608 def fieldNames(self):
599 return self.fields.keys() 609 return self.fields.keys()
600 610
601 def hasFields(self,*fieldnames): 611 def hasFields(self,*fieldnames):
602 for fieldname in fieldnames: 612 for fieldname in fieldnames:
603 if fieldname not in self.fields: 613 if fieldname not in self.fields.keys():
604 return False 614 return False
605 return True 615 return True
606 616
607 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 617 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
608 class Iterator(object): 618 class Iterator(object):
747 fieldnames = datasets[-1].fieldNames() 757 fieldnames = datasets[-1].fieldNames()
748 self.datasets_start_row=[] 758 self.datasets_start_row=[]
749 # We use this map from row index to dataset index for constant-time random access of examples, 759 # We use this map from row index to dataset index for constant-time random access of examples,
750 # to avoid having to search for the appropriate dataset each time and slice is asked for. 760 # to avoid having to search for the appropriate dataset each time and slice is asked for.
751 for dataset,k in enumerate(datasets[0:-1]): 761 for dataset,k in enumerate(datasets[0:-1]):
752 try: 762 assert dataset.is_unbounded() # All VStacked datasets (except possibly the last) must be bounded (have a length).
753 L=len(dataset) 763 L=len(dataset)
754 except UnboundedDataSet:
755 print "All VStacked datasets (except possibly the last) must be bounded (have a length)."
756 assert False
757 for i in xrange(L): 764 for i in xrange(L):
758 self.index2dataset[self.length+i]=k 765 self.index2dataset[self.length+i]=k
759 self.datasets_start_row.append(self.length) 766 self.datasets_start_row.append(self.length)
760 self.length+=L 767 self.length+=L
761 assert dataset.fieldNames()==fieldnames 768 assert dataset.fieldNames()==fieldnames