comparison dataset.py @ 42:9b68774fcc6b

Testing basic functionality and removing obvious bugs
author bengioy@grenat.iro.umontreal.ca
date Fri, 25 Apr 2008 16:00:31 -0400
parents 283e95c15b47
children e92244f30116
comparison
equal deleted inserted replaced
41:283e95c15b47 42:9b68774fcc6b
1 1
2 from lookup_list import LookupList 2 from lookup_list import LookupList
3 Example = LookupList 3 Example = LookupList
4 from misc import * 4 from misc import unique_elements_list_intersection
5 import copy 5 from string import join
6 import string 6 from sys import maxint
7 7
8 class AbstractFunction (Exception): """Derived class must override this function""" 8 class AbstractFunction (Exception): """Derived class must override this function"""
9 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" 9 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented"""
10 class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" 10 #class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)"""
11 11
12 class DataSet(object): 12 class DataSet(object):
13 """A virtual base class for datasets. 13 """A virtual base class for datasets.
14 14
15 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction 15 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction
122 """ 122 """
123 123
124 def __init__(self,description=None,field_types=None): 124 def __init__(self,description=None,field_types=None):
125 if description is None: 125 if description is None:
126 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" 126 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)"
127 description = type(self).__name__ + " ( " + string.join([x.__name__ for x in type(self).__bases__]) + " )" 127 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )"
128 self.description=description 128 self.description=description
129 self.field_types=field_types 129 self.field_types=field_types
130 130
131 class MinibatchToSingleExampleIterator(object): 131 class MinibatchToSingleExampleIterator(object):
132 """ 132 """
141 self.minibatch_iterator = minibatch_iterator 141 self.minibatch_iterator = minibatch_iterator
142 def __iter__(self): #makes for loop work 142 def __iter__(self): #makes for loop work
143 return self 143 return self
144 def next(self): 144 def next(self):
145 size1_minibatch = self.minibatch_iterator.next() 145 size1_minibatch = self.minibatch_iterator.next()
146 return Example(size1_minibatch.keys,[value[0] for value in size1_minibatch.values()]) 146 return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()])
147 147
148 def next_index(self): 148 def next_index(self):
149 return self.minibatch_iterator.next_index() 149 return self.minibatch_iterator.next_index()
150 150
151 def __iter__(self): 151 def __iter__(self):
195 195
196 def next_index(self): 196 def next_index(self):
197 return self.next_row 197 return self.next_row
198 198
199 def next(self): 199 def next(self):
200 if self.n_batches and self.n_batches_done==self.n_batches: 200 if self.n_batches and self.n_batches_done==self.n_batches
201 raise StopIteration 201 raise StopIteration
202 upper = self.next_row+minibatch_size 202 upper = self.next_row+self.minibatch_size
203 if upper <=self.L: 203 if upper <=self.L:
204 minibatch = self.minibatch_iterator.next() 204 minibatch = self.iterator.next()
205 else: 205 else:
206 if not self.n_batches: 206 if not self.n_batches:
207 raise StopIteration 207 raise StopIteration
208 # we must concatenate (vstack) the bottom and top parts of our minibatch 208 # we must concatenate (vstack) the bottom and top parts of our minibatch
209 # first get the beginning of our minibatch (top of dataset) 209 # first get the beginning of our minibatch (top of dataset)
212 minibatch = Example(self.fieldnames, 212 minibatch = Example(self.fieldnames,
213 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) 213 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]])
214 for name in self.fieldnames]) 214 for name in self.fieldnames])
215 self.next_row=upper 215 self.next_row=upper
216 self.n_batches_done+=1 216 self.n_batches_done+=1
217 if upper >= L: 217 if upper >= self.L:
218 self.next_row -= L 218 self.next_row -= self.L
219 return minibatch 219 return minibatch
220 220
221 221
222 minibatches_fieldnames = None 222 minibatches_fieldnames = None
223 minibatches_minibatch_size = 1 223 minibatches_minibatch_size = 1
273 273
274 Note: A list-like container is something like a tuple, list, numpy.ndarray or 274 Note: A list-like container is something like a tuple, list, numpy.ndarray or
275 any other object that supports integer indexing and slicing. 275 any other object that supports integer indexing and slicing.
276 276
277 """ 277 """
278 return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) 278 return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)
279 279
280 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 280 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
281 """ 281 """
282 This is the minibatches iterator generator that sub-classes must define. 282 This is the minibatches iterator generator that sub-classes must define.
283 It does not need to worry about wrapping around multiple times across the dataset, 283 It does not need to worry about wrapping around multiple times across the dataset,
320 320
321 def __call__(self,*fieldnames): 321 def __call__(self,*fieldnames):
322 """ 322 """
323 Return a dataset that sees only the fields whose name are specified. 323 Return a dataset that sees only the fields whose name are specified.
324 """ 324 """
325 assert self.hasFields(fieldnames) 325 assert self.hasFields(*fieldnames)
326 return self.fields(fieldnames).examples() 326 return self.fields(*fieldnames).examples()
327 327
328 def fields(self,*fieldnames): 328 def fields(self,*fieldnames):
329 """ 329 """
330 Return a DataSetFields object associated with this dataset. 330 Return a DataSetFields object associated with this dataset.
331 """ 331 """
332 return DataSetFields(self,fieldnames) 332 return DataSetFields(self,*fieldnames)
333 333
334 def __getitem__(self,i): 334 def __getitem__(self,i):
335 """ 335 """
336 dataset[i] returns the (i+1)-th example of the dataset. 336 dataset[i] returns the (i+1)-th example of the dataset.
337 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. 337 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1.
369 # or a list of indices 369 # or a list of indices
370 elif type(i) is list: 370 elif type(i) is list:
371 rows = i 371 rows = i
372 if rows is not None: 372 if rows is not None:
373 fields_values = zip(*[self[row] for row in rows]) 373 fields_values = zip(*[self[row] for row in rows])
374 return MinibatchDataSet( 374 return DataSet.MinibatchDataSet(
375 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) 375 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values)
376 for fieldname,field_values 376 for fieldname,field_values
377 in zip(self.fieldNames(),fields_values)])) 377 in zip(self.fieldNames(),fields_values)]))
378 # else check for a fieldname 378 # else check for a fieldname
379 if self.hasFields(i): 379 if self.hasFields(i):
457 assert len(datasets)>0 457 assert len(datasets)>0
458 if len(datasets)==1: 458 if len(datasets)==1:
459 return datasets[0] 459 return datasets[0]
460 return VStackedDataSet(datasets) 460 return VStackedDataSet(datasets)
461 461
462 462 class FieldsSubsetDataSet(DataSet):
463 """
464 A sub-class of DataSet that selects a subset of the fields.
465 """
466 def __init__(self,src,fieldnames):
467 self.src=src
468 self.fieldnames=fieldnames
469 assert src.hasFields(*fieldnames)
470 self.valuesHStack = src.valuesHStack
471 self.valuesVStack = src.valuesVStack
472
473 def __len__(self): return len(self.src)
474
475 def fieldNames(self):
476 return self.fieldnames
477
478 def __iter__(self):
479 class Iterator(object):
480 def __init__(self,ds):
481 self.ds=ds
482 self.src_iter=ds.src.__iter__()
483 def __iter__(self): return self
484 def next(self):
485 example = self.src_iter.next()
486 return Example(self.ds.fieldnames,
487 [example[field] for field in self.ds.fieldnames])
488 return Iterator(self)
489
490 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
491 assert self.hasFields(*fieldnames)
492 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset)
493 def __getitem__(self,i):
494 return FieldsSubsetDataSet(self.src[i],self.fieldnames)
495
496
463 class DataSetFields(LookupList): 497 class DataSetFields(LookupList):
464 """ 498 """
465 Although a DataSet iterates over examples (like rows of a matrix), an associated 499 Although a DataSet iterates over examples (like rows of a matrix), an associated
466 DataSetFields iterates over fields (like columns of a matrix), and can be understood 500 DataSetFields iterates over fields (like columns of a matrix), and can be understood
467 as a transpose of the associated dataset. 501 as a transpose of the associated dataset.
486 DataSetFields can be concatenated vertically or horizontally. To be consistent with 520 DataSetFields can be concatenated vertically or horizontally. To be consistent with
487 the syntax used for DataSets, the | concatenates the fields and the & concatenates 521 the syntax used for DataSets, the | concatenates the fields and the & concatenates
488 the examples. 522 the examples.
489 """ 523 """
490 def __init__(self,dataset,*fieldnames): 524 def __init__(self,dataset,*fieldnames):
491 self.dataset=dataset
492 if not fieldnames: 525 if not fieldnames:
493 fieldnames=dataset.fieldNames() 526 fieldnames=dataset.fieldNames()
527 elif fieldnames is not dataset.fieldNames():
528 dataset = FieldsSubsetDataSet(dataset,fieldnames)
494 assert dataset.hasFields(*fieldnames) 529 assert dataset.hasFields(*fieldnames)
495 LookupList.__init__(self,dataset.fieldNames(), 530 self.dataset=dataset
496 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(), 531 minibatch_iterator = dataset.minibatches(fieldnames,
497 minibatch_size=len(dataset)).next()) 532 minibatch_size=len(dataset),
533 n_batches=1)
534 minibatch=minibatch_iterator.next()
535 LookupList.__init__(self,fieldnames,minibatch)
536
498 def examples(self): 537 def examples(self):
499 return self.dataset 538 return self.dataset
500 539
501 def __or__(self,other): 540 def __or__(self,other):
502 """ 541 """
811 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). 850 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2).
812 """ 851 """
813 852
814 """ 853 """
815 Construct an ArrayDataSet from the underlying numpy array (data) and 854 Construct an ArrayDataSet from the underlying numpy array (data) and
816 a map from fieldnames to field columns. The columns of a field are specified 855 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified
817 using the standard arguments for indexing/slicing: integer for a column index, 856 using the standard arguments for indexing/slicing: integer for a column index,
818 slice for an interval of columns (with possible stride), or iterable of column indices. 857 slice for an interval of columns (with possible stride), or iterable of column indices.
819 """ 858 """
820 def __init__(self, data_array, fields_names_columns): 859 def __init__(self, data_array, fields_columns):
821 self.data=data_array 860 self.data=data_array
822 self.fields=fields_names_columns 861 self.fields_columns=fields_columns
823 862
824 # check consistency and complete slices definitions 863 # check consistency and complete slices definitions
825 for fieldname, fieldcolumns in self.fields.items(): 864 for fieldname, fieldcolumns in self.fields_columns.items():
826 if type(fieldcolumns) is int: 865 if type(fieldcolumns) is int:
827 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] 866 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1]
828 elif type(fieldcolumns) is slice: 867 elif type(fieldcolumns) is slice:
829 start,step=None,None 868 start,step=None,None
830 if not fieldcolumns.start: 869 if not fieldcolumns.start:
831 start=0 870 start=0
832 if not fieldcolumns.step: 871 if not fieldcolumns.step:
833 step=1 872 step=1
834 if start or step: 873 if start or step:
835 self.fields[fieldname]=slice(start,fieldcolumns.stop,step) 874 self.fields_columns[fieldname]=slice(start,fieldcolumns.stop,step)
836 elif hasattr(fieldcolumns,"__iter__"): # something like a list 875 elif hasattr(fieldcolumns,"__iter__"): # something like a list
837 for i in fieldcolumns: 876 for i in fieldcolumns:
838 assert i>=0 and i<data_array.shape[1] 877 assert i>=0 and i<data_array.shape[1]
839 878
840 def fieldNames(self): 879 def fieldNames(self):
841 return self.fields.keys() 880 return self.fields_columns.keys()
842 881
843 def __len__(self): 882 def __len__(self):
844 return len(self.data) 883 return len(self.data)
845 884
846 #def __getitem__(self,i): 885 #def __getitem__(self,i):
847 # """More efficient implementation than the default""" 886 # """More efficient implementation than the default"""
848 887
849 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 888 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
850 class Iterator(LookupList): # store the result in the lookup-list values 889 class Iterator(LookupList): # store the result in the lookup-list values
851 def __init__(dataset,fieldnames,minibatch_size,n_batches,offset): 890 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
852 if fieldnames is None: fieldnames = dataset.fieldNames() 891 if fieldnames is None: fieldnames = dataset.fieldNames()
853 LookupList.__init__(self,fieldnames,[0]*len(fieldnames)) 892 LookupList.__init__(self,fieldnames,[0]*len(fieldnames))
854 self.dataset=dataset 893 self.dataset=dataset
855 self.minibatch_size=minibatch_size 894 self.minibatch_size=minibatch_size
856 assert offset>=0 and offset<len(dataset.data) 895 assert offset>=0 and offset<len(dataset.data)
857 assert offset+minibatch_size<len(dataset.data) 896 assert offset+minibatch_size<=len(dataset.data)
858 self.current=offset 897 self.current=offset
859 def __iter__(self): 898 def __iter__(self):
860 return self 899 return self
861 def next(self): 900 def next(self):
862 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] 901 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size]
863 self._values = [sub_data[:,self.dataset.fields[f]] for f in self._names] 902 self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names]
864 return self 903 return self
865 904
866 return Iterator(self,fieldnames,minibatch_size,n_batches,offset) 905 return Iterator(self,fieldnames,minibatch_size,n_batches,offset)
867 906
868 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 907 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
869 """ 908 """
870 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the 909 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the
871 user to define a set of fields as the 'input' field and a set of fields 910 user to define a set of fields as the 'input' field and a set of fields