Mercurial > pylearn
diff dataset.py @ 376:c9a89be5cb0a
Redesigning linear_regression
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 07 Jul 2008 10:08:35 -0400 |
parents | 18702ceb2096 |
children | 835830e52b42 |
line wrap: on
line diff
--- a/dataset.py Mon Jun 16 17:47:36 2008 -0400 +++ b/dataset.py Mon Jul 07 10:08:35 2008 -0400 @@ -1,6 +1,6 @@ from lookup_list import LookupList as Example -from misc import unique_elements_list_intersection +from common.misc import unique_elements_list_intersection from string import join from sys import maxint import numpy, copy @@ -381,7 +381,8 @@ any other object that supports integer indexing and slicing. @ATTENTION: now minibatches returns minibatches_nowrap, which is supposed to return complete - batches only, raise StopIteration + batches only, raise StopIteration. + @ATTENTION: minibatches returns a LookupList, we can't iterate over examples on it. """ #return DataSet.MinibatchWrapAroundIterator(self,fieldnames,minibatch_size,n_batches,offset)\ @@ -435,6 +436,16 @@ Return a dataset that sees only the fields whose name are specified. """ assert self.hasFields(*fieldnames) + #return self.fields(*fieldnames).examples() + fieldnames_list = list(fieldnames) + return FieldsSubsetDataSet(self,fieldnames_list) + + def cached_fields_subset(self,*fieldnames) : + """ + Behaviour is supposed to be the same as __call__(*fieldnames), but the dataset returned is cached. + @see : dataset.__call__ + """ + assert self.hasFields(*fieldnames) return self.fields(*fieldnames).examples() def fields(self,*fieldnames): @@ -692,6 +703,7 @@ assert len(src_fieldnames)==len(new_fieldnames) self.valuesHStack = src.valuesHStack self.valuesVStack = src.valuesVStack + self.lookup_fields = Example(new_fieldnames,src_fieldnames) def __len__(self): return len(self.src) @@ -719,9 +731,18 @@ def minibatches_nowrap(self,fieldnames,minibatch_size,n_batches,offset): assert self.hasFields(*fieldnames) - return self.src.minibatches_nowrap(fieldnames,minibatch_size,n_batches,offset) + cursor = Example(fieldnames,[0]*len(fieldnames)) + for batch in self.src.minibatches_nowrap([self.lookup_fields[f] for f in fieldnames],minibatch_size,n_batches,offset): + cursor._values=batch._values + yield cursor + def __getitem__(self,i): - return FieldsSubsetDataSet(self.src[i],self.new_fieldnames) +# return FieldsSubsetDataSet(self.src[i],self.new_fieldnames) + complete_example = self.src[i] + return Example(self.new_fieldnames, + [complete_example[field] + for field in self.src_fieldnames]) + class DataSetFields(Example): @@ -859,7 +880,9 @@ return self def next(self): upper = self.next_example+minibatch_size - assert upper<=self.ds.length + if upper > len(self.ds) : + raise StopIteration() + assert upper<=len(self.ds) # instead of self.ds.length #minibatch = Example(self.ds._fields.keys(), # [field[self.next_example:upper] # for field in self.ds._fields]) @@ -1314,7 +1337,10 @@ # into memory at once, which may be too much # the work could possibly be done by minibatches # that are as large as possible but no more than what memory allows. - fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() + # + # field_values is supposed to be an DataSetFields, that inherits from LookupList + #fields_values = source_dataset.minibatches(minibatch_size=len(source_dataset)).__iter__().next() + fields_values = DataSetFields(source_dataset,None) assert all([len(self)==len(field_values) for field_values in fields_values]) for example in fields_values.examples(): self.cached_examples.append(copy.copy(example)) @@ -1333,16 +1359,25 @@ self.dataset=dataset self.current=offset self.all_fields = self.dataset.fieldNames()==fieldnames + self.n_batches = n_batches + self.batch_counter = 0 def __iter__(self): return self def next(self): + self.batch_counter += 1 + if self.n_batches and self.batch_counter > self.n_batches : + raise StopIteration() upper = self.current+minibatch_size + if upper > len(self.dataset.source_dataset): + raise StopIteration() cache_len = len(self.dataset.cached_examples) if upper>cache_len: # whole minibatch is not already in cache # cache everything from current length to upper - for example in self.dataset.source_dataset[cache_len:upper]: + #for example in self.dataset.source_dataset[cache_len:upper]: + for example in self.dataset.source_dataset.subset[cache_len:upper]: self.dataset.cached_examples.append(example) all_fields_minibatch = Example(self.dataset.fieldNames(), zip(*self.dataset.cached_examples[self.current:self.current+minibatch_size])) + self.current+=minibatch_size if self.all_fields: return all_fields_minibatch