# HG changeset patch # User James Bergstra # Date 1210623442 14400 # Node ID cc8b032417dbeb428a6721f542fc9fa3345a351b # Parent ae5651a3696b9981b546c6ccc07b7534813c64bc# Parent f8a1ae7eb83e1ece601e590f03a1016d8c61ee6d merged diff -r ae5651a3696b -r cc8b032417db dataset.py --- a/dataset.py Mon May 12 16:16:32 2008 -0400 +++ b/dataset.py Mon May 12 16:17:22 2008 -0400 @@ -1045,15 +1045,16 @@ def __init__(self,source_dataset,cache_all_upon_construction=False): self.source_dataset=source_dataset self.cache_all_upon_construction=cache_all_upon_construction + self.cached_examples = [] if cache_all_upon_construction: # this potentially brings all the source examples # into memory at once, which may be too much # the work could possibly be done by minibatches # that are as large as possible but no more than what memory allows. fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() - self.cached_examples = zip(*fields_values) - else: - self.cached_examples = [] + assert all([len(self)==len(field_values) for field_values in fields_values]) + for example in fields_values.examples(): + self.cached_examples.append(example) self.fieldNames = source_dataset.fieldNames self.hasFields = source_dataset.hasFields @@ -1077,12 +1078,17 @@ for example in self.dataset.source_dataset[cache_len:upper]: self.dataset.cached_examples.append(example) all_fields_minibatch = Example(self.dataset.fieldNames(), - *self.dataset.cached_examples[self.current:self.current+minibatch_size]) + zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) if self.dataset.fieldNames()==fieldnames: return all_fields_minibatch return Example(fieldnames,[all_fields_minibatch[name] for name in fieldnames]) return CacheIterator(self) + def __getitem__(self,i): + if type(i)==int and len(self.cached_examples)>i: + return self.cached_examples[i] + else: + return DataSet.__getitem__(self,i) class ApplyFunctionDataSet(DataSet): """