Mercurial > pylearn
diff dataset.py @ 82:158653a9bc7c
Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 11:02:03 -0400 |
parents | 3499918faa9d 40476a7746e8 |
children | c0f211213a58 |
line wrap: on
line diff
--- a/dataset.py Mon May 05 09:35:30 2008 -0400 +++ b/dataset.py Mon May 05 11:02:03 2008 -0400 @@ -281,7 +281,7 @@ The minibatches iterator is expected to return upon each call to next() a DataSetFields object, which is a LookupList (indexed by the field names) whose - elements are iterable over the minibatch examples, and which keeps a pointer to + elements are iterable and indexable over the minibatch examples, and which keeps a pointer to a sub-dataset that can be used to iterate over the individual examples in the minibatch. Hence a minibatch can be converted back to a regular dataset or its fields can be looked at individually (and possibly iterated over). @@ -632,12 +632,13 @@ return self.length def __getitem__(self,i): + if type(i) in (slice,list): + return DataSetFields(MinibatchDataSet( + Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames()) if type(i) is int: - return Example(self._fields.keys(),[field[i] for field in self._fields]) - if type(i) in (slice,list): - return MinibatchDataSet(Example(self._fields.keys(), - [field[i] for field in self._fields]), - self.valuesVStack,self.valuesHStack) + return DataSetFields(MinibatchDataSet( + Example(self._fields.keys(),[[field[i]] for field in self._fields])),self.fieldNames()) + if self.hasFields(i): return self._fields[i] assert i in self.__dict__ # else it means we are trying to access a non-existing property @@ -939,22 +940,29 @@ def __len__(self): return len(self.data) - def __getitem__(self,i): + def __getitem__(self,key): """More efficient implementation than the default __getitem__""" fieldnames=self.fields_columns.keys() - if type(i) is int: + if type(key) is int: return Example(fieldnames, - [self.data[i,self.fields_columns[f]] for f in fieldnames]) - if type(i) in (slice,list): + [self.data[key,self.fields_columns[f]] for f in fieldnames]) + if type(key) is slice: return MinibatchDataSet(Example(fieldnames, - [self.data[i,self.fields_columns[f]] for f in fieldnames]), + [self.data[key,self.fields_columns[f]] for f in fieldnames])) + if type(key) is list: + for i in range(len(key)): + if self.hasFields(key[i]): + key[i]=self.fields_columns[key[i]] + return MinibatchDataSet(Example(fieldnames, + [self.data[key,self.fields_columns[f]] for f in fieldnames]), self.valuesVStack,self.valuesHStack) + # else check for a fieldname - if self.hasFields(i): - return Example([i],[self.data[self.fields_columns[i],:]]) + if self.hasFields(key): + return self.data[self.fields_columns[key],:] # else we are trying to access a property of the dataset - assert i in self.__dict__ # else it means we are trying to access a non-existing property - return self.__dict__[i] + assert key in self.__dict__ # else it means we are trying to access a non-existing property + return self.__dict__[key] def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):