Mercurial > pylearn
comparison learner.py @ 135:0d8e721cc63c
Fixed bugs in dataset to make test_mlp.py work
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 12 May 2008 14:30:21 -0400 |
parents | 3f4e5c9bdc5e |
children | ceae4de18981 |
comparison
equal
deleted
inserted
replaced
134:3f4e5c9bdc5e | 135:0d8e721cc63c |
---|---|
9 Base class for learning algorithms, provides an interface | 9 Base class for learning algorithms, provides an interface |
10 that allows various algorithms to be applicable to generic learning | 10 that allows various algorithms to be applicable to generic learning |
11 algorithms. | 11 algorithms. |
12 | 12 |
13 A L{Learner} can be seen as a learning algorithm, a function that when | 13 A L{Learner} can be seen as a learning algorithm, a function that when |
14 applied to training data returns a learned function, an object that | 14 applied to training data returns a learned function (which is an object that |
15 can be applied to other data and return some output data. | 15 can be applied to other data and return some output data). |
16 | |
16 """ | 17 """ |
17 | 18 |
18 def __init__(self): | 19 def __init__(self): |
19 pass | 20 pass |
20 | 21 |
49 """ | 50 """ |
50 self.forget() | 51 self.forget() |
51 return self.update(training_set,train_stats_collector) | 52 return self.update(training_set,train_stats_collector) |
52 | 53 |
53 def use(self,input_dataset,output_fieldnames=None, | 54 def use(self,input_dataset,output_fieldnames=None, |
54 test_stats_collector=None,copy_inputs=True, | 55 test_stats_collector=None,copy_inputs=False, |
55 put_stats_in_output_dataset=True, | 56 put_stats_in_output_dataset=True, |
56 output_attributes=[]): | 57 output_attributes=[]): |
57 """ | 58 """ |
58 Once a L{Learner} has been trained by one or more call to 'update', it can | 59 Once a L{Learner} has been trained by one or more call to 'update', it can |
59 be used with one or more calls to 'use'. The argument is an input L{DataSet} (possibly | 60 be used with one or more calls to 'use'. The argument is an input L{DataSet} (possibly |
83 If output_attributes is None then all of the attributes in self.AttributeNames() | 84 If output_attributes is None then all of the attributes in self.AttributeNames() |
84 are copied in the output dataset, but if it is [] (the default), then none are copied. | 85 are copied in the output dataset, but if it is [] (the default), then none are copied. |
85 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) | 86 If a test_stats_collector is provided, then its attributes (test_stats_collector.AttributeNames()) |
86 are also copied into the output dataset attributes. | 87 are also copied into the output dataset attributes. |
87 """ | 88 """ |
88 minibatchwise_use_function = self.minibatchwiseUseFunction(input_dataset.fieldNames(), | 89 input_fieldnames = input_dataset.fieldNames() |
90 if not output_fieldnames: | |
91 output_fieldnames = self.defaultOutputFields(input_fieldnames) | |
92 | |
93 minibatchwise_use_function = self.minibatchwiseUseFunction(input_fieldnames, | |
89 output_fieldnames, | 94 output_fieldnames, |
90 test_stats_collector) | 95 test_stats_collector) |
91 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, | 96 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, |
92 minibatchwise_use_function, | 97 minibatchwise_use_function, |
98 output_fieldnames, | |
93 True,DataSet.numpy_vstack, | 99 True,DataSet.numpy_vstack, |
94 DataSet.numpy_hstack) | 100 DataSet.numpy_hstack) |
95 # actually force the computation | 101 # actually force the computation |
96 output_dataset = CachedDataSet(virtual_output_dataset,True) | 102 output_dataset = CachedDataSet(virtual_output_dataset,True) |
97 if copy_inputs: | 103 if copy_inputs: |
210 def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector): | 216 def minibatchwiseUseFunction(self, input_fields, output_fields, stats_collector): |
211 """ | 217 """ |
212 Implement minibatchwiseUseFunction by exploiting Theano compilation | 218 Implement minibatchwiseUseFunction by exploiting Theano compilation |
213 and the expression graph defined by a sub-class constructor. | 219 and the expression graph defined by a sub-class constructor. |
214 """ | 220 """ |
215 if not output_fields: | |
216 output_fields = self.defaultOutputFields(input_fields) | |
217 if stats_collector: | 221 if stats_collector: |
218 stats_collector_inputs = stats_collector.input2UpdateAttributes() | 222 stats_collector_inputs = stats_collector.input2UpdateAttributes() |
219 for attribute in stats_collector_inputs: | 223 for attribute in stats_collector_inputs: |
220 if attribute not in input_fields: | 224 if attribute not in input_fields: |
221 output_fields.append(attribute) | 225 output_fields.append(attribute) |