comparison dataset.py @ 243:c8f19a9eb10f

Optimisation in ArrayDataSet::__getitem__
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Mon, 02 Jun 2008 11:59:41 -0400
parents ef70a665aaaf
children 7e6edee187e3 4ad6bc9b4f03
comparison
equal deleted inserted replaced
242:ef70a665aaaf 243:c8f19a9eb10f
1014 return len(self.data) 1014 return len(self.data)
1015 1015
1016 def __getitem__(self,key): 1016 def __getitem__(self,key):
1017 """More efficient implementation than the default __getitem__""" 1017 """More efficient implementation than the default __getitem__"""
1018 fieldnames=self.fields_columns.keys() 1018 fieldnames=self.fields_columns.keys()
1019 values=self.fields_columns.values()
1019 if type(key) is int: 1020 if type(key) is int:
1020 return Example(fieldnames, 1021 return Example(fieldnames,
1021 [self.data[key,self.fields_columns[f]] for f in fieldnames]) 1022 [self.data[key,col] for col in values])
1022 if type(key) is slice: 1023 if type(key) is slice:
1023 return MinibatchDataSet(Example(fieldnames, 1024 return MinibatchDataSet(Example(fieldnames,
1024 [self.data[key,self.fields_columns[f]] for f in fieldnames])) 1025 [self.data[key,col] for col in values]))
1025 if type(key) is list: 1026 if type(key) is list:
1026 for i in range(len(key)): 1027 for i in range(len(key)):
1027 if self.hasFields(key[i]): 1028 if self.hasFields(key[i]):
1028 key[i]=self.fields_columns[key[i]] 1029 key[i]=self.fields_columns[key[i]]
1029 return MinibatchDataSet(Example(fieldnames, 1030 return MinibatchDataSet(Example(fieldnames,
1030 #we must separate differently for list as numpy 1031 #we must separate differently for list as numpy
1031 # doesn't support self.data[[i1,...],[i2,...]] 1032 # doesn't support self.data[[i1,...],[i2,...]]
1032 # when their is more then two i1 and i2 1033 # when their is more then two i1 and i2
1033 [self.data[key,:][:,self.fields_columns[f]] 1034 [self.data[key,:][:,col]
1034 if isinstance(self.fields_columns[f],list) else 1035 if isinstance(col,list) else
1035 self.data[key,self.fields_columns[f]] for f in fieldnames]), 1036 self.data[key,col] for col in values]),
1037
1036 1038
1037 self.valuesVStack,self.valuesHStack) 1039 self.valuesVStack,self.valuesHStack)
1038 1040
1039 # else check for a fieldname 1041 # else check for a fieldname
1040 if self.hasFields(key): 1042 if self.hasFields(key):