Mercurial > pylearn
diff dataset.py @ 203:80731832c62b
Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Thu, 15 May 2008 15:21:00 -0400 |
parents | cb6b945acf5a c9c966cab763 |
children | bd728c83faff 6f55e301c687 |
line wrap: on
line diff
--- a/dataset.py Thu May 15 12:55:21 2008 -0400 +++ b/dataset.py Thu May 15 15:21:00 2008 -0400 @@ -1111,6 +1111,8 @@ given function example-wise or minibatch-wise to all the fields of an input dataset. The output of the function should be an iterable (e.g. a list or a LookupList) over the resulting values. + + The function take as input the fields of the dataset, not the examples. In minibatch mode, the function is expected to work on minibatches (takes a minibatch in input and returns a minibatch in output). More @@ -1170,7 +1172,7 @@ function_outputs = self.output_dataset.function(*function_inputs) else: input_examples = zip(*function_inputs) - output_examples = [self.output_dataset.function(input_example) + output_examples = [self.output_dataset.function(*input_example) for input_example in input_examples] function_outputs = [self.output_dataset.valuesVStack(name,values) for name,values in zip(all_output_names, @@ -1190,11 +1192,14 @@ self.input_iterator=output_dataset.input_dataset.__iter__() def __iter__(self): return self def next(self): - function_inputs = self.input_iterator.next() if self.output_dataset.minibatch_mode: - function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)] + function_inputs = [[input] for input in self.input_iterator.next()] + outputs = self.output_dataset.function(*function_inputs) + assert all([hasattr(output,'__iter__') for output in outputs]) + function_outputs = [output[0] for output in outputs] else: - function_outputs = self.output_dataset.function(function_inputs) + function_inputs = self.input_iterator.next() + function_outputs = self.output_dataset.function(*function_inputs) return Example(self.output_dataset.output_names,function_outputs) return ApplyFunctionSingleExampleIterator(self)