Mercurial > pylearn
changeset 256:aef979d5bad9
Automated merge with ssh://projects@lgcm.iro.umontreal.ca/hg/pylearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 13:25:40 -0400 |
parents | e93e511deb9a (current diff) bf0a1ebc6e52 (diff) |
children | 19b14afe04b7 1cafd495098c |
files | |
diffstat | 2 files changed, 36 insertions(+), 12 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Tue Jun 03 13:18:33 2008 -0400 +++ b/dataset.py Tue Jun 03 13:25:40 2008 -0400 @@ -1141,6 +1141,7 @@ def __init__(self,dataset): self.dataset=dataset self.current=offset + self.all_fields = self.dataset.fieldNames()==fieldnames def __iter__(self): return self def next(self): upper = self.current+minibatch_size @@ -1152,7 +1153,7 @@ all_fields_minibatch = Example(self.dataset.fieldNames(), zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) self.current+=minibatch_size - if self.dataset.fieldNames()==fieldnames: + if self.all_fields: return all_fields_minibatch return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) return CacheIterator(self) @@ -1161,8 +1162,31 @@ if type(i)==int and len(self.cached_examples)>i: return self.cached_examples[i] else: - return DataSet.__getitem__(self,i) - + return self.source_dataset[i] + + def __iter__(self): + class CacheIteratorIter(object): + def __init__(self,dataset): + self.dataset=dataset + self.l = len(dataset) + self.current = 0 + self.fieldnames = self.dataset.fieldNames() + self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) + def __iter__(self): return self + def next(self): + if self.current>=self.l: + raise StopIteration + cache_len = len(self.dataset.cached_examples) + if self.current>=cache_len: # whole minibatch is not already in cache + # cache everything from current length to upper + self.dataset.cached_examples.append( + self.dataset.source_dataset[self.current]) + self.example._values = self.dataset.cached_examples[self.current] + self.current+=1 + return self.example + + return CacheIteratorIter(self) + class ApplyFunctionDataSet(DataSet): """ A L{DataSet} that contains as fields the results of applying a
--- a/test_dataset.py Tue Jun 03 13:18:33 2008 -0400 +++ b/test_dataset.py Tue Jun 03 13:25:40 2008 -0400 @@ -493,12 +493,11 @@ raise NotImplementedError() -def test_speed(): - print "test_speed" - import time - a2 = numpy.random.rand(100000,400) - ds = ArrayDataSet(a2,{'all':slice(0,a2.shape[1],1)}) +def test_speed(array, ds): + print "test_speed", ds.__class__ + mat = numpy.random.rand(400,100) + @print_timing def f_array_full(a): a+1 @@ -540,11 +539,13 @@ exs[0]+1 # ex[0]*mat - f_array_full(a2) - f_array_index(a2) - f_array_iter(a2) + f_array_full(array) + f_array_index(array) + f_array_iter(array) f_ds_index(ds) + f_ds_index(ds) + f_ds_iter(ds) f_ds_iter(ds) f_ds_mb1(ds,10) @@ -556,7 +557,6 @@ f_ds_mb2(ds,1000) f_ds_mb2(ds,10000) - del a2, ds if __name__=='__main__': test1()