comparison dataset.py @ 258:19b14afe04b7

merged
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 03 Jun 2008 16:06:21 -0400
parents 4ad6bc9b4f03 8ec867d12428
children 6e69fb91f3c0 6226ebafefc3
comparison
equal deleted inserted replaced
257:4ad6bc9b4f03 258:19b14afe04b7
1148 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): 1148 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset):
1149 class CacheIterator(object): 1149 class CacheIterator(object):
1150 def __init__(self,dataset): 1150 def __init__(self,dataset):
1151 self.dataset=dataset 1151 self.dataset=dataset
1152 self.current=offset 1152 self.current=offset
1153 self.all_fields = self.dataset.fieldNames()==fieldnames
1153 def __iter__(self): return self 1154 def __iter__(self): return self
1154 def next(self): 1155 def next(self):
1155 upper = self.current+minibatch_size 1156 upper = self.current+minibatch_size
1156 cache_len = len(self.dataset.cached_examples) 1157 cache_len = len(self.dataset.cached_examples)
1157 if upper>cache_len: # whole minibatch is not already in cache 1158 if upper>cache_len: # whole minibatch is not already in cache
1159 for example in self.dataset.source_dataset[cache_len:upper]: 1160 for example in self.dataset.source_dataset[cache_len:upper]:
1160 self.dataset.cached_examples.append(example) 1161 self.dataset.cached_examples.append(example)
1161 all_fields_minibatch = Example(self.dataset.fieldNames(), 1162 all_fields_minibatch = Example(self.dataset.fieldNames(),
1162 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) 1163 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
1163 self.current+=minibatch_size 1164 self.current+=minibatch_size
1164 if self.dataset.fieldNames()==fieldnames: 1165 if self.all_fields:
1165 return all_fields_minibatch 1166 return all_fields_minibatch
1166 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) 1167 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1167 return CacheIterator(self) 1168 return CacheIterator(self)
1168 1169
1169 def __getitem__(self,i): 1170 def __getitem__(self,i):
1170 if type(i)==int and len(self.cached_examples)>i: 1171 if type(i)==int and len(self.cached_examples)>i:
1171 return self.cached_examples[i] 1172 return self.cached_examples[i]
1172 else: 1173 else:
1173 return DataSet.__getitem__(self,i) 1174 return self.source_dataset[i]
1174 1175
1176 def __iter__(self):
1177 class CacheIteratorIter(object):
1178 def __init__(self,dataset):
1179 self.dataset=dataset
1180 self.l = len(dataset)
1181 self.current = 0
1182 self.fieldnames = self.dataset.fieldNames()
1183 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames))
1184 def __iter__(self): return self
1185 def next(self):
1186 if self.current>=self.l:
1187 raise StopIteration
1188 cache_len = len(self.dataset.cached_examples)
1189 if self.current>=cache_len: # whole minibatch is not already in cache
1190 # cache everything from current length to upper
1191 self.dataset.cached_examples.append(
1192 self.dataset.source_dataset[self.current])
1193 self.example._values = self.dataset.cached_examples[self.current]
1194 self.current+=1
1195 return self.example
1196
1197 return CacheIteratorIter(self)
1198
1175 class ApplyFunctionDataSet(DataSet): 1199 class ApplyFunctionDataSet(DataSet):
1176 """ 1200 """
1177 A L{DataSet} that contains as fields the results of applying a 1201 A L{DataSet} that contains as fields the results of applying a
1178 given function example-wise or minibatch-wise to all the fields of 1202 given function example-wise or minibatch-wise to all the fields of
1179 an input dataset. The output of the function should be an iterable 1203 an input dataset. The output of the function should be an iterable