Mercurial > pylearn
diff dataset.py @ 80:40476a7746e8
bugfix
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 10:56:58 -0400 |
parents | dde1fb1b63ba |
children | 158653a9bc7c |
line wrap: on
line diff
--- a/dataset.py Mon May 05 10:28:58 2008 -0400 +++ b/dataset.py Mon May 05 10:56:58 2008 -0400 @@ -259,7 +259,7 @@ The minibatches iterator is expected to return upon each call to next() a DataSetFields object, which is a LookupList (indexed by the field names) whose - elements are iterable over the minibatch examples, and which keeps a pointer to + elements are iterable and indexable over the minibatch examples, and which keeps a pointer to a sub-dataset that can be used to iterate over the individual examples in the minibatch. Hence a minibatch can be converted back to a regular dataset or its fields can be looked at individually (and possibly iterated over). @@ -609,9 +609,13 @@ return self.length def __getitem__(self,i): - if type(i) in (int,slice,list): + if type(i) in (slice,list): return DataSetFields(MinibatchDataSet( - Example(self._fields.keys(),[field[i] for field in self._fields])),self._fields) + Example(self._fields.keys(),[field[i] for field in self._fields])),self.fieldNames()) + if type(i) is int: + return DataSetFields(MinibatchDataSet( + Example(self._fields.keys(),[[field[i]] for field in self._fields])),self.fieldNames()) + if self.hasFields(i): return self._fields[i] assert i in self.__dict__ # else it means we are trying to access a non-existing property @@ -918,21 +922,28 @@ def __len__(self): return len(self.data) - def __getitem__(self,i): + def __getitem__(self,key): """More efficient implementation than the default __getitem__""" fieldnames=self.fields_columns.keys() - if type(i) is int: + if type(key) is int: return Example(fieldnames, - [self.data[i,self.fields_columns[f]] for f in fieldnames]) - if type(i) in (slice,list): + [self.data[key,self.fields_columns[f]] for f in fieldnames]) + if type(key) is slice: return MinibatchDataSet(Example(fieldnames, - [self.data[i,self.fields_columns[f]] for f in fieldnames])) + [self.data[key,self.fields_columns[f]] for f in fieldnames])) + if type(key) is list: + for i in range(len(key)): + if self.hasFields(key[i]): + key[i]=self.fields_columns[key[i]] + return MinibatchDataSet(Example(fieldnames, + [self.data[key,self.fields_columns[f]] for f in fieldnames])) + # else check for a fieldname - if self.hasFields(i): - return Example([i],[self.data[self.fields_columns[i],:]]) + if self.hasFields(key): + return self.data[self.fields_columns[key],:] # else we are trying to access a property of the dataset - assert i in self.__dict__ # else it means we are trying to access a non-existing property - return self.__dict__[i] + assert key in self.__dict__ # else it means we are trying to access a non-existing property + return self.__dict__[key] def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):