comparison dataset.py @ 296:f5d33f9c0b9c

ApplyFunctionDataSet passing
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 06 Jun 2008 17:50:29 -0400
parents 4bfdda107a17
children 923de30457f0
comparison
equal deleted inserted replaced
294:f7924e13e426 296:f5d33f9c0b9c
1202 class ApplyFunctionDataSet(DataSet): 1202 class ApplyFunctionDataSet(DataSet):
1203 """ 1203 """
1204 A L{DataSet} that contains as fields the results of applying a 1204 A L{DataSet} that contains as fields the results of applying a
1205 given function example-wise or minibatch-wise to all the fields of 1205 given function example-wise or minibatch-wise to all the fields of
1206 an input dataset. The output of the function should be an iterable 1206 an input dataset. The output of the function should be an iterable
1207 (e.g. a list or a Example) over the resulting values. 1207 (e.g. a list or a LookupList) over the resulting values.
1208 1208
1209 The function take as input the fields of the dataset, not the examples. 1209 The function take as input the fields of the dataset, not the examples.
1210 1210
1211 In minibatch mode, the function is expected to work on minibatches 1211 In minibatch mode, the function is expected to work on minibatches
1212 (takes a minibatch in input and returns a minibatch in output). More 1212 (takes a minibatch in input and returns a minibatch in output). More
1219 The function is applied each time an example or a minibatch is accessed. 1219 The function is applied each time an example or a minibatch is accessed.
1220 To avoid re-doing computation, wrap this dataset inside a CachedDataSet. 1220 To avoid re-doing computation, wrap this dataset inside a CachedDataSet.
1221 1221
1222 If the values_{h,v}stack functions are not provided, then 1222 If the values_{h,v}stack functions are not provided, then
1223 the input_dataset.values{H,V}Stack functions are used by default. 1223 the input_dataset.values{H,V}Stack functions are used by default.
1224 """ 1224
1225 """
1226
1225 def __init__(self,input_dataset,function,output_names,minibatch_mode=True, 1227 def __init__(self,input_dataset,function,output_names,minibatch_mode=True,
1226 values_hstack=None,values_vstack=None, 1228 values_hstack=None,values_vstack=None,
1227 description=None,fieldtypes=None): 1229 description=None,fieldtypes=None):
1228 """ 1230 """
1229 Constructor takes an input dataset that has as many fields as the function 1231 Constructor takes an input dataset that has as many fields as the function
1251 1253
1252 def fieldNames(self): 1254 def fieldNames(self):
1253 return self.output_names 1255 return self.output_names
1254 1256
1255 def minibatches_nowrap(self, fieldnames, *args, **kwargs): 1257 def minibatches_nowrap(self, fieldnames, *args, **kwargs):
1256 for input_fields in self.input_dataset.minibatches_nowrap(fieldnames, *args, **kwargs): 1258 all_input_fieldNames = self.input_dataset.fieldNames()
1257 1259 mbnw = self.input_dataset.minibatches_nowrap
1258 #function_inputs = self.input_iterator.next() 1260
1261 for input_fields in mbnw(all_input_fieldNames, *args, **kwargs):
1259 if self.minibatch_mode: 1262 if self.minibatch_mode:
1260 function_outputs = self.function(*input_fields) 1263 all_output_fields = self.function(*input_fields)
1261 else: 1264 else:
1262 input_examples = zip(*input_fields) 1265 input_examples = zip(*input_fields) #makes so that [i] means example i
1263 output_examples = [self.function(*input_example) 1266 output_examples = [self.function(*input_example)
1264 for input_example in input_examples] 1267 for input_example in input_examples]
1265 function_outputs = [self.valuesVStack(name,values) 1268 all_output_fields = zip(*output_examples)
1266 for name,values in zip(self.output_names, 1269
1267 zip(*output_examples))] 1270 all_outputs = Example(self.output_names, all_output_fields)
1268 all_outputs = Example(self.output_names, function_outputs) 1271 #print 'input_fields', input_fields
1269 print 'input_fields', input_fields 1272 #print 'all_outputs', all_outputs
1270 print 'all_outputs', all_outputs
1271 if fieldnames==self.output_names: 1273 if fieldnames==self.output_names:
1272 rval = all_outputs 1274 rval = all_outputs
1273 else: 1275 else:
1274 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames]) 1276 rval = Example(fieldnames,[all_outputs[name] for name in fieldnames])
1275 print 'rval', rval 1277 #print 'rval', rval
1276 print '--------' 1278 #print '--------'
1277 yield rval 1279 yield rval
1278 1280
1279 def untested__iter__(self): # only implemented for increased efficiency 1281 def untested__iter__(self): # only implemented for increased efficiency
1280 class ApplyFunctionSingleExampleIterator(object): 1282 class ApplyFunctionSingleExampleIterator(object):
1281 def __init__(self,output_dataset): 1283 def __init__(self,output_dataset):
1293 function_inputs = self.input_iterator.next() 1295 function_inputs = self.input_iterator.next()
1294 function_outputs = self.output_dataset.function(*function_inputs) 1296 function_outputs = self.output_dataset.function(*function_inputs)
1295 return Example(self.output_dataset.output_names,function_outputs) 1297 return Example(self.output_dataset.output_names,function_outputs)
1296 return ApplyFunctionSingleExampleIterator(self) 1298 return ApplyFunctionSingleExampleIterator(self)
1297 1299
1298
1299 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None): 1300 def supervised_learning_dataset(src_dataset,input_fields,target_fields,weight_field=None):
1300 """ 1301 """
1301 Wraps an arbitrary L{DataSet} into one for supervised learning tasks 1302 Wraps an arbitrary L{DataSet} into one for supervised learning tasks
1302 by forcing the user to define a set of fields as the 'input' field 1303 by forcing the user to define a set of fields as the 'input' field
1303 and a set of fields as the 'target' field. Optionally, a single 1304 and a set of fields as the 'target' field. Optionally, a single