Mercurial > pylearn
changeset 228:6f55e301c687
optimisation of ArrayDataSet
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Fri, 16 May 2008 16:38:07 -0400 |
parents | 80731832c62b |
children | d7250ee86f72 |
files | dataset.py |
diffstat | 1 files changed, 26 insertions(+), 1 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Thu May 15 15:21:00 2008 -0400 +++ b/dataset.py Fri May 16 16:38:07 2008 -0400 @@ -1015,7 +1015,31 @@ assert key in self.__dict__ # else it means we are trying to access a non-existing property return self.__dict__[key] - + def __iter__(self): + class ArrayDataSetIterator2(object): + def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): + if fieldnames is None: fieldnames = dataset.fieldNames() + # store the resulting minibatch in a lookup-list of values + self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) + self.dataset=dataset + self.minibatch_size=minibatch_size + assert offset>=0 and offset<len(dataset.data) + assert offset+minibatch_size<=len(dataset.data) + self.current=offset + def __iter__(self): + return self + def next(self): + #@todo: we suppose that we need to stop only when minibatch_size == 1. + # Otherwise, MinibatchWrapAroundIterator do it. + if self.current>=self.dataset.data.shape[0]: + raise StopIteration + sub_data = self.dataset.data[self.current] + self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names] + self.current+=self.minibatch_size + return self.minibatch + + return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): class ArrayDataSetIterator(object): def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): @@ -1030,6 +1054,7 @@ def __iter__(self): return self def next(self): + #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] self.current+=self.minibatch_size