# HG changeset patch # User Thierry Bertin-Mahieux # Date 1213737163 14400 # Node ID 47538a45b878dde531c29dd53bde47963530bb9f # Parent cefa8518ff48080c1da38fff63d31af6b1cb07e7 Cached dataset seems debug, using n_batches... is n_batches around to stay? diff -r cefa8518ff48 -r 47538a45b878 dataset.py --- a/dataset.py Tue Jun 17 16:50:35 2008 -0400 +++ b/dataset.py Tue Jun 17 17:12:43 2008 -0400 @@ -381,7 +381,8 @@ any other object that supports integer indexing and slicing. @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete - batches only, raise StopIteration + batches only, raise StopIteration. + @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it. """ #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ @@ -869,8 +870,9 @@ return self def next(self): upper = self.next_example+minibatch_size - if upper>self.ds.length: - raise StopIteration + if upper > len(self.ds) : + raise StopIteration() + assert upper<=len(self.ds) # instead of self.ds.length #minibatch = Example(self.ds._fields.keys(), # [field[self.next_example:upper] # for field in self.ds._fields]) @@ -1325,7 +1327,10 @@ # into memory at once, which may be too much # the work could possibly be done by minibatches # that are as large as possible but no more than what memory allows. - fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() + # + # field_values is supposed to be an DataSetFields, that inherits from LookupList + #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() + fields_values = DataSetFields(source_dataset,None) assert all([len(self)==len(field_values) for field_values in fields_values]) for example in fields_values.examples(): self.cached_examples.append(copy.copy(example)) @@ -1344,16 +1349,25 @@ self.dataset=dataset self.current=offset self.all_fields = self.dataset.fieldNames()==fieldnames + self.n_batches = n_batches + self.batch_counter = 0 def __iter__(self): return self def next(self): + self.batch_counter += 1 + if self.n_batches and self.batch_counter > self.n_batches : + raise StopIteration() upper = self.current+minibatch_size + if upper > len(self.dataset.source_dataset): + raise StopIteration() cache_len = len(self.dataset.cached_examples) if upper>cache_len: # whole minibatch is not already in cache # cache everything from current length to upper - for example in self.dataset.source_dataset[cache_len:upper]: + #for example in self.dataset.source_dataset[cache_len:upper]: + for example in self.dataset.source_dataset.subset[cache_len:upper]: self.dataset.cached_examples.append(example) all_fields_minibatch = Example(self.dataset.fieldNames(), zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) + self.current+=minibatch_size if self.all_fields: return all_fields_minibatch