Mercurial > pylearn
diff dataset.py @ 231:38beb81f4e8b
Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 27 May 2008 13:46:03 -0400 |
parents | 17c5d080964b 6f55e301c687 |
children | a70f2c973ea5 ddb88a8e9fd2 |
line wrap: on
line diff
--- a/dataset.py Tue May 27 13:23:05 2008 -0400 +++ b/dataset.py Tue May 27 13:46:03 2008 -0400 @@ -1043,7 +1043,31 @@ assert key in self.__dict__ # else it means we are trying to access a non-existing property return self.__dict__[key] - + def __iter__(self): + class ArrayDataSetIterator2(object): + def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): + if fieldnames is None: fieldnames = dataset.fieldNames() + # store the resulting minibatch in a lookup-list of values + self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) + self.dataset=dataset + self.minibatch_size=minibatch_size + assert offset>=0 and offset<len(dataset.data) + assert offset+minibatch_size<=len(dataset.data) + self.current=offset + def __iter__(self): + return self + def next(self): + #@todo: we suppose that we need to stop only when minibatch_size == 1. + # Otherwise, MinibatchWrapAroundIterator do it. + if self.current>=self.dataset.data.shape[0]: + raise StopIteration + sub_data = self.dataset.data[self.current] + self.minibatch._values = [sub_data[self.dataset.fields_columns[f]] for f in self.minibatch._names] + self.current+=self.minibatch_size + return self.minibatch + + return ArrayDataSetIterator2(self,self.fieldNames(),1,0,0) + def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): class ArrayDataSetIterator(object): def __init__(self,dataset,fieldnames,minibatch_size,n_batches,offset): @@ -1058,6 +1082,7 @@ def __iter__(self): return self def next(self): + #@todo: we suppose that MinibatchWrapAroundIterator stop the iterator sub_data = self.dataset.data[self.current:self.current+self.minibatch_size] self.minibatch._values = [sub_data[:,self.dataset.fields_columns[f]] for f in self.minibatch._names] self.current+=self.minibatch_size