Mercurial > pylearn
comparison dataset.py @ 296:f5d33f9c0b9c
ApplyFunctionDataSet passing
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Fri, 06 Jun 2008 17:50:29 -0400 |
parents | 4bfdda107a17 |
children | 923de30457f0 |
comparison
equal
deleted
inserted
replaced
294:f7924e13e426 | 296:f5d33f9c0b9c |
---|---|
1202 class ApplyFunctionDataSet(DataSet): | 1202 class ApplyFunctionDataSet(DataSet): |
1203 """ | 1203 """ |
1204 A L{DataSet} that contains as fields the results of applying a | 1204 A L{DataSet} that contains as fields the results of applying a |
1205 given function example-wise or minibatch-wise to all the fields of | 1205 given function example-wise or minibatch-wise to all the fields of |
1206 an input dataset. The output of the function should be an iterable | 1206 an input dataset. The output of the function should be an iterable |
1207 (e.g. a list or a Example) over the resulting values. | 1207 (e.g. a list or a LookupList) over the resulting values. |
1208 | 1208 |
1209 The function take as input the fields of the dataset, not the examples. | 1209 The function take as input the fields of the dataset, not the examples. |
1210 | 1210 |
1211 In minibatch mode, the function is expected to work on minibatches | 1211 In minibatch mode, the function is expected to work on minibatches |
1212 (takes a minibatch in input and returns a minibatch in output). More | 1212 (takes a minibatch in input and returns a minibatch in output). More |
1219 The function is applied each time an example or a minibatch is accessed. | 1219 The function is applied each time an example or a minibatch is accessed. |
1220 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. | 1220 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. |
1221 | 1221 |
1222 If the values_{h,v}stack functions are not provided, then | 1222 If the values_{h,v}stack functions are not provided, then |
1223 the input_dataset.values{H,V}Stack functions are used by default. | 1223 the input_dataset.values{H,V}Stack functions are used by default. |
1224 """ | 1224 |
1225 """ | |
1226 | |
1225 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, | 1227 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, |
1226 values_hstack=None,values_vstack=None, | 1228 values_hstack=None,values_vstack=None, |
1227 description=None,fieldtypes=None): | 1229 description=None,fieldtypes=None): |
1228 """ | 1230 """ |
1229 Constructor takes an input dataset that has as many fields as the function | 1231 Constructor takes an input dataset that has as many fields as the function |
1251 | 1253 |
1252 def fieldNames(self): | 1254 def fieldNames(self): |
1253 return self.output_names | 1255 return self.output_names |
1254 | 1256 |
1255 def minibatches_nowrap(self, fieldnames, *args, **kwargs): | 1257 def minibatches_nowrap(self, fieldnames, *args, **kwargs): |
1256 for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): | 1258 all_input_fieldNames = self.input_dataset.fieldNames() |
1257 | 1259 mbnw = self.input_dataset.minibatches_nowrap |
1258 #function_inputs = self.input_iterator.next() | 1260 |
1261 for input_fields in mbnw(all_input_fieldNames, *args, **kwargs): | |
1259 if self.minibatch_mode: | 1262 if self.minibatch_mode: |
1260 function_outputs = self.function(*input_fields) | 1263 all_output_fields = self.function(*input_fields) |
1261 else: | 1264 else: |
1262 input_examples = zip(*input_fields) | 1265 input_examples = zip(*input_fields) #makes so that [i] means example i |
1263 output_examples = [self.function(*input_example) | 1266 output_examples = [self.function(*input_example) |
1264 for input_example in input_examples] | 1267 for input_example in input_examples] |
1265 function_outputs = [self.valuesVStack(name,values) | 1268 all_output_fields = zip(*output_examples) |
1266 for name,values in zip(self.output_names, | 1269 |
1267 zip(*output_examples))] | 1270 all_outputs = Example(self.output_names, all_output_fields) |
1268 all_outputs = Example(self.output_names, function_outputs) | 1271 #print 'input_fields', input_fields |
1269 print 'input_fields', input_fields | 1272 #print 'all_outputs', all_outputs |
1270 print 'all_outputs', all_outputs | |
1271 if fieldnames==self.output_names: | 1273 if fieldnames==self.output_names: |
1272 rval = all_outputs | 1274 rval = all_outputs |
1273 else: | 1275 else: |
1274 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames]) | 1276 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames]) |
1275 print 'rval', rval | 1277 #print 'rval', rval |
1276 print '--------' | 1278 #print '--------' |
1277 yield rval | 1279 yield rval |
1278 | 1280 |
1279 def untested__iter__(self): # only implemented for increased efficiency | 1281 def untested__iter__(self): # only implemented for increased efficiency |
1280 class ApplyFunctionSingleExampleIterator(object): | 1282 class ApplyFunctionSingleExampleIterator(object): |
1281 def __init__(self,output_dataset): | 1283 def __init__(self,output_dataset): |
1293 function_inputs = self.input_iterator.next() | 1295 function_inputs = self.input_iterator.next() |
1294 function_outputs = self.output_dataset.function(*function_inputs) | 1296 function_outputs = self.output_dataset.function(*function_inputs) |
1295 return Example(self.output_dataset.output_names,function_outputs) | 1297 return Example(self.output_dataset.output_names,function_outputs) |
1296 return ApplyFunctionSingleExampleIterator(self) | 1298 return ApplyFunctionSingleExampleIterator(self) |
1297 | 1299 |
1298 | |
1299 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): | 1300 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): |
1300 """ | 1301 """ |
1301 Wraps an arbitrary L{DataSet} into one for supervised learning tasks | 1302 Wraps an arbitrary L{DataSet} into one for supervised learning tasks |
1302 by forcing the user to define a set of fields as the 'input' field | 1303 by forcing the user to define a set of fields as the 'input' field |
1303 and a set of fields as the 'target' field. Optionally, a single | 1304 and a set of fields as the 'target' field. Optionally, a single |