comparison dataset.py @ 353:47538a45b878

Cached dataset seems debug, using n_batches... is n_batches around to stay?
author Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca>
date Tue, 17 Jun 2008 17:12:43 -0400
parents 7545207466d4
children d580b3a369a4
comparison
equal deleted inserted replaced
352:cefa8518ff48 353:47538a45b878
379 379
380 Note: A list-like container is something like a tuple, list, numpy.ndarray or 380 Note: A list-like container is something like a tuple, list, numpy.ndarray or
381 any other object that supports integer indexing and slicing. 381 any other object that supports integer indexing and slicing.
382 382
383 @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete 383 @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete
384 batches only, raise StopIteration 384 batches only, raise StopIteration.
385 @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it.
385 386
386 """ 387 """
387 #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ 388 #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\
388 assert offset >= 0 389 assert offset >= 0
389 assert offset < len(self) 390 assert offset < len(self)
867 raise NotImplementedError() 868 raise NotImplementedError()
868 def __iter__(self): 869 def __iter__(self):
869 return self 870 return self
870 def next(self): 871 def next(self):
871 upper = self.next_example+minibatch_size 872 upper = self.next_example+minibatch_size
872 if upper>self.ds.length: 873 if upper > len(self.ds) :
873 raise StopIteration 874 raise StopIteration()
875 assert upper<=len(self.ds) # instead of self.ds.length
874 #minibatch = Example(self.ds._fields.keys(), 876 #minibatch = Example(self.ds._fields.keys(),
875 # [field[self.next_example:upper] 877 # [field[self.next_example:upper]
876 # for field in self.ds._fields]) 878 # for field in self.ds._fields])
877 # tbm: modif to use fieldnames 879 # tbm: modif to use fieldnames
878 values = [] 880 values = []
1323 if cache_all_upon_construction: 1325 if cache_all_upon_construction:
1324 # this potentially brings all the source examples 1326 # this potentially brings all the source examples
1325 # into memory at once, which may be too much 1327 # into memory at once, which may be too much
1326 # the work could possibly be done by minibatches 1328 # the work could possibly be done by minibatches
1327 # that are as large as possible but no more than what memory allows. 1329 # that are as large as possible but no more than what memory allows.
1328 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() 1330 #
1331 # field_values is supposed to be an DataSetFields, that inherits from LookupList
1332 #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next()
1333 fields_values = DataSetFields(source_dataset,None)
1329 assert all([len(self)==len(field_values) for field_values in fields_values]) 1334 assert all([len(self)==len(field_values) for field_values in fields_values])
1330 for example in fields_values.examples(): 1335 for example in fields_values.examples():
1331 self.cached_examples.append(copy.copy(example)) 1336 self.cached_examples.append(copy.copy(example))
1332 1337
1333 self.fieldNames = source_dataset.fieldNames 1338 self.fieldNames = source_dataset.fieldNames
1342 class CacheIterator(object): 1347 class CacheIterator(object):
1343 def __init__(self,dataset): 1348 def __init__(self,dataset):
1344 self.dataset=dataset 1349 self.dataset=dataset
1345 self.current=offset 1350 self.current=offset
1346 self.all_fields = self.dataset.fieldNames()==fieldnames 1351 self.all_fields = self.dataset.fieldNames()==fieldnames
1352 self.n_batches = n_batches
1353 self.batch_counter = 0
1347 def __iter__(self): return self 1354 def __iter__(self): return self
1348 def next(self): 1355 def next(self):
1356 self.batch_counter += 1
1357 if self.n_batches and self.batch_counter > self.n_batches :
1358 raise StopIteration()
1349 upper = self.current+minibatch_size 1359 upper = self.current+minibatch_size
1360 if upper > len(self.dataset.source_dataset):
1361 raise StopIteration()
1350 cache_len = len(self.dataset.cached_examples) 1362 cache_len = len(self.dataset.cached_examples)
1351 if upper>cache_len: # whole minibatch is not already in cache 1363 if upper>cache_len: # whole minibatch is not already in cache
1352 # cache everything from current length to upper 1364 # cache everything from current length to upper
1353 for example in self.dataset.source_dataset[cache_len:upper]: 1365 #for example in self.dataset.source_dataset[cache_len:upper]:
1366 for example in self.dataset.source_dataset.subset[cache_len:upper]:
1354 self.dataset.cached_examples.append(example) 1367 self.dataset.cached_examples.append(example)
1355 all_fields_minibatch = Example(self.dataset.fieldNames(), 1368 all_fields_minibatch = Example(self.dataset.fieldNames(),
1356 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) 1369 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size]))
1370
1357 self.current+=minibatch_size 1371 self.current+=minibatch_size
1358 if self.all_fields: 1372 if self.all_fields:
1359 return all_fields_minibatch 1373 return all_fields_minibatch
1360 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) 1374 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames])
1361 return CacheIterator(self) 1375 return CacheIterator(self)