Mercurial > pylearn
diff dataset.py @ 252:856d14dc4468
implemented CachedDataSet.__iter__ as an optimization
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 13:22:45 -0400 |
parents | 7e6edee187e3 |
children | 394e07e2b0fd |
line wrap: on
line diff
--- a/dataset.py Tue Jun 03 12:25:53 2008 -0400 +++ b/dataset.py Tue Jun 03 13:22:45 2008 -0400 @@ -1162,7 +1162,58 @@ return self.cached_examples[i] else: return self.source_dataset[i] - + + def __iter__(self): + class CacheIteratorIter(object): + def __init__(self,dataset): + self.dataset=dataset + self.l = len(dataset) + self.current = 0 + self.fieldnames = self.dataset.fieldNames() + self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) + def __iter__(self): return self + def next(self): + if self.current>=self.l: + raise StopIteration + cache_len = len(self.dataset.cached_examples) + if self.current>=cache_len: # whole minibatch is not already in cache + # cache everything from current length to upper + self.dataset.cached_examples.append( + self.dataset.source_dataset[self.current]) + self.example._values = self.dataset.cached_examples[self.current] + self.current+=1 + return self.example + + return CacheIteratorIter(self) + +# class CachedDataSetIterator(object): +# def __init__(self,dataset,fieldnames):#,minibatch_size,n_batches,offset): +# # if fieldnames is None: fieldnames = dataset.fieldNames() +# # store the resulting minibatch in a lookup-list of values +# self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) +# self.dataset=dataset +# # self.minibatch_size=minibatch_size +# # assert offset>=0 and offset<len(dataset.data) +# # assert offset+minibatch_size<=len(dataset.data) +# self.current=0 +# self.columns = [self.dataset.fields_columns[f] +# for f in self.minibatch._names] +# self.l = len(self.dataset) +# def __iter__(self): +# return self +# def next(self): +# #@todo: we suppose that we need to stop only when minibatch_size == 1. +# # Otherwise, MinibatchWrapAroundIterator do it. +# if self.current>=self.l: +# raise StopIteration +# sub_data = self.dataset.data[self.current] +# self.minibatch._values = [sub_data[c] for c in self.columns] + +# self.current+=self.minibatch_size +# return self.minibatch + +# return CachedDataSetIterator(self,self.fieldNames())#,1,0,0) + class ApplyFunctionDataSet(DataSet): """ A L{DataSet} that contains as fields the results of applying a