Mercurial > pylearn
comparison dataset.py @ 254:8ec867d12428
optimication in CachedDataSet.minibatches_nowrap
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 13:24:41 -0400 |
parents | 394e07e2b0fd |
children | 19b14afe04b7 1cafd495098c |
comparison
equal
deleted
inserted
replaced
253:394e07e2b0fd | 254:8ec867d12428 |
---|---|
1139 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1139 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1140 class CacheIterator(object): | 1140 class CacheIterator(object): |
1141 def __init__(self,dataset): | 1141 def __init__(self,dataset): |
1142 self.dataset=dataset | 1142 self.dataset=dataset |
1143 self.current=offset | 1143 self.current=offset |
1144 self.all_fields = self.dataset.fieldNames()==fieldnames | |
1144 def __iter__(self): return self | 1145 def __iter__(self): return self |
1145 def next(self): | 1146 def next(self): |
1146 upper = self.current+minibatch_size | 1147 upper = self.current+minibatch_size |
1147 cache_len = len(self.dataset.cached_examples) | 1148 cache_len = len(self.dataset.cached_examples) |
1148 if upper>cache_len: # whole minibatch is not already in cache | 1149 if upper>cache_len: # whole minibatch is not already in cache |
1150 for example in self.dataset.source_dataset[cache_len:upper]: | 1151 for example in self.dataset.source_dataset[cache_len:upper]: |
1151 self.dataset.cached_examples.append(example) | 1152 self.dataset.cached_examples.append(example) |
1152 all_fields_minibatch = Example(self.dataset.fieldNames(), | 1153 all_fields_minibatch = Example(self.dataset.fieldNames(), |
1153 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) | 1154 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) |
1154 self.current+=minibatch_size | 1155 self.current+=minibatch_size |
1155 if self.dataset.fieldNames()==fieldnames: | 1156 if self.all_fields: |
1156 return all_fields_minibatch | 1157 return all_fields_minibatch |
1157 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1158 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1158 return CacheIterator(self) | 1159 return CacheIterator(self) |
1159 | 1160 |
1160 def __getitem__(self,i): | 1161 def __getitem__(self,i): |