comparison dataset.py @ 55:66619ce44497

Efficient implementation of getitem for ArrayDataSet
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Tue, 29 Apr 2008 15:05:12 -0400
parents b6730f9a336d
children 1729ad44f175
comparison
equal deleted inserted replaced
49:718befdc8671 55:66619ce44497
601 if type(i) in (int,slice,list): 601 if type(i) in (int,slice,list):
602 return DataSetFields(MinibatchDataSet( 602 return DataSetFields(MinibatchDataSet(
603 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) 603 Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields)
604 if self.hasFields(i): 604 if self.hasFields(i):
605 return self.fields[i] 605 return self.fields[i]
606 assert i in self.__dict__ # else it means we are trying to access a non-existing property
606 return self.__dict__[i] 607 return self.__dict__[i]
607 608
608 def fieldNames(self): 609 def fieldNames(self):
609 return self.fields.keys() 610 return self.fields.keys()
610 611
872 whose first axis iterates over examples, second axis determines fields. 873 whose first axis iterates over examples, second axis determines fields.
873 If the underlying array is N-dimensional (has N axes), then the field 874 If the underlying array is N-dimensional (has N axes), then the field
874 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). 875 values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2).
875 """ 876 """
876 877
877 """
878 Construct an ArrayDataSet from the underlying numpy array (data) and
879 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified
880 using the standard arguments for indexing/slicing: integer for a column index,
881 slice for an interval of columns (with possible stride), or iterable of column indices.
882 """
883 def __init__(self, data_array, fields_columns): 878 def __init__(self, data_array, fields_columns):
879 """
880 Construct an ArrayDataSet from the underlying numpy array (data) and
881 a map (fields_columns) from fieldnames to field columns. The columns of a field are specified
882 using the standard arguments for indexing/slicing: integer for a column index,
883 slice for an interval of columns (with possible stride), or iterable of column indices.
884 """
884 self.data=data_array 885 self.data=data_array
885 self.fields_columns=fields_columns 886 self.fields_columns=fields_columns
886 887
887 # check consistency and complete slices definitions 888 # check consistency and complete slices definitions
888 for fieldname, fieldcolumns in self.fields_columns.items(): 889 for fieldname, fieldcolumns in self.fields_columns.items():
904 return self.fields_columns.keys() 905 return self.fields_columns.keys()
905 906
906 def __len__(self): 907 def __len__(self):
907 return len(self.data) 908 return len(self.data)
908 909
909 #def __getitem__(self,i): 910 def __getitem__(self,i):
910 # """More efficient implementation than the default""" 911 """More efficient implementation than the default __getitem__"""
912 fieldnames=self.fields_columns.keys()
913 if type(i) is int:
914 return Example(fieldnames,
915 [self.data[i,self.fields_columns[f]] for f in fieldnames])
916 if type(i) in (slice,list):
917 return MinibatchDataSet(Example(fieldnames,
918 [self.data[i,self.fields_columns[f]] for f in fieldnames]))
919 # else check for a fieldname
920 if self.hasFields(i):
921 return Example([i],[self.data[self.fields_columns[i],:]])
922 # else we are trying to access a property of the dataset
923 assert i in self.__dict__ # else it means we are trying to access a non-existing property
924 return self.__dict__[i]
925
911 926
912 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 927 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
913 class ArrayDataSetIterator(object): 928 class ArrayDataSetIterator(object):
914 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): 929 def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset):
915 if fieldnames is None: fieldnames = dataset.fieldNames() 930 if fieldnames is None: fieldnames = dataset.fieldNames()