Mercurial > pylearn
diff learner.py @ 135:0d8e721cc63c
Fixed bugs in dataset to make test_mlp.py work
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 14:30:21 -0400 |
parents | 3f4e5c9bdc5e |
children | ceae4de18981 |
line wrap: on
line diff
--- a/learner.py Fri May 09 17:38:57 2008 -0400 +++ b/learner.py Mon May 12 14:30:21 2008 -0400 @@ -11,8 +11,9 @@ algorithms. A L{Learner} can be seen as a learning algorithm, a function that when - applied to training data returns a learned function, an object that - can be applied to other data and return some output data. + applied to training data returns a learned function (which is an object that + can be applied to other data and return some output data). + """ def __init__(self): @@ -51,7 +52,7 @@ return self.update(training_set,train_stats_collector) def use(self,input_dataset,output_fieldnames=None, - test_stats_collector=None,copy_inputs=True, + test_stats_collector=None,copy_inputs=False, put_stats_in_output_dataset=True, output_attributes=[]): """ @@ -85,11 +86,16 @@ If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) are also copied into the output dataset attributes. """ - minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(), + input_fieldnames = input_dataset.fieldNames() + if not output_fieldnames: + output_fieldnames = self.defaultOutputFields(input_fieldnames) + + minibatchwise_use_function = self.minibatchwiseUseFunction(input_fieldnames, output_fieldnames, test_stats_collector) virtual_output_dataset = ApplyFunctionDataSet(input_dataset, minibatchwise_use_function, + output_fieldnames, True,DataSet.numpy_vstack, DataSet.numpy_hstack) # actually force the computation @@ -212,8 +218,6 @@ Implement minibatchwiseUseFunction by exploiting Theano compilation and the expression graph defined by a sub-class constructor. """ - if not output_fields: - output_fields = self.defaultOutputFields(input_fields) if stats_collector: stats_collector_inputs = stats_collector.input2UpdateAttributes() for attribute in stats_collector_inputs: