Mercurial > pylearn
comparison dataset.py @ 55:66619ce44497
Efficient implementation of getitem for ArrayDataSet
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Tue, 29 Apr 2008 15:05:12 -0400 |
parents | b6730f9a336d |
children | 1729ad44f175 |
comparison
equal
deleted
inserted
replaced
49:718befdc8671 | 55:66619ce44497 |
---|---|
601 if type(i) in (int,slice,list): | 601 if type(i) in (int,slice,list): |
602 return DataSetFields(MinibatchDataSet( | 602 return DataSetFields(MinibatchDataSet( |
603 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) | 603 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) |
604 if self.hasFields(i): | 604 if self.hasFields(i): |
605 return self.fields[i] | 605 return self.fields[i] |
606 assert i in self.__dict__ # else it means we are trying to access a non-existing property | |
606 return self.__dict__[i] | 607 return self.__dict__[i] |
607 | 608 |
608 def fieldNames(self): | 609 def fieldNames(self): |
609 return self.fields.keys() | 610 return self.fields.keys() |
610 | 611 |
872 whose first axis iterates over examples, second axis determines fields. | 873 whose first axis iterates over examples, second axis determines fields. |
873 If the underlying array is N-dimensional (has N axes), then the field | 874 If the underlying array is N-dimensional (has N axes), then the field |
874 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). | 875 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). |
875 """ | 876 """ |
876 | 877 |
877 """ | |
878 Construct an ArrayDataSet from the underlying numpy array (data) and | |
879 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified | |
880 using the standard arguments for indexing/slicing: integer for a column index, | |
881 slice for an interval of columns (with possible stride), or iterable of column indices. | |
882 """ | |
883 def __init__(self, data_array, fields_columns): | 878 def __init__(self, data_array, fields_columns): |
879 """ | |
880 Construct an ArrayDataSet from the underlying numpy array (data) and | |
881 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified | |
882 using the standard arguments for indexing/slicing: integer for a column index, | |
883 slice for an interval of columns (with possible stride), or iterable of column indices. | |
884 """ | |
884 self.data=data_array | 885 self.data=data_array |
885 self.fields_columns=fields_columns | 886 self.fields_columns=fields_columns |
886 | 887 |
887 # check consistency and complete slices definitions | 888 # check consistency and complete slices definitions |
888 for fieldname, fieldcolumns in self.fields_columns.items(): | 889 for fieldname, fieldcolumns in self.fields_columns.items(): |
904 return self.fields_columns.keys() | 905 return self.fields_columns.keys() |
905 | 906 |
906 def __len__(self): | 907 def __len__(self): |
907 return len(self.data) | 908 return len(self.data) |
908 | 909 |
909 #def __getitem__(self,i): | 910 def __getitem__(self,i): |
910 # """More efficient implementation than the default""" | 911 """More efficient implementation than the default __getitem__""" |
912 fieldnames=self.fields_columns.keys() | |
913 if type(i) is int: | |
914 return Example(fieldnames, | |
915 [self.data[i,self.fields_columns[f]] for f in fieldnames]) | |
916 if type(i) in (slice,list): | |
917 return MinibatchDataSet(Example(fieldnames, | |
918 [self.data[i,self.fields_columns[f]] for f in fieldnames])) | |
919 # else check for a fieldname | |
920 if self.hasFields(i): | |
921 return Example([i],[self.data[self.fields_columns[i],:]]) | |
922 # else we are trying to access a property of the dataset | |
923 assert i in self.__dict__ # else it means we are trying to access a non-existing property | |
924 return self.__dict__[i] | |
925 | |
911 | 926 |
912 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 927 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
913 class ArrayDataSetIterator(object): | 928 class ArrayDataSetIterator(object): |
914 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): | 929 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): |
915 if fieldnames is None: fieldnames = dataset.fieldNames() | 930 if fieldnames is None: fieldnames = dataset.fieldNames() |