Mercurial > pylearn
comparison dataset.py @ 203:80731832c62b
Automated merge with ssh://p-omega1@lgcm.iro.umontreal.ca/tlearn
author | Frederic Bastien <bastienf@iro.umontreal.ca> |
---|---|
date | Thu, 15 May 2008 15:21:00 -0400 |
parents | cb6b945acf5a c9c966cab763 |
children | bd728c83faff 6f55e301c687 |
comparison
equal
deleted
inserted
replaced
193:cb6b945acf5a | 203:80731832c62b |
---|---|
1109 """ | 1109 """ |
1110 A L{DataSet} that contains as fields the results of applying a | 1110 A L{DataSet} that contains as fields the results of applying a |
1111 given function example-wise or minibatch-wise to all the fields of | 1111 given function example-wise or minibatch-wise to all the fields of |
1112 an input dataset. The output of the function should be an iterable | 1112 an input dataset. The output of the function should be an iterable |
1113 (e.g. a list or a LookupList) over the resulting values. | 1113 (e.g. a list or a LookupList) over the resulting values. |
1114 | |
1115 The function take as input the fields of the dataset, not the examples. | |
1114 | 1116 |
1115 In minibatch mode, the function is expected to work on minibatches | 1117 In minibatch mode, the function is expected to work on minibatches |
1116 (takes a minibatch in input and returns a minibatch in output). More | 1118 (takes a minibatch in input and returns a minibatch in output). More |
1117 precisely, it means that each element of the input or output list | 1119 precisely, it means that each element of the input or output list |
1118 should be iterable and indexable over the individual example values | 1120 should be iterable and indexable over the individual example values |
1168 all_output_names = self.output_dataset.output_names | 1170 all_output_names = self.output_dataset.output_names |
1169 if self.output_dataset.minibatch_mode: | 1171 if self.output_dataset.minibatch_mode: |
1170 function_outputs = self.output_dataset.function(*function_inputs) | 1172 function_outputs = self.output_dataset.function(*function_inputs) |
1171 else: | 1173 else: |
1172 input_examples = zip(*function_inputs) | 1174 input_examples = zip(*function_inputs) |
1173 output_examples = [self.output_dataset.function(input_example) | 1175 output_examples = [self.output_dataset.function(*input_example) |
1174 for input_example in input_examples] | 1176 for input_example in input_examples] |
1175 function_outputs = [self.output_dataset.valuesVStack(name,values) | 1177 function_outputs = [self.output_dataset.valuesVStack(name,values) |
1176 for name,values in zip(all_output_names, | 1178 for name,values in zip(all_output_names, |
1177 zip(*output_examples))] | 1179 zip(*output_examples))] |
1178 all_outputs = Example(all_output_names,function_outputs) | 1180 all_outputs = Example(all_output_names,function_outputs) |
1188 self.current=0 | 1190 self.current=0 |
1189 self.output_dataset=output_dataset | 1191 self.output_dataset=output_dataset |
1190 self.input_iterator=output_dataset.input_dataset.__iter__() | 1192 self.input_iterator=output_dataset.input_dataset.__iter__() |
1191 def __iter__(self): return self | 1193 def __iter__(self): return self |
1192 def next(self): | 1194 def next(self): |
1193 function_inputs = self.input_iterator.next() | |
1194 if self.output_dataset.minibatch_mode: | 1195 if self.output_dataset.minibatch_mode: |
1195 function_outputs = [output[0] for output in self.output_dataset.function(function_inputs)] | 1196 function_inputs = [[input] for input in self.input_iterator.next()] |
1197 outputs = self.output_dataset.function(*function_inputs) | |
1198 assert all([hasattr(output,'__iter__') for output in outputs]) | |
1199 function_outputs = [output[0] for output in outputs] | |
1196 else: | 1200 else: |
1197 function_outputs = self.output_dataset.function(function_inputs) | 1201 function_inputs = self.input_iterator.next() |
1202 function_outputs = self.output_dataset.function(*function_inputs) | |
1198 return Example(self.output_dataset.output_names,function_outputs) | 1203 return Example(self.output_dataset.output_names,function_outputs) |
1199 return ApplyFunctionSingleExampleIterator(self) | 1204 return ApplyFunctionSingleExampleIterator(self) |
1200 | 1205 |
1201 | 1206 |
1202 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 1207 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |