Mercurial > pylearn
comparison learner.py @ 107:c4916445e025
Comments from Pascal V.
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Tue, 06 May 2008 19:54:43 -0400 |
parents | c4726e19b8ec |
children | d97f6fe6bdf9 |
comparison
equal
deleted
inserted
replaced
97:05cfe011ca20 | 107:c4916445e025 |
---|---|
59 | 59 |
60 def attributeNames(self): | 60 def attributeNames(self): |
61 """ | 61 """ |
62 A Learner may have attributes that it wishes to export to other objects. To automate | 62 A Learner may have attributes that it wishes to export to other objects. To automate |
63 such export, sub-classes should define here the names (list of strings) of these attributes. | 63 such export, sub-classes should define here the names (list of strings) of these attributes. |
64 | |
65 @todo By default, attributeNames looks for all dictionary entries whose name does not start with _. | |
64 """ | 66 """ |
65 return [] | 67 return [] |
66 | 68 |
67 class TLearner(Learner): | 69 class TLearner(Learner): |
68 """ | 70 """ |
69 TLearner is a virtual class of Learners that attempts to factor out of the definition | 71 TLearner is a virtual class of Learners that attempts to factor out of the definition |
70 of a learner the steps that are common to many implementations of learning algorithms, | 72 of a learner the steps that are common to many implementations of learning algorithms, |
71 so as to leave only "the equations" to define in particular sub-classes, using Theano. | 73 so as to leave only 'the equations' to define in particular sub-classes, using Theano. |
72 | 74 |
73 In the default implementations of use and update, it is assumed that the 'use' and 'update' methods | 75 In the default implementations of use and update, it is assumed that the 'use' and 'update' methods |
74 visit examples in the input dataset sequentially. In the 'use' method only one pass through the dataset is done, | 76 visit examples in the input dataset sequentially. In the 'use' method only one pass through the dataset is done, |
75 whereas the sub-learner may wish to iterate over the examples multiple times. Subclasses where this | 77 whereas the sub-learner may wish to iterate over the examples multiple times. Subclasses where this |
76 basic model is not appropriate can simply redefine update or use. | 78 basic model is not appropriate can simply redefine update or use. |
83 The sub-class constructor defines the relations between | 85 The sub-class constructor defines the relations between |
84 the Theano variables that may be used by 'use' and 'update' | 86 the Theano variables that may be used by 'use' and 'update' |
85 or by a stats collector. | 87 or by a stats collector. |
86 - defaultOutputFields(input_fields): return a list of default dataset output fields when | 88 - defaultOutputFields(input_fields): return a list of default dataset output fields when |
87 None are provided by the caller of use. | 89 None are provided by the caller of use. |
88 - update_start(), update_end(), update_minibatch(minibatch): functions | |
89 executed at the beginning, the end, and in the middle | |
90 (for each minibatch) of the update method. This model only | |
91 works for 'online' or one-short learning that requires | |
92 going only once through the training data. For more complicated | |
93 models, more specialized subclasses of TLearner should be used | |
94 or a learning-algorithm specific update method should be defined. | |
95 | |
96 The following naming convention is assumed and important. | 90 The following naming convention is assumed and important. |
97 Attributes whose names are listed in attributeNames() can be of any type, | 91 Attributes whose names are listed in attributeNames() can be of any type, |
98 but those that can be referenced as input/output dataset fields or as | 92 but those that can be referenced as input/output dataset fields or as |
99 output attributes in 'use' or as input attributes in the stats collector | 93 output attributes in 'use' or as input attributes in the stats collector |
100 should be associated with a Theano Result variable. If the exported attribute | 94 should be associated with a Theano Result variable. If the exported attribute |
101 name is <name>, the corresponding Result name (an internal attribute of | 95 name is <name>, the corresponding Result name (an internal attribute of |
102 the TLearner, created in the sub-class constructor) should be _<name>. | 96 the TLearner, created in the sub-class constructor) should be _<name>. |
103 Typically <name> will be numpy ndarray and _<name> will be the corresponding | 97 Typically <name> will be numpy ndarray and _<name> will be the corresponding |
104 Theano Tensor (for symbolic manipulation). | 98 Theano Tensor (for symbolic manipulation). |
99 | |
100 @todo pousser dans Learner toute la poutine qui peut l'etre sans etre | |
101 dependant de Theano | |
105 """ | 102 """ |
106 | 103 |
107 def __init__(self): | 104 def __init__(self): |
108 Learner.__init__(self) | 105 Learner.__init__(self) |
109 | 106 |
146 if return_copy: | 143 if return_copy: |
147 return [copy.deepcopy(self.__getattr__(name).data) for name in names] | 144 return [copy.deepcopy(self.__getattr__(name).data) for name in names] |
148 else: | 145 else: |
149 return [self.__getattr__(name).data for name in names] | 146 return [self.__getattr__(name).data for name in names] |
150 | 147 |
151 def use(self,input_dataset,output_fieldnames=None,output_attributes=None, | 148 def use(self,input_dataset,output_fieldnames=None,output_attributes=[], |
152 test_stats_collector=None,copy_inputs=True): | 149 test_stats_collector=None,copy_inputs=True): |
153 """ | 150 """ |
154 The learner tries to compute in the output dataset the output fields specified | 151 The learner tries to compute in the output dataset the output fields specified |
152 | |
153 @todo check if some of the learner attributes are actually SPECIFIED | |
154 as attributes of the input_dataset, and if so use their values instead | |
155 of the ones in the learner. | |
155 """ | 156 """ |
156 minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(), | 157 minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(), |
157 output_fieldnames, | 158 output_fieldnames, |
158 test_stats_collector) | 159 test_stats_collector) |
159 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, | 160 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, |
163 # actually force the computation | 164 # actually force the computation |
164 output_dataset = CachedDataSet(virtual_output_dataset,True) | 165 output_dataset = CachedDataSet(virtual_output_dataset,True) |
165 if copy_inputs: | 166 if copy_inputs: |
166 output_dataset = input_dataset | output_dataset | 167 output_dataset = input_dataset | output_dataset |
167 # copy the wanted attributes in the dataset | 168 # copy the wanted attributes in the dataset |
169 if output_attributes is None: | |
170 output_attributes = self.attributeNames() | |
168 if output_attributes: | 171 if output_attributes: |
169 assert set(output_attributes) <= set(self.attributeNames()) | 172 assert set(output_attributes) <= set(self.attributeNames()) |
170 output_dataset.setAttributes(output_attributes, | 173 output_dataset.setAttributes(output_attributes, |
171 self._names2attributes(output_attributes,return_copy=True)) | 174 self._names2attributes(output_attributes,return_copy=True)) |
172 if test_stats_collector: | 175 if test_stats_collector: |
173 test_stats_collector.update(output_dataset) | 176 test_stats_collector.update(output_dataset) |
174 output_dataset.setAttributes(test_stats_collector.attributeNames(), | 177 output_dataset.setAttributes(test_stats_collector.attributeNames(), |
175 test_stats_collector.attributes()) | 178 test_stats_collector.attributes()) |
176 return output_dataset | 179 return output_dataset |
177 | 180 |
181 | |
182 class OneShotTLearner(TLearner): | |
183 """ | |
184 This adds to TLearner a | |
185 - update_start(), update_end(), update_minibatch(minibatch), end_epoch(): | |
186 functions executed at the beginning, the end, in the middle | |
187 (for each minibatch) of the update method, and at the end | |
188 of each epoch. This model only | |
189 works for 'online' or one-shot learning that requires | |
190 going only once through the training data. For more complicated | |
191 models, more specialized subclasses of TLearner should be used | |
192 or a learning-algorithm specific update method should be defined. | |
193 """ | |
194 | |
195 def __init__(self): | |
196 TLearner.__init__(self) | |
197 | |
178 def update_start(self): pass | 198 def update_start(self): pass |
179 def update_end(self): pass | 199 def update_end(self): pass |
180 def update_minibatch(self,minibatch): | 200 def update_minibatch(self,minibatch): |
181 raise AbstractFunction() | 201 raise AbstractFunction() |
182 | 202 |
183 def update(self,training_set,train_stats_collector=None): | 203 def update(self,training_set,train_stats_collector=None): |
184 | 204 """ |
205 @todo check if some of the learner attributes are actually SPECIFIED | |
206 in as attributes of the training_set. | |
207 """ | |
185 self.update_start() | 208 self.update_start() |
186 for minibatch in training_set.minibatches(self.training_set_input_fields, | 209 stop=False |
187 minibatch_size=self.minibatch_size): | 210 while not stop: |
188 self.update_minibatch(minibatch) | |
189 if train_stats_collector: | 211 if train_stats_collector: |
190 minibatch_set = minibatch.examples() | 212 train_stats_collector.forget() # restart stats collectin at the beginning of each epoch |
191 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) | 213 for minibatch in training_set.minibatches(self.training_set_input_fields, |
192 train_stats_collector.update(minibatch_set) | 214 minibatch_size=self.minibatch_size): |
215 self.update_minibatch(minibatch) | |
216 if train_stats_collector: | |
217 minibatch_set = minibatch.examples() | |
218 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) | |
219 train_stats_collector.update(minibatch_set) | |
220 stop = self.end_epoch() | |
193 self.update_end() | 221 self.update_end() |
194 return self.use | 222 return self.use |
195 | 223 |