Mercurial > pylearn
comparison dataset.py @ 353:47538a45b878
Cached dataset seems debug, using n_batches... is n_batches around to stay?
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Tue, 17 Jun 2008 17:12:43 -0400 |
parents | 7545207466d4 |
children | d580b3a369a4 |
comparison
equal
deleted
inserted
replaced
352:cefa8518ff48 | 353:47538a45b878 |
---|---|
379 | 379 |
380 Note: A list-like container is something like a tuple, list, numpy.ndarray or | 380 Note: A list-like container is something like a tuple, list, numpy.ndarray or |
381 any other object that supports integer indexing and slicing. | 381 any other object that supports integer indexing and slicing. |
382 | 382 |
383 @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete | 383 @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete |
384 batches only, raise StopIteration | 384 batches only, raise StopIteration. |
385 @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it. | |
385 | 386 |
386 """ | 387 """ |
387 #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ | 388 #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ |
388 assert offset >= 0 | 389 assert offset >= 0 |
389 assert offset < len(self) | 390 assert offset < len(self) |
867 raise NotImplementedError() | 868 raise NotImplementedError() |
868 def __iter__(self): | 869 def __iter__(self): |
869 return self | 870 return self |
870 def next(self): | 871 def next(self): |
871 upper = self.next_example+minibatch_size | 872 upper = self.next_example+minibatch_size |
872 if upper>self.ds.length: | 873 if upper > len(self.ds) : |
873 raise StopIteration | 874 raise StopIteration() |
875 assert upper<=len(self.ds) # instead of self.ds.length | |
874 #minibatch = Example(self.ds._fields.keys(), | 876 #minibatch = Example(self.ds._fields.keys(), |
875 # [field[self.next_example:upper] | 877 # [field[self.next_example:upper] |
876 # for field in self.ds._fields]) | 878 # for field in self.ds._fields]) |
877 # tbm: modif to use fieldnames | 879 # tbm: modif to use fieldnames |
878 values = [] | 880 values = [] |
1323 if cache_all_upon_construction: | 1325 if cache_all_upon_construction: |
1324 # this potentially brings all the source examples | 1326 # this potentially brings all the source examples |
1325 # into memory at once, which may be too much | 1327 # into memory at once, which may be too much |
1326 # the work could possibly be done by minibatches | 1328 # the work could possibly be done by minibatches |
1327 # that are as large as possible but no more than what memory allows. | 1329 # that are as large as possible but no more than what memory allows. |
1328 fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() | 1330 # |
1331 # field_values is supposed to be an DataSetFields, that inherits from LookupList | |
1332 #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() | |
1333 fields_values = DataSetFields(source_dataset,None) | |
1329 assert all([len(self)==len(field_values) for field_values in fields_values]) | 1334 assert all([len(self)==len(field_values) for field_values in fields_values]) |
1330 for example in fields_values.examples(): | 1335 for example in fields_values.examples(): |
1331 self.cached_examples.append(copy.copy(example)) | 1336 self.cached_examples.append(copy.copy(example)) |
1332 | 1337 |
1333 self.fieldNames = source_dataset.fieldNames | 1338 self.fieldNames = source_dataset.fieldNames |
1342 class CacheIterator(object): | 1347 class CacheIterator(object): |
1343 def __init__(self,dataset): | 1348 def __init__(self,dataset): |
1344 self.dataset=dataset | 1349 self.dataset=dataset |
1345 self.current=offset | 1350 self.current=offset |
1346 self.all_fields = self.dataset.fieldNames()==fieldnames | 1351 self.all_fields = self.dataset.fieldNames()==fieldnames |
1352 self.n_batches = n_batches | |
1353 self.batch_counter = 0 | |
1347 def __iter__(self): return self | 1354 def __iter__(self): return self |
1348 def next(self): | 1355 def next(self): |
1356 self.batch_counter += 1 | |
1357 if self.n_batches and self.batch_counter > self.n_batches : | |
1358 raise StopIteration() | |
1349 upper = self.current+minibatch_size | 1359 upper = self.current+minibatch_size |
1360 if upper > len(self.dataset.source_dataset): | |
1361 raise StopIteration() | |
1350 cache_len = len(self.dataset.cached_examples) | 1362 cache_len = len(self.dataset.cached_examples) |
1351 if upper>cache_len: # whole minibatch is not already in cache | 1363 if upper>cache_len: # whole minibatch is not already in cache |
1352 # cache everything from current length to upper | 1364 # cache everything from current length to upper |
1353 for example in self.dataset.source_dataset[cache_len:upper]: | 1365 #for example in self.dataset.source_dataset[cache_len:upper]: |
1366 for example in self.dataset.source_dataset.subset[cache_len:upper]: | |
1354 self.dataset.cached_examples.append(example) | 1367 self.dataset.cached_examples.append(example) |
1355 all_fields_minibatch = Example(self.dataset.fieldNames(), | 1368 all_fields_minibatch = Example(self.dataset.fieldNames(), |
1356 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) | 1369 zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) |
1370 | |
1357 self.current+=minibatch_size | 1371 self.current+=minibatch_size |
1358 if self.all_fields: | 1372 if self.all_fields: |
1359 return all_fields_minibatch | 1373 return all_fields_minibatch |
1360 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) | 1374 return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) |
1361 return CacheIterator(self) | 1375 return CacheIterator(self) |