Mercurial > pylearn
comparison dataset.py @ 258:19b14afe04b7
merged
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 16:06:21 -0400 |
parents | 4ad6bc9b4f03 8ec867d12428 |
children | 6e69fb91f3c0 6226ebafefc3 |
comparison
equal
deleted
inserted
replaced
257:4ad6bc9b4f03 | 258:19b14afe04b7 |
---|---|
1148 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): | 1148 def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): |
1149 class CacheIterator(object): | 1149 class CacheIterator(object): |
1150 def __init__(self,dataset): | 1150 def __init__(self,dataset): |
1151 self.dataset=dataset | 1151 self.dataset=dataset |
1152 self.current=offset | 1152 self.current=offset |
1153 self.all_fields = self.dataset.fieldNames()==fieldnames | |
1153 def __iter__(self): return self | 1154 def __iter__(self): return self |
1154 def next(self): | 1155 def next(self): |
1155 upper = self.current+minibatch_size | 1156 upper = self.current+minibatch_size |
1156 cache_len = len(self.dataset.cached_examples) | 1157 cache_len = len(self.dataset.cached_examples) |
1157 if upper>cache_len: # whole minibatch is not already in cache | 1158 if upper>cache_len: # whole minibatch is not already in cache |
1159 for example in self.dataset.source_dataset[cache_len:upper]: | 1160 for example in self.dataset.source_dataset[cache_len:upper]: |
1160 self.dataset.cached_examples.append(example) | 1161 self.dataset.cached_examples.append(example) |
1161 all_fields_minibatch = Example(self.dataset.fieldNames(), | 1162 all_fields_minibatch = Example(self.dataset.fieldNames(), |
1162 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) | 1163 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) |
1163 self.current+=minibatch_size | 1164 self.current+=minibatch_size |
1164 if self.dataset.fieldNames()==fieldnames: | 1165 if self.all_fields: |
1165 return all_fields_minibatch | 1166 return all_fields_minibatch |
1166 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1167 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1167 return CacheIterator(self) | 1168 return CacheIterator(self) |
1168 | 1169 |
1169 def __getitem__(self,i): | 1170 def __getitem__(self,i): |
1170 if type(i)==int and len(self.cached_examples)>i: | 1171 if type(i)==int and len(self.cached_examples)>i: |
1171 return self.cached_examples[i] | 1172 return self.cached_examples[i] |
1172 else: | 1173 else: |
1173 return DataSet.__getitem__(self,i) | 1174 return self.source_dataset[i] |
1174 | 1175 |
1176 def __iter__(self): | |
1177 class CacheIteratorIter(object): | |
1178 def __init__(self,dataset): | |
1179 self.dataset=dataset | |
1180 self.l = len(dataset) | |
1181 self.current = 0 | |
1182 self.fieldnames = self.dataset.fieldNames() | |
1183 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) | |
1184 def __iter__(self): return self | |
1185 def next(self): | |
1186 if self.current>=self.l: | |
1187 raise StopIteration | |
1188 cache_len = len(self.dataset.cached_examples) | |
1189 if self.current>=cache_len: # whole minibatch is not already in cache | |
1190 # cache everything from current length to upper | |
1191 self.dataset.cached_examples.append( | |
1192 self.dataset.source_dataset[self.current]) | |
1193 self.example._values = self.dataset.cached_examples[self.current] | |
1194 self.current+=1 | |
1195 return self.example | |
1196 | |
1197 return CacheIteratorIter(self) | |
1198 | |
1175 class ApplyFunctionDataSet(DataSet): | 1199 class ApplyFunctionDataSet(DataSet): |
1176 """ | 1200 """ |
1177 A L{DataSet} that contains as fields the results of applying a | 1201 A L{DataSet} that contains as fields the results of applying a |
1178 given function example-wise or minibatch-wise to all the fields of | 1202 given function example-wise or minibatch-wise to all the fields of |
1179 an input dataset. The output of the function should be an iterable | 1203 an input dataset. The output of the function should be an iterable |