diff learner.py @ 135:0d8e721cc63c

Fixed bugs in dataset to make test_mlp.py work
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 12 May 2008 14:30:21 -0400
parents 3f4e5c9bdc5e
children ceae4de18981
line wrap: on
line diff
--- a/learner.py	Fri May 09 17:38:57 2008 -0400
+++ b/learner.py	Mon May 12 14:30:21 2008 -0400
@@ -11,8 +11,9 @@
     algorithms.
 
     A L{Learner} can be seen as a learning algorithm, a function that when
-    applied to training data returns a learned function, an object that
-    can be applied to other data and return some output data.
+    applied to training data returns a learned function (which is an object that
+    can be applied to other data and return some output data).
+    
     """
     
     def __init__(self):
@@ -51,7 +52,7 @@
         return self.update(training_set,train_stats_collector)
 
     def use(self,input_dataset,output_fieldnames=None,
-            test_stats_collector=None,copy_inputs=True,
+            test_stats_collector=None,copy_inputs=False,
             put_stats_in_output_dataset=True,
             output_attributes=[]):
         """
@@ -85,11 +86,16 @@
         If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
         are also copied into the output dataset attributes.
         """
-        minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(),
+        input_fieldnames = input_dataset.fieldNames()
+        if not output_fieldnames:
+            output_fieldnames = self.defaultOutputFields(input_fieldnames)
+
+        minibatchwise_use_function = self.minibatchwiseUseFunction(input_fieldnames,
                                                                    output_fieldnames,
                                                                    test_stats_collector)
         virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
                                                       minibatchwise_use_function,
+                                                      output_fieldnames,
                                                       True,DataSet.numpy_vstack,
                                                       DataSet.numpy_hstack)
         # actually force the computation
@@ -212,8 +218,6 @@
         Implement minibatchwiseUseFunction by exploiting Theano compilation
         and the expression graph defined by a sub-class constructor.
         """
-        if not output_fields:
-            output_fields = self.defaultOutputFields(input_fields)
         if stats_collector:
             stats_collector_inputs = stats_collector.input2UpdateAttributes()
             for attribute in stats_collector_inputs: