comparison dataset.py @ 254:8ec867d12428

optimication in CachedDataSet.minibatches_nowrap
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Tue, 03 Jun 2008 13:24:41 -0400
parents 394e07e2b0fd
children 19b14afe04b7 1cafd495098c
comparison
equal deleted inserted replaced
253:394e07e2b0fd 254:8ec867d12428
1139 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1139 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1140 class CacheIterator(object): 1140 class CacheIterator(object):
1141 def __init__(self,dataset): 1141 def __init__(self,dataset):
1142 self.dataset=dataset 1142 self.dataset=dataset
1143 self.current=offset 1143 self.current=offset
1144 self.all_fields = self.dataset.fieldNames()==fieldnames
1144 def __iter__(self): return self 1145 def __iter__(self): return self
1145 def next(self): 1146 def next(self):
1146 upper = self.current+minibatch_size 1147 upper = self.current+minibatch_size
1147 cache_len = len(self.dataset.cached_examples) 1148 cache_len = len(self.dataset.cached_examples)
1148 if upper>cache_len: # whole minibatch is not already in cache 1149 if upper>cache_len: # whole minibatch is not already in cache
1150 for example in self.dataset.source_dataset[cache_len:upper]: 1151 for example in self.dataset.source_dataset[cache_len:upper]:
1151 self.dataset.cached_examples.append(example) 1152 self.dataset.cached_examples.append(example)
1152 all_fields_minibatch = Example(self.dataset.fieldNames(), 1153 all_fields_minibatch = Example(self.dataset.fieldNames(),
1153 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) 1154 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
1154 self.current+=minibatch_size 1155 self.current+=minibatch_size
1155 if self.dataset.fieldNames()==fieldnames: 1156 if self.all_fields:
1156 return all_fields_minibatch 1157 return all_fields_minibatch
1157 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) 1158 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1158 return CacheIterator(self) 1159 return CacheIterator(self)
1159 1160
1160 def __getitem__(self,i): 1161 def __getitem__(self,i):