comparison learner.py @ 135:0d8e721cc63c

Fixed bugs in dataset to make test_mlp.py work
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 12 May 2008 14:30:21 -0400
parents 3f4e5c9bdc5e
children ceae4de18981
comparison
equal deleted inserted replaced
134:3f4e5c9bdc5e 135:0d8e721cc63c
9 Base class for learning algorithms, provides an interface 9 Base class for learning algorithms, provides an interface
10 that allows various algorithms to be applicable to generic learning 10 that allows various algorithms to be applicable to generic learning
11 algorithms. 11 algorithms.
12 12
13 A L{Learner} can be seen as a learning algorithm, a function that when 13 A L{Learner} can be seen as a learning algorithm, a function that when
14 applied to training data returns a learned function, an object that 14 applied to training data returns a learned function (which is an object that
15 can be applied to other data and return some output data. 15 can be applied to other data and return some output data).
16
16 """ 17 """
17 18
18 def __init__(self): 19 def __init__(self):
19 pass 20 pass
20 21
49 """ 50 """
50 self.forget() 51 self.forget()
51 return self.update(training_set,train_stats_collector) 52 return self.update(training_set,train_stats_collector)
52 53
53 def use(self,input_dataset,output_fieldnames=None, 54 def use(self,input_dataset,output_fieldnames=None,
54 test_stats_collector=None,copy_inputs=True, 55 test_stats_collector=None,copy_inputs=False,
55 put_stats_in_output_dataset=True, 56 put_stats_in_output_dataset=True,
56 output_attributes=[]): 57 output_attributes=[]):
57 """ 58 """
58 Once a L{Learner} has been trained by one or more call to 'update', it can 59 Once a L{Learner} has been trained by one or more call to 'update', it can
59 be used with one or more calls to 'use'. The argument is an input L{DataSet} (possibly 60 be used with one or more calls to 'use'. The argument is an input L{DataSet} (possibly
83 If output_attributes is None then all of the attributes in self.AttributeNames() 84 If output_attributes is None then all of the attributes in self.AttributeNames()
84 are copied in the output dataset, but if it is [] (the default), then none are copied. 85 are copied in the output dataset, but if it is [] (the default), then none are copied.
85 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) 86 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames())
86 are also copied into the output dataset attributes. 87 are also copied into the output dataset attributes.
87 """ 88 """
88 minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(), 89 input_fieldnames = input_dataset.fieldNames()
90 if not output_fieldnames:
91 output_fieldnames = self.defaultOutputFields(input_fieldnames)
92
93 minibatchwise_use_function = self.minibatchwiseUseFunction(input_fieldnames,
89 output_fieldnames, 94 output_fieldnames,
90 test_stats_collector) 95 test_stats_collector)
91 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, 96 virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
92 minibatchwise_use_function, 97 minibatchwise_use_function,
98 output_fieldnames,
93 True,DataSet.numpy_vstack, 99 True,DataSet.numpy_vstack,
94 DataSet.numpy_hstack) 100 DataSet.numpy_hstack)
95 # actually force the computation 101 # actually force the computation
96 output_dataset = CachedDataSet(virtual_output_dataset,True) 102 output_dataset = CachedDataSet(virtual_output_dataset,True)
97 if copy_inputs: 103 if copy_inputs:
210 def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector): 216 def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector):
211 """ 217 """
212 Implement minibatchwiseUseFunction by exploiting Theano compilation 218 Implement minibatchwiseUseFunction by exploiting Theano compilation
213 and the expression graph defined by a sub-class constructor. 219 and the expression graph defined by a sub-class constructor.
214 """ 220 """
215 if not output_fields:
216 output_fields = self.defaultOutputFields(input_fields)
217 if stats_collector: 221 if stats_collector:
218 stats_collector_inputs = stats_collector.input2UpdateAttributes() 222 stats_collector_inputs = stats_collector.input2UpdateAttributes()
219 for attribute in stats_collector_inputs: 223 for attribute in stats_collector_inputs:
220 if attribute not in input_fields: 224 if attribute not in input_fields:
221 output_fields.append(attribute) 225 output_fields.append(attribute)