comparison dataset.py @ 203:80731832c62b

Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author Frederic Bastien <bastienf@iro.umontreal.ca>
date Thu, 15 May 2008 15:21:00 -0400
parents cb6b945acf5a c9c966cab763
children bd728c83faff 6f55e301c687
comparison
equal deleted inserted replaced
193:cb6b945acf5a 203:80731832c62b
1109 """ 1109 """
1110 A L{DataSet} that contains as fields the results of applying a 1110 A L{DataSet} that contains as fields the results of applying a
1111 given function example-wise or minibatch-wise to all the fields of 1111 given function example-wise or minibatch-wise to all the fields of
1112 an input dataset. The output of the function should be an iterable 1112 an input dataset. The output of the function should be an iterable
1113 (e.g. a list or a LookupList) over the resulting values. 1113 (e.g. a list or a LookupList) over the resulting values.
1114
1115 The function take as input the fields of the dataset, not the examples.
1114 1116
1115 In minibatch mode, the function is expected to work on minibatches 1117 In minibatch mode, the function is expected to work on minibatches
1116 (takes a minibatch in input and returns a minibatch in output). More 1118 (takes a minibatch in input and returns a minibatch in output). More
1117 precisely, it means that each element of the input or output list 1119 precisely, it means that each element of the input or output list
1118 should be iterable and indexable over the individual example values 1120 should be iterable and indexable over the individual example values
1168 all_output_names = self.output_dataset.output_names 1170 all_output_names = self.output_dataset.output_names
1169 if self.output_dataset.minibatch_mode: 1171 if self.output_dataset.minibatch_mode:
1170 function_outputs = self.output_dataset.function(*function_inputs) 1172 function_outputs = self.output_dataset.function(*function_inputs)
1171 else: 1173 else:
1172 input_examples = zip(*function_inputs) 1174 input_examples = zip(*function_inputs)
1173 output_examples = [self.output_dataset.function(input_example) 1175 output_examples = [self.output_dataset.function(*input_example)
1174 for input_example in input_examples] 1176 for input_example in input_examples]
1175 function_outputs = [self.output_dataset.valuesVStack(name,values) 1177 function_outputs = [self.output_dataset.valuesVStack(name,values)
1176 for name,values in zip(all_output_names, 1178 for name,values in zip(all_output_names,
1177 zip(*output_examples))] 1179 zip(*output_examples))]
1178 all_outputs = Example(all_output_names,function_outputs) 1180 all_outputs = Example(all_output_names,function_outputs)
1188 self.current=0 1190 self.current=0
1189 self.output_dataset=output_dataset 1191 self.output_dataset=output_dataset
1190 self.input_iterator=output_dataset.input_dataset.__iter__() 1192 self.input_iterator=output_dataset.input_dataset.__iter__()
1191 def __iter__(self): return self 1193 def __iter__(self): return self
1192 def next(self): 1194 def next(self):
1193 function_inputs = self.input_iterator.next()
1194 if self.output_dataset.minibatch_mode: 1195 if self.output_dataset.minibatch_mode:
1195 function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)] 1196 function_inputs = [[input] for input in self.input_iterator.next()]
1197 outputs = self.output_dataset.function(*function_inputs)
1198 assert all([hasattr(output,'__iter__') for output in outputs])
1199 function_outputs = [output[0] for output in outputs]
1196 else: 1200 else:
1197 function_outputs = self.output_dataset.function(function_inputs) 1201 function_inputs = self.input_iterator.next()
1202 function_outputs = self.output_dataset.function(*function_inputs)
1198 return Example(self.output_dataset.output_names,function_outputs) 1203 return Example(self.output_dataset.output_names,function_outputs)
1199 return ApplyFunctionSingleExampleIterator(self) 1204 return ApplyFunctionSingleExampleIterator(self)
1200 1205
1201 1206
1202 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 1207 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):