Mercurial > pylearn
changeset 353:47538a45b878
Cached dataset seems debug, using n_batches... is n_batches around to stay?
author | Thierry Bertin-Mahieux <bertinmt@iro.umontreal.ca> |
---|---|
date | Tue, 17 Jun 2008 17:12:43 -0400 |
parents | cefa8518ff48 |
children | d580b3a369a4 |
files | dataset.py |
diffstat | 1 files changed, 19 insertions(+), 5 deletions(-) [+] |
line wrap: on
line diff
--- a/dataset.py Tue Jun 17 16:50:35 2008 -0400 +++ b/dataset.py Tue Jun 17 17:12:43 2008 -0400 @@ -381,7 +381,8 @@ any other object that supports integer indexing and slicing. @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete - batches only, raise StopIteration + batches only, raise StopIteration. + @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it. """ #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ @@ -869,8 +870,9 @@ return self def next(self): upper = self.next_example+minibatch_size - if upper>self.ds.length: - raise StopIteration + if upper > len(self.ds) : + raise StopIteration() + assert upper<=len(self.ds) # instead of self.ds.length #minibatch = Example(self.ds._fields.keys(), # [field[self.next_example:upper] # for field in self.ds._fields]) @@ -1325,7 +1327,10 @@ # into memory at once, which may be too much # the work could possibly be done by minibatches # that are as large as possible but no more than what memory allows. - fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() + # + # field_values is supposed to be an DataSetFields, that inherits from LookupList + #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() + fields_values = DataSetFields(source_dataset,None) assert all([len(self)==len(field_values) for field_values in fields_values]) for example in fields_values.examples(): self.cached_examples.append(copy.copy(example)) @@ -1344,16 +1349,25 @@ self.dataset=dataset self.current=offset self.all_fields = self.dataset.fieldNames()==fieldnames + self.n_batches = n_batches + self.batch_counter = 0 def __iter__(self): return self def next(self): + self.batch_counter += 1 + if self.n_batches and self.batch_counter > self.n_batches : + raise StopIteration() upper = self.current+minibatch_size + if upper > len(self.dataset.source_dataset): + raise StopIteration() cache_len = len(self.dataset.cached_examples) if upper>cache_len: # whole minibatch is not already in cache # cache everything from current length to upper - for example in self.dataset.source_dataset[cache_len:upper]: + #for example in self.dataset.source_dataset[cache_len:upper]: + for example in self.dataset.source_dataset.subset[cache_len:upper]: self.dataset.cached_examples.append(example) all_fields_minibatch = Example(self.dataset.fieldNames(), zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) + self.current+=minibatch_size if self.all_fields: return all_fields_minibatch