Mercurial > pylearn
comparison dataset.py @ 252:856d14dc4468
implemented CachedDataSet.__iter__ as an optimization
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Tue, 03 Jun 2008 13:22:45 -0400 |
parents | 7e6edee187e3 |
children | 394e07e2b0fd |
comparison
equal
deleted
inserted
replaced
251:7e6edee187e3 | 252:856d14dc4468 |
---|---|
1160 def __getitem__(self,i): | 1160 def __getitem__(self,i): |
1161 if type(i)==int and len(self.cached_examples)>i: | 1161 if type(i)==int and len(self.cached_examples)>i: |
1162 return self.cached_examples[i] | 1162 return self.cached_examples[i] |
1163 else: | 1163 else: |
1164 return self.source_dataset[i] | 1164 return self.source_dataset[i] |
1165 | 1165 |
1166 def __iter__(self): | |
1167 class CacheIteratorIter(object): | |
1168 def __init__(self,dataset): | |
1169 self.dataset=dataset | |
1170 self.l = len(dataset) | |
1171 self.current = 0 | |
1172 self.fieldnames = self.dataset.fieldNames() | |
1173 self.example = LookupList(self.fieldnames,[0]*len(self.fieldnames)) | |
1174 def __iter__(self): return self | |
1175 def next(self): | |
1176 if self.current>=self.l: | |
1177 raise StopIteration | |
1178 cache_len = len(self.dataset.cached_examples) | |
1179 if self.current>=cache_len: # whole minibatch is not already in cache | |
1180 # cache everything from current length to upper | |
1181 self.dataset.cached_examples.append( | |
1182 self.dataset.source_dataset[self.current]) | |
1183 self.example._values = self.dataset.cached_examples[self.current] | |
1184 self.current+=1 | |
1185 return self.example | |
1186 | |
1187 return CacheIteratorIter(self) | |
1188 | |
1189 # class CachedDataSetIterator(object): | |
1190 # def __init__(self,dataset,fieldnames):#,minibatch_size,n_batches,offset): | |
1191 # # if fieldnames is None: fieldnames = dataset.fieldNames() | |
1192 # # store the resulting minibatch in a lookup-list of values | |
1193 # self.minibatch = LookupList(fieldnames,[0]*len(fieldnames)) | |
1194 # self.dataset=dataset | |
1195 # # self.minibatch_size=minibatch_size | |
1196 # # assert offset>=0 and offset<len(dataset.data) | |
1197 # # assert offset+minibatch_size<=len(dataset.data) | |
1198 # self.current=0 | |
1199 # self.columns = [self.dataset.fields_columns[f] | |
1200 # for f in self.minibatch._names] | |
1201 # self.l = len(self.dataset) | |
1202 # def __iter__(self): | |
1203 # return self | |
1204 # def next(self): | |
1205 # #@todo: we suppose that we need to stop only when minibatch_size == 1. | |
1206 # # Otherwise, MinibatchWrapAroundIterator do it. | |
1207 # if self.current>=self.l: | |
1208 # raise StopIteration | |
1209 # sub_data = self.dataset.data[self.current] | |
1210 # self.minibatch._values = [sub_data[c] for c in self.columns] | |
1211 | |
1212 # self.current+=self.minibatch_size | |
1213 # return self.minibatch | |
1214 | |
1215 # return CachedDataSetIterator(self,self.fieldNames())#,1,0,0) | |
1216 | |
1166 class ApplyFunctionDataSet(DataSet): | 1217 class ApplyFunctionDataSet(DataSet): |
1167 """ | 1218 """ |
1168 A L{DataSet} that contains as fields the results of applying a | 1219 A L{DataSet} that contains as fields the results of applying a |
1169 given function example-wise or minibatch-wise to all the fields of | 1220 given function example-wise or minibatch-wise to all the fields of |
1170 an input dataset. The output of the function should be an iterable | 1221 an input dataset. The output of the function should be an iterable |