comparison learner.py @ 92:c4726e19b8ec

Finished first draft of TLearner
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 05 May 2008 18:14:32 -0400
parents 3499918faa9d
children c4916445e025
comparison
equal deleted inserted replaced
84:aa9e786ee849 92:c4726e19b8ec
55 Optionally, if copy_inputs, the input fields (of the input_dataset) can be made 55 Optionally, if copy_inputs, the input fields (of the input_dataset) can be made
56 visible in the output DataSet returned by this method. 56 visible in the output DataSet returned by this method.
57 """ 57 """
58 raise NotImplementedError 58 raise NotImplementedError
59 59
60 def attribute_names(self): 60 def attributeNames(self):
61 """ 61 """
62 A Learner may have attributes that it wishes to export to other objects. To automate 62 A Learner may have attributes that it wishes to export to other objects. To automate
63 such export, sub-classes should define here the names (list of strings) of these attributes. 63 such export, sub-classes should define here the names (list of strings) of these attributes.
64 """ 64 """
65 return [] 65 return []
83 The sub-class constructor defines the relations between 83 The sub-class constructor defines the relations between
84 the Theano variables that may be used by 'use' and 'update' 84 the Theano variables that may be used by 'use' and 'update'
85 or by a stats collector. 85 or by a stats collector.
86 - defaultOutputFields(input_fields): return a list of default dataset output fields when 86 - defaultOutputFields(input_fields): return a list of default dataset output fields when
87 None are provided by the caller of use. 87 None are provided by the caller of use.
88 - 88 - update_start(), update_end(), update_minibatch(minibatch): functions
89 89 executed at the beginning, the end, and in the middle
90 (for each minibatch) of the update method. This model only
91 works for 'online' or one-short learning that requires
92 going only once through the training data. For more complicated
93 models, more specialized subclasses of TLearner should be used
94 or a learning-algorithm specific update method should be defined.
95
96 The following naming convention is assumed and important.
97 Attributes whose names are listed in attributeNames() can be of any type,
98 but those that can be referenced as input/output dataset fields or as
99 output attributes in 'use' or as input attributes in the stats collector
100 should be associated with a Theano Result variable. If the exported attribute
101 name is <name>, the corresponding Result name (an internal attribute of
102 the TLearner, created in the sub-class constructor) should be _<name>.
103 Typically <name> will be numpy ndarray and _<name> will be the corresponding
104 Theano Tensor (for symbolic manipulation).
90 """ 105 """
106
107 def __init__(self):
108 Learner.__init__(self)
109
110 def _minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
111 """
112 Private helper function called by the generic TLearner.use. It returns a function
113 that can map the given input fields to the given output fields (along with the
114 attributes that the stats collector needs for its computation.
115 """
116 if not output_fields:
117 output_fields = self.defaultOutputFields(input_fields)
118 if stats_collector:
119 stats_collector_inputs = stats_collector.inputUpdateAttributes()
120 for attribute in stats_collector_inputs:
121 if attribute not in input_fields:
122 output_fields.append(attribute)
123 key = (input_fields,output_fields)
124 if key not in self.use_functions_dictionary:
125 self.use_functions_dictionary[key]=Function(self._names2attributes(input_fields),
126 self._names2attributes(output_fields))
127 return self.use_functions_dictionary[key]
128
129 def attributes(self,return_copy=False):
130 """
131 Return a list with the values of the learner's attributes (or optionally, a deep copy).
132 """
133 return self.names2attributes(self.attributeNames())
134
135 def _names2attributes(self,names,return_Result=False, return_copy=False):
136 """
137 Private helper function that maps a list of attribute names to a list
138 of (optionally copies) values or of the Result objects that own these values.
139 """
140 if return_Result:
141 if return_copy:
142 return [copy.deepcopy(self.__getattr__(name)) for name in names]
143 else:
144 return [self.__getattr__(name) for name in names]
145 else:
146 if return_copy:
147 return [copy.deepcopy(self.__getattr__(name).data) for name in names]
148 else:
149 return [self.__getattr__(name).data for name in names]
150
151 def use(self,input_dataset,output_fieldnames=None,output_attributes=None,
152 test_stats_collector=None,copy_inputs=True):
153 """
154 The learner tries to compute in the output dataset the output fields specified
155 """
156 minibatchwise_use_function = _minibatchwise_use_functions(input_dataset.fieldNames(),
157 output_fieldnames,
158 test_stats_collector)
159 virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
160 minibatchwise_use_function,
161 True,DataSet.numpy_vstack,
162 DataSet.numpy_hstack)
163 # actually force the computation
164 output_dataset = CachedDataSet(virtual_output_dataset,True)
165 if copy_inputs:
166 output_dataset = input_dataset | output_dataset
167 # copy the wanted attributes in the dataset
168 if output_attributes:
169 assert set(output_attributes) <= set(self.attributeNames())
170 output_dataset.setAttributes(output_attributes,
171 self._names2attributes(output_attributes,return_copy=True))
172 if test_stats_collector:
173 test_stats_collector.update(output_dataset)
174 output_dataset.setAttributes(test_stats_collector.attributeNames(),
175 test_stats_collector.attributes())
176 return output_dataset
177
178 def update_start(self): pass
179 def update_end(self): pass
180 def update_minibatch(self,minibatch):
181 raise AbstractFunction()
91 182
183 def update(self,training_set,train_stats_collector=None):
184
185 self.update_start()
186 for minibatch in training_set.minibatches(self.training_set_input_fields,
187 minibatch_size=self.minibatch_size):
188 self.update_minibatch(minibatch)
189 if train_stats_collector:
190 minibatch_set = minibatch.examples()
191 minibatch_set.setAttributes(self.attributeNames(),self.attributes())
192 train_stats_collector.update(minibatch_set)
193 self.update_end()
194 return self.use
195