Mercurial > pylearn
changeset 55:66619ce44497
Efficient implementation of getitem for ArrayDataSet
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Tue, 29 Apr 2008 15:05:12 -0400 |
parents | 718befdc8671 |
children | 1729ad44f175 |
files | dataset.py |
diffstat | 1 files changed, 23 insertions(+), 8 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Tue Apr 29 14:34:40 2008 -0400 +++ b/dataset.py Tue Apr 29 15:05:12 2008 -0400 @@ -603,6 +603,7 @@ Example(self.fields.keys(),[field[i] for field in self.fields])),self.fields) if self.hasFields(i): return self.fields[i] + assert i in self.__dict__ # else it means we are trying to access a non-existing property return self.__dict__[i] def fieldNames(self): @@ -874,13 +875,13 @@ values are (N-2)-dimensional objects (i.e. ordinary numbers if N=2). """ - """ - Construct an ArrayDataSet from the underlying numpy array (data) and - a map (fields_columns) from fieldnames to field columns. The columns of a field are specified - using the standard arguments for indexing/slicing: integer for a column index, - slice for an interval of columns (with possible stride), or iterable of column indices. - """ def __init__(self, data_array, fields_columns): + """ + Construct an ArrayDataSet from the underlying numpy array (data) and + a map (fields_columns) from fieldnames to field columns. The columns of a field are specified + using the standard arguments for indexing/slicing: integer for a column index, + slice for an interval of columns (with possible stride), or iterable of column indices. + """ self.data=data_array self.fields_columns=fields_columns @@ -906,8 +907,22 @@ def __len__(self): return len(self.data) - #def __getitem__(self,i): - # """More efficient implementation than the default""" + def __getitem__(self,i): + """More efficient implementation than the default __getitem__""" + fieldnames=self.fields_columns.keys() + if type(i) is int: + return Example(fieldnames, + [self.data[i,self.fields_columns[f]] for f in fieldnames]) + if type(i) in (slice,list): + return MinibatchDataSet(Example(fieldnames, + [self.data[i,self.fields_columns[f]] for f in fieldnames])) + # else check for a fieldname + if self.hasFields(i): + return Example([i],[self.data[self.fields_columns[i],:]]) + # else we are trying to access a property of the dataset + assert i in self.__dict__ # else it means we are trying to access a non-existing property + return self.__dict__[i] + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): class ArrayDataSetIterator(object):