Mercurial > pylearn
comparison dataset.py @ 48:b6730f9a336d
Fixing MinibatchDataSet getitem
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Tue, 29 Apr 2008 13:40:13 -0400 |
parents | c5b07e87b0cb |
children | e3ac93e27e16 66619ce44497 |
comparison
equal
deleted
inserted
replaced
47:7086cfcd8ed6 | 48:b6730f9a336d |
---|---|
6 from sys import maxint | 6 from sys import maxint |
7 import numpy | 7 import numpy |
8 | 8 |
9 class AbstractFunction (Exception): """Derived class must override this function""" | 9 class AbstractFunction (Exception): """Derived class must override this function""" |
10 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" | 10 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" |
11 #class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" | |
12 | 11 |
13 class DataSet(object): | 12 class DataSet(object): |
14 """A virtual base class for datasets. | 13 """A virtual base class for datasets. |
15 | 14 |
16 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction | 15 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction |
17 with learning algorithms (for training and testing them): rows/records are called examples, and | 16 with learning algorithms (for training and testing them): rows/records are called examples, and |
18 columns/attributes are called fields. The field value for a particular example can be an arbitrary | 17 columns/attributes are called fields. The field value for a particular example can be an arbitrary |
19 python object, which depends on the particular dataset. | 18 python object, which depends on the particular dataset. |
20 | 19 |
21 We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method | 20 We call a DataSet a 'stream' when its length is unbounded (otherwise its __len__ method |
22 should raise an UnboundedDataSet exception). | 21 should return sys.maxint). |
23 | 22 |
24 A DataSet is a generator of iterators; these iterators can run through the | 23 A DataSet is a generator of iterators; these iterators can run through the |
25 examples or the fields in a variety of ways. A DataSet need not necessarily have a finite | 24 examples or the fields in a variety of ways. A DataSet need not necessarily have a finite |
26 or known length, so this class can be used to interface to a 'stream' which | 25 or known length, so this class can be used to interface to a 'stream' which |
27 feeds on-line learning (however, as noted below, some operations are not | 26 feeds on-line learning (however, as noted below, some operations are not |
302 raise AbstractFunction() | 301 raise AbstractFunction() |
303 | 302 |
304 def __len__(self): | 303 def __len__(self): |
305 """ | 304 """ |
306 len(dataset) returns the number of examples in the dataset. | 305 len(dataset) returns the number of examples in the dataset. |
307 By default, a DataSet is a 'stream', i.e. it has an unbounded length (raises UnboundedDataSet). | 306 By default, a DataSet is a 'stream', i.e. it has an unbounded length (sys.maxint). |
308 Sub-classes which implement finite-length datasets should redefine this method. | 307 Sub-classes which implement finite-length datasets should redefine this method. |
309 Some methods only make sense for finite-length datasets. | 308 Some methods only make sense for finite-length datasets. |
310 """ | 309 """ |
311 raise UnboundedDataSet() | 310 return sys.maxint |
311 | |
312 def is_unbounded(self): | |
313 """ | |
314 Tests whether a dataset is unbounded (e.g. a stream). | |
315 """ | |
316 return len(self)==sys.maxint | |
312 | 317 |
313 def hasFields(self,*fieldnames): | 318 def hasFields(self,*fieldnames): |
314 """ | 319 """ |
315 Return true if the given field name (or field names, if multiple arguments are | 320 Return true if the given field name (or field names, if multiple arguments are |
316 given) is recognized by the DataSet (i.e. can be used as a field name in one | 321 given) is recognized by the DataSet (i.e. can be used as a field name in one |
378 rows = range(i.start,i.stop,i.step) | 383 rows = range(i.start,i.stop,i.step) |
379 # or a list of indices | 384 # or a list of indices |
380 elif type(i) is list: | 385 elif type(i) is list: |
381 rows = i | 386 rows = i |
382 if rows is not None: | 387 if rows is not None: |
383 fields_values = zip(*[self[row] for row in rows]) | 388 examples = [self[row] for row in rows] |
389 fields_values = zip(*examples) | |
384 return MinibatchDataSet( | 390 return MinibatchDataSet( |
385 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) | 391 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) |
386 for fieldname,field_values | 392 for fieldname,field_values |
387 in zip(self.fieldNames(),fields_values)])) | 393 in zip(self.fieldNames(),fields_values)])) |
388 # else check for a fieldname | 394 # else check for a fieldname |
590 | 596 |
591 def __len__(self): | 597 def __len__(self): |
592 return self.length | 598 return self.length |
593 | 599 |
594 def __getitem__(self,i): | 600 def __getitem__(self,i): |
595 return DataSetFields(MinibatchDataSet( | 601 if type(i) in (int,slice,list): |
596 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) | 602 return DataSetFields(MinibatchDataSet( |
603 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) | |
604 if self.hasFields(i): | |
605 return self.fields[i] | |
606 return self.__dict__[i] | |
597 | 607 |
598 def fieldNames(self): | 608 def fieldNames(self): |
599 return self.fields.keys() | 609 return self.fields.keys() |
600 | 610 |
601 def hasFields(self,*fieldnames): | 611 def hasFields(self,*fieldnames): |
602 for fieldname in fieldnames: | 612 for fieldname in fieldnames: |
603 if fieldname not in self.fields: | 613 if fieldname not in self.fields.keys(): |
604 return False | 614 return False |
605 return True | 615 return True |
606 | 616 |
607 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 617 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
608 class Iterator(object): | 618 class Iterator(object): |
747 fieldnames = datasets[-1].fieldNames() | 757 fieldnames = datasets[-1].fieldNames() |
748 self.datasets_start_row=[] | 758 self.datasets_start_row=[] |
749 # We use this map from row index to dataset index for constant-time random access of examples, | 759 # We use this map from row index to dataset index for constant-time random access of examples, |
750 # to avoid having to search for the appropriate dataset each time and slice is asked for. | 760 # to avoid having to search for the appropriate dataset each time and slice is asked for. |
751 for dataset,k in enumerate(datasets[0:-1]): | 761 for dataset,k in enumerate(datasets[0:-1]): |
752 try: | 762 assert dataset.is_unbounded() # All VStacked datasets (except possibly the last) must be bounded (have a length). |
753 L=len(dataset) | 763 L=len(dataset) |
754 except UnboundedDataSet: | |
755 print "All VStacked datasets (except possibly the last) must be bounded (have a length)." | |
756 assert False | |
757 for i in xrange(L): | 764 for i in xrange(L): |
758 self.index2dataset[self.length+i]=k | 765 self.index2dataset[self.length+i]=k |
759 self.datasets_start_row.append(self.length) | 766 self.datasets_start_row.append(self.length) |
760 self.length+=L | 767 self.length+=L |
761 assert dataset.fieldNames()==fieldnames | 768 assert dataset.fieldNames()==fieldnames |