Mercurial > pylearn
comparison learner.py @ 92:c4726e19b8ec
Finished first draft of TLearner
author | Yoshua Bengio <bengioy@iro.umontreal.ca> |
---|---|
date | Mon, 05 May 2008 18:14:32 -0400 |
parents | 3499918faa9d |
children | c4916445e025 |
comparison
equal
deleted
inserted
replaced
84:aa9e786ee849 | 92:c4726e19b8ec |
---|---|
55 Optionally, if copy_inputs, the input fields (of the input_dataset) can be made | 55 Optionally, if copy_inputs, the input fields (of the input_dataset) can be made |
56 visible in the output DataSet returned by this method. | 56 visible in the output DataSet returned by this method. |
57 """ | 57 """ |
58 raise NotImplementedError | 58 raise NotImplementedError |
59 | 59 |
60 def attribute_names(self): | 60 def attributeNames(self): |
61 """ | 61 """ |
62 A Learner may have attributes that it wishes to export to other objects. To automate | 62 A Learner may have attributes that it wishes to export to other objects. To automate |
63 such export, sub-classes should define here the names (list of strings) of these attributes. | 63 such export, sub-classes should define here the names (list of strings) of these attributes. |
64 """ | 64 """ |
65 return [] | 65 return [] |
83 The sub-class constructor defines the relations between | 83 The sub-class constructor defines the relations between |
84 the Theano variables that may be used by 'use' and 'update' | 84 the Theano variables that may be used by 'use' and 'update' |
85 or by a stats collector. | 85 or by a stats collector. |
86 - defaultOutputFields(input_fields): return a list of default dataset output fields when | 86 - defaultOutputFields(input_fields): return a list of default dataset output fields when |
87 None are provided by the caller of use. | 87 None are provided by the caller of use. |
88 - | 88 - update_start(), update_end(), update_minibatch(minibatch): functions |
89 | 89 executed at the beginning, the end, and in the middle |
90 (for each minibatch) of the update method. This model only | |
91 works for 'online' or one-short learning that requires | |
92 going only once through the training data. For more complicated | |
93 models, more specialized subclasses of TLearner should be used | |
94 or a learning-algorithm specific update method should be defined. | |
95 | |
96 The following naming convention is assumed and important. | |
97 Attributes whose names are listed in attributeNames() can be of any type, | |
98 but those that can be referenced as input/output dataset fields or as | |
99 output attributes in 'use' or as input attributes in the stats collector | |
100 should be associated with a Theano Result variable. If the exported attribute | |
101 name is <name>, the corresponding Result name (an internal attribute of | |
102 the TLearner, created in the sub-class constructor) should be _<name>. | |
103 Typically <name> will be numpy ndarray and _<name> will be the corresponding | |
104 Theano Tensor (for symbolic manipulation). | |
90 """ | 105 """ |
106 | |
107 def __init__(self): | |
108 Learner.__init__(self) | |
109 | |
110 def _minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): | |
111 """ | |
112 Private helper function called by the generic TLearner.use. It returns a function | |
113 that can map the given input fields to the given output fields (along with the | |
114 attributes that the stats collector needs for its computation. | |
115 """ | |
116 if not output_fields: | |
117 output_fields = self.defaultOutputFields(input_fields) | |
118 if stats_collector: | |
119 stats_collector_inputs = stats_collector.inputUpdateAttributes() | |
120 for attribute in stats_collector_inputs: | |
121 if attribute not in input_fields: | |
122 output_fields.append(attribute) | |
123 key = (input_fields,output_fields) | |
124 if key not in self.use_functions_dictionary: | |
125 self.use_functions_dictionary[key]=Function(self._names2attributes(input_fields), | |
126 self._names2attributes(output_fields)) | |
127 return self.use_functions_dictionary[key] | |
128 | |
129 def attributes(self,return_copy=False): | |
130 """ | |
131 Return a list with the values of the learner's attributes (or optionally, a deep copy). | |
132 """ | |
133 return self.names2attributes(self.attributeNames()) | |
134 | |
135 def _names2attributes(self,names,return_Result=False, return_copy=False): | |
136 """ | |
137 Private helper function that maps a list of attribute names to a list | |
138 of (optionally copies) values or of the Result objects that own these values. | |
139 """ | |
140 if return_Result: | |
141 if return_copy: | |
142 return [copy.deepcopy(self.__getattr__(name)) for name in names] | |
143 else: | |
144 return [self.__getattr__(name) for name in names] | |
145 else: | |
146 if return_copy: | |
147 return [copy.deepcopy(self.__getattr__(name).data) for name in names] | |
148 else: | |
149 return [self.__getattr__(name).data for name in names] | |
150 | |
151 def use(self,input_dataset,output_fieldnames=None,output_attributes=None, | |
152 test_stats_collector=None,copy_inputs=True): | |
153 """ | |
154 The learner tries to compute in the output dataset the output fields specified | |
155 """ | |
156 minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(), | |
157 output_fieldnames, | |
158 test_stats_collector) | |
159 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, | |
160 minibatchwise_use_function, | |
161 True,DataSet.numpy_vstack, | |
162 DataSet.numpy_hstack) | |
163 # actually force the computation | |
164 output_dataset = CachedDataSet(virtual_output_dataset,True) | |
165 if copy_inputs: | |
166 output_dataset = input_dataset | output_dataset | |
167 # copy the wanted attributes in the dataset | |
168 if output_attributes: | |
169 assert set(output_attributes) <= set(self.attributeNames()) | |
170 output_dataset.setAttributes(output_attributes, | |
171 self._names2attributes(output_attributes,return_copy=True)) | |
172 if test_stats_collector: | |
173 test_stats_collector.update(output_dataset) | |
174 output_dataset.setAttributes(test_stats_collector.attributeNames(), | |
175 test_stats_collector.attributes()) | |
176 return output_dataset | |
177 | |
178 def update_start(self): pass | |
179 def update_end(self): pass | |
180 def update_minibatch(self,minibatch): | |
181 raise AbstractFunction() | |
91 | 182 |
183 def update(self,training_set,train_stats_collector=None): | |
184 | |
185 self.update_start() | |
186 for minibatch in training_set.minibatches(self.training_set_input_fields, | |
187 minibatch_size=self.minibatch_size): | |
188 self.update_minibatch(minibatch) | |
189 if train_stats_collector: | |
190 minibatch_set = minibatch.examples() | |
191 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) | |
192 train_stats_collector.update(minibatch_set) | |
193 self.update_end() | |
194 return self.use | |
195 |