Mercurial > pylearn
comparison dataset.py @ 42:9b68774fcc6b
Testing basic functionality and removing obvious bugs
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Fri, 25 Apr 2008 16:00:31 -0400 |
parents | 283e95c15b47 |
children | e92244f30116 |
comparison
equal
deleted
inserted
replaced
41:283e95c15b47 | 42:9b68774fcc6b |
---|---|
1 | 1 |
2 from lookup_list import LookupList | 2 from lookup_list import LookupList |
3 Example = LookupList | 3 Example = LookupList |
4 from misc import * | 4 from misc import unique_elements_list_intersection |
5 import copy | 5 from string import join |
6 import string | 6 from sys import maxint |
7 | 7 |
8 class AbstractFunction (Exception): """Derived class must override this function""" | 8 class AbstractFunction (Exception): """Derived class must override this function""" |
9 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" | 9 class NotImplementedYet (NotImplementedError): """Work in progress, this should eventually be implemented""" |
10 class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" | 10 #class UnboundedDataSet (Exception): """Trying to obtain length of unbounded dataset (a stream)""" |
11 | 11 |
12 class DataSet(object): | 12 class DataSet(object): |
13 """A virtual base class for datasets. | 13 """A virtual base class for datasets. |
14 | 14 |
15 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction | 15 A DataSet can be seen as a generalization of a matrix, meant to be used in conjunction |
122 """ | 122 """ |
123 | 123 |
124 def __init__(self,description=None,field_types=None): | 124 def __init__(self,description=None,field_types=None): |
125 if description is None: | 125 if description is None: |
126 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" | 126 # by default return "<DataSetType>(<SuperClass1>,<SuperClass2>,...)" |
127 description = type(self).__name__ + " ( " + string.join([x.__name__ for x in type(self).__bases__]) + " )" | 127 description = type(self).__name__ + " ( " + join([x.__name__ for x in type(self).__bases__]) + " )" |
128 self.description=description | 128 self.description=description |
129 self.field_types=field_types | 129 self.field_types=field_types |
130 | 130 |
131 class MinibatchToSingleExampleIterator(object): | 131 class MinibatchToSingleExampleIterator(object): |
132 """ | 132 """ |
141 self.minibatch_iterator = minibatch_iterator | 141 self.minibatch_iterator = minibatch_iterator |
142 def __iter__(self): #makes for loop work | 142 def __iter__(self): #makes for loop work |
143 return self | 143 return self |
144 def next(self): | 144 def next(self): |
145 size1_minibatch = self.minibatch_iterator.next() | 145 size1_minibatch = self.minibatch_iterator.next() |
146 return Example(size1_minibatch.keys,[value[0] for value in size1_minibatch.values()]) | 146 return Example(size1_minibatch.keys(),[value[0] for value in size1_minibatch.values()]) |
147 | 147 |
148 def next_index(self): | 148 def next_index(self): |
149 return self.minibatch_iterator.next_index() | 149 return self.minibatch_iterator.next_index() |
150 | 150 |
151 def __iter__(self): | 151 def __iter__(self): |
195 | 195 |
196 def next_index(self): | 196 def next_index(self): |
197 return self.next_row | 197 return self.next_row |
198 | 198 |
199 def next(self): | 199 def next(self): |
200 if self.n_batches and self.n_batches_done==self.n_batches: | 200 if self.n_batches and self.n_batches_done==self.n_batches |
201 raise StopIteration | 201 raise StopIteration |
202 upper = self.next_row+minibatch_size | 202 upper = self.next_row+self.minibatch_size |
203 if upper <=self.L: | 203 if upper <=self.L: |
204 minibatch = self.minibatch_iterator.next() | 204 minibatch = self.iterator.next() |
205 else: | 205 else: |
206 if not self.n_batches: | 206 if not self.n_batches: |
207 raise StopIteration | 207 raise StopIteration |
208 # we must concatenate (vstack) the bottom and top parts of our minibatch | 208 # we must concatenate (vstack) the bottom and top parts of our minibatch |
209 # first get the beginning of our minibatch (top of dataset) | 209 # first get the beginning of our minibatch (top of dataset) |
212 minibatch = Example(self.fieldnames, | 212 minibatch = Example(self.fieldnames, |
213 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) | 213 [self.dataset.valuesVStack(name,[first_part[name],second_part[name]]) |
214 for name in self.fieldnames]) | 214 for name in self.fieldnames]) |
215 self.next_row=upper | 215 self.next_row=upper |
216 self.n_batches_done+=1 | 216 self.n_batches_done+=1 |
217 if upper >= L: | 217 if upper >= self.L: |
218 self.next_row -= L | 218 self.next_row -= self.L |
219 return minibatch | 219 return minibatch |
220 | 220 |
221 | 221 |
222 minibatches_fieldnames = None | 222 minibatches_fieldnames = None |
223 minibatches_minibatch_size = 1 | 223 minibatches_minibatch_size = 1 |
273 | 273 |
274 Note: A list-like container is something like a tuple, list, numpy.ndarray or | 274 Note: A list-like container is something like a tuple, list, numpy.ndarray or |
275 any other object that supports integer indexing and slicing. | 275 any other object that supports integer indexing and slicing. |
276 | 276 |
277 """ | 277 """ |
278 return MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) | 278 return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset) |
279 | 279 |
280 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 280 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
281 """ | 281 """ |
282 This is the minibatches iterator generator that sub-classes must define. | 282 This is the minibatches iterator generator that sub-classes must define. |
283 It does not need to worry about wrapping around multiple times across the dataset, | 283 It does not need to worry about wrapping around multiple times across the dataset, |
320 | 320 |
321 def __call__(self,*fieldnames): | 321 def __call__(self,*fieldnames): |
322 """ | 322 """ |
323 Return a dataset that sees only the fields whose name are specified. | 323 Return a dataset that sees only the fields whose name are specified. |
324 """ | 324 """ |
325 assert self.hasFields(fieldnames) | 325 assert self.hasFields(*fieldnames) |
326 return self.fields(fieldnames).examples() | 326 return self.fields(*fieldnames).examples() |
327 | 327 |
328 def fields(self,*fieldnames): | 328 def fields(self,*fieldnames): |
329 """ | 329 """ |
330 Return a DataSetFields object associated with this dataset. | 330 Return a DataSetFields object associated with this dataset. |
331 """ | 331 """ |
332 return DataSetFields(self,fieldnames) | 332 return DataSetFields(self,*fieldnames) |
333 | 333 |
334 def __getitem__(self,i): | 334 def __getitem__(self,i): |
335 """ | 335 """ |
336 dataset[i] returns the (i+1)-th example of the dataset. | 336 dataset[i] returns the (i+1)-th example of the dataset. |
337 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. | 337 dataset[i:j] returns the subdataset with examples i,i+1,...,j-1. |
369 # or a list of indices | 369 # or a list of indices |
370 elif type(i) is list: | 370 elif type(i) is list: |
371 rows = i | 371 rows = i |
372 if rows is not None: | 372 if rows is not None: |
373 fields_values = zip(*[self[row] for row in rows]) | 373 fields_values = zip(*[self[row] for row in rows]) |
374 return MinibatchDataSet( | 374 return DataSet.MinibatchDataSet( |
375 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) | 375 Example(self.fieldNames(),[ self.valuesVStack(fieldname,field_values) |
376 for fieldname,field_values | 376 for fieldname,field_values |
377 in zip(self.fieldNames(),fields_values)])) | 377 in zip(self.fieldNames(),fields_values)])) |
378 # else check for a fieldname | 378 # else check for a fieldname |
379 if self.hasFields(i): | 379 if self.hasFields(i): |
457 assert len(datasets)>0 | 457 assert len(datasets)>0 |
458 if len(datasets)==1: | 458 if len(datasets)==1: |
459 return datasets[0] | 459 return datasets[0] |
460 return VStackedDataSet(datasets) | 460 return VStackedDataSet(datasets) |
461 | 461 |
462 | 462 class FieldsSubsetDataSet(DataSet): |
463 """ | |
464 A sub-class of DataSet that selects a subset of the fields. | |
465 """ | |
466 def __init__(self,src,fieldnames): | |
467 self.src=src | |
468 self.fieldnames=fieldnames | |
469 assert src.hasFields(*fieldnames) | |
470 self.valuesHStack = src.valuesHStack | |
471 self.valuesVStack = src.valuesVStack | |
472 | |
473 def __len__(self): return len(self.src) | |
474 | |
475 def fieldNames(self): | |
476 return self.fieldnames | |
477 | |
478 def __iter__(self): | |
479 class Iterator(object): | |
480 def __init__(self,ds): | |
481 self.ds=ds | |
482 self.src_iter=ds.src.__iter__() | |
483 def __iter__(self): return self | |
484 def next(self): | |
485 example = self.src_iter.next() | |
486 return Example(self.ds.fieldnames, | |
487 [example[field] for field in self.ds.fieldnames]) | |
488 return Iterator(self) | |
489 | |
490 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | |
491 assert self.hasFields(*fieldnames) | |
492 return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) | |
493 def __getitem__(self,i): | |
494 return FieldsSubsetDataSet(self.src[i],self.fieldnames) | |
495 | |
496 | |
463 class DataSetFields(LookupList): | 497 class DataSetFields(LookupList): |
464 """ | 498 """ |
465 Although a DataSet iterates over examples (like rows of a matrix), an associated | 499 Although a DataSet iterates over examples (like rows of a matrix), an associated |
466 DataSetFields iterates over fields (like columns of a matrix), and can be understood | 500 DataSetFields iterates over fields (like columns of a matrix), and can be understood |
467 as a transpose of the associated dataset. | 501 as a transpose of the associated dataset. |
486 DataSetFields can be concatenated vertically or horizontally. To be consistent with | 520 DataSetFields can be concatenated vertically or horizontally. To be consistent with |
487 the syntax used for DataSets, the | concatenates the fields and the & concatenates | 521 the syntax used for DataSets, the | concatenates the fields and the & concatenates |
488 the examples. | 522 the examples. |
489 """ | 523 """ |
490 def __init__(self,dataset,*fieldnames): | 524 def __init__(self,dataset,*fieldnames): |
491 self.dataset=dataset | |
492 if not fieldnames: | 525 if not fieldnames: |
493 fieldnames=dataset.fieldNames() | 526 fieldnames=dataset.fieldNames() |
527 elif fieldnames is not dataset.fieldNames(): | |
528 dataset = FieldsSubsetDataSet(dataset,fieldnames) | |
494 assert dataset.hasFields(*fieldnames) | 529 assert dataset.hasFields(*fieldnames) |
495 LookupList.__init__(self,dataset.fieldNames(), | 530 self.dataset=dataset |
496 dataset.minibatches(fieldnames if len(fieldnames)>0 else self.fieldNames(), | 531 minibatch_iterator = dataset.minibatches(fieldnames, |
497 minibatch_size=len(dataset)).next()) | 532 minibatch_size=len(dataset), |
533 n_batches=1) | |
534 minibatch=minibatch_iterator.next() | |
535 LookupList.__init__(self,fieldnames,minibatch) | |
536 | |
498 def examples(self): | 537 def examples(self): |
499 return self.dataset | 538 return self.dataset |
500 | 539 |
501 def __or__(self,other): | 540 def __or__(self,other): |
502 """ | 541 """ |
811 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). | 850 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). |
812 """ | 851 """ |
813 | 852 |
814 """ | 853 """ |
815 Construct an ArrayDataSet from the underlying numpy array (data) and | 854 Construct an ArrayDataSet from the underlying numpy array (data) and |
816 a map from fieldnames to field columns. The columns of a field are specified | 855 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified |
817 using the standard arguments for indexing/slicing: integer for a column index, | 856 using the standard arguments for indexing/slicing: integer for a column index, |
818 slice for an interval of columns (with possible stride), or iterable of column indices. | 857 slice for an interval of columns (with possible stride), or iterable of column indices. |
819 """ | 858 """ |
820 def __init__(self, data_array, fields_names_columns): | 859 def __init__(self, data_array, fields_columns): |
821 self.data=data_array | 860 self.data=data_array |
822 self.fields=fields_names_columns | 861 self.fields_columns=fields_columns |
823 | 862 |
824 # check consistency and complete slices definitions | 863 # check consistency and complete slices definitions |
825 for fieldname, fieldcolumns in self.fields.items(): | 864 for fieldname, fieldcolumns in self.fields_columns.items(): |
826 if type(fieldcolumns) is int: | 865 if type(fieldcolumns) is int: |
827 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] | 866 assert fieldcolumns>=0 and fieldcolumns<data_array.shape[1] |
828 elif type(fieldcolumns) is slice: | 867 elif type(fieldcolumns) is slice: |
829 start,step=None,None | 868 start,step=None,None |
830 if not fieldcolumns.start: | 869 if not fieldcolumns.start: |
831 start=0 | 870 start=0 |
832 if not fieldcolumns.step: | 871 if not fieldcolumns.step: |
833 step=1 | 872 step=1 |
834 if start or step: | 873 if start or step: |
835 self.fields[fieldname]=slice(start,fieldcolumns.stop,step) | 874 self.fields_columns[fieldname]=slice(start,fieldcolumns.stop,step) |
836 elif hasattr(fieldcolumns,"__iter__"): # something like a list | 875 elif hasattr(fieldcolumns,"__iter__"): # something like a list |
837 for i in fieldcolumns: | 876 for i in fieldcolumns: |
838 assert i>=0 and i<data_array.shape[1] | 877 assert i>=0 and i<data_array.shape[1] |
839 | 878 |
840 def fieldNames(self): | 879 def fieldNames(self): |
841 return self.fields.keys() | 880 return self.fields_columns.keys() |
842 | 881 |
843 def __len__(self): | 882 def __len__(self): |
844 return len(self.data) | 883 return len(self.data) |
845 | 884 |
846 #def __getitem__(self,i): | 885 #def __getitem__(self,i): |
847 # """More efficient implementation than the default""" | 886 # """More efficient implementation than the default""" |
848 | 887 |
849 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 888 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
850 class Iterator(LookupList): # store the result in the lookup-list values | 889 class Iterator(LookupList): # store the result in the lookup-list values |
851 def __init__(dataset,fieldnames,minibatch_size,n_batches,offset): | 890 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): |
852 if fieldnames is None: fieldnames = dataset.fieldNames() | 891 if fieldnames is None: fieldnames = dataset.fieldNames() |
853 LookupList.__init__(self,fieldnames,[0]*len(fieldnames)) | 892 LookupList.__init__(self,fieldnames,[0]*len(fieldnames)) |
854 self.dataset=dataset | 893 self.dataset=dataset |
855 self.minibatch_size=minibatch_size | 894 self.minibatch_size=minibatch_size |
856 assert offset>=0 and offset<len(dataset.data) | 895 assert offset>=0 and offset<len(dataset.data) |
857 assert offset+minibatch_size<len(dataset.data) | 896 assert offset+minibatch_size<=len(dataset.data) |
858 self.current=offset | 897 self.current=offset |
859 def __iter__(self): | 898 def __iter__(self): |
860 return self | 899 return self |
861 def next(self): | 900 def next(self): |
862 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] | 901 sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] |
863 self._values = [sub_data[:,self.dataset.fields[f]] for f in self._names] | 902 self._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self._names] |
864 return self | 903 return self |
865 | 904 |
866 return Iterator(self,fieldnames,minibatch_size,n_batches,offset) | 905 return Iterator(self,fieldnames,minibatch_size,n_batches,offset) |
867 | 906 |
868 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 907 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
869 """ | 908 """ |
870 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the | 909 Wraps an arbitrary DataSet into one for supervised learning tasks by forcing the |
871 user to define a set of fields as the 'input' field and a set of fields | 910 user to define a set of fields as the 'input' field and a set of fields |