comparison linear_regression.py @ 92:c4726e19b8ec

Finished first draft of TLearner
author Yoshua Bengio <bengioy@iro.umontreal.ca>
date Mon, 05 May 2008 18:14:32 -0400
parents 3499918faa9d
children c4916445e025
comparison
equal deleted inserted replaced
84:aa9e786ee849 92:c4726e19b8ec
9 class LinearRegression(Learner): 9 class LinearRegression(Learner):
10 """ 10 """
11 Implement linear regression, with or without L2 regularization 11 Implement linear regression, with or without L2 regularization
12 (the former is called Ridge Regression and the latter Ordinary Least Squares). 12 (the former is called Ridge Regression and the latter Ordinary Least Squares).
13 13
14 The predictor is obtained analytically. 14 The predictor parameters are obtained analytically from the training set.
15 Training can proceed sequentially (with multiple calls to update with
16 different disjoint subsets of the training sets). After each call to
17 update the predictor is ready to be used (and optimized for the union
18 of all the training sets passed to update since construction or since
19 the last call to forget).
15 20
16 The L2 regularization coefficient is obtained analytically. 21 The L2 regularization coefficient is obtained analytically.
17 For each (input[t],output[t]) pair in a minibatch,:: 22 For each (input[t],output[t]) pair in a minibatch,::
18 23
19 output_t = b + W * input_t 24 output_t = b + W * input_t
43 - 'output' (optionally produced by use as an output dataset field) 48 - 'output' (optionally produced by use as an output dataset field)
44 - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error 49 - 'squared_error' (optionally produced by use as an output dataset field, needs 'target') = example-wise squared error
45 50
46 - optional input attributes (optionally expected as input_dataset attributes) 51 - optional input attributes (optionally expected as input_dataset attributes)
47 52
48 - 'lambda' (only used by update) 53 - optional attributes (optionally expected as input_dataset attributes)
49 - 'b' (only used by use) 54 (warning, this may be dangerous, the 'use' method will use those provided in the
50 - 'W' (only used by use) 55 input_dataset rather than those learned during 'update'; currently no support
51 56 for providing these to update):
52 - optional output attributes (available in self and optionally in output dataset)
53
54 - 'b' (only set by update)
55 - 'W' (only set by update)
56 - 'regularization_term' (only set by update)
57 - 'XtX' (only set by update)
58 - 'XtY' (only set by update)
59 57
58 - 'lambda'
59 - 'b'
60 - 'W'
61 - 'regularization_term'
62 - 'XtX'
63 - 'XtY'
60 """ 64 """
61 65
66 def attributeNames(self):
67 return ["lambda","b","W","regularization_term","XtX","XtY"]
68
62 # definitions specifiques a la regression lineaire: 69 # definitions specifiques a la regression lineaire:
70
63 71
64 def global_inputs(self): 72 def global_inputs(self):
65 self.lambda = as_scalar(0.,'lambda') 73 self.lambda = as_scalar(0.,'lambda')
66 self.theta = t.matrix('theta') 74 self.theta = t.matrix('theta')
67 self.W = self.theta[:,1:] 75 self.W = self.theta[:,1:]
105 output_fields.append("squared_error") 113 output_fields.append("squared_error")
106 return output_fields 114 return output_fields
107 115
108 # poutine generale basee sur ces fonctions 116 # poutine generale basee sur ces fonctions
109 117
110 def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
111 if not output_fields:
112 output_fields = self.defaultOutputFields(input_fields)
113 if stats_collector:
114 stats_collector_inputs = stats_collector.inputUpdateAttributes()
115 for attribute in stats_collector_inputs:
116 if attribute not in input_fields:
117 output_fields.append(attribute)
118 key = (input_fields,output_fields)
119 if key not in self.use_functions_dictionary:
120 self.use_functions_dictionary[key]=Function(self.names2attributes(input_fields),
121 self.names2attributes(output_fields))
122 return self.use_functions_dictionary[key]
123
124 def attributes(self,return_copy=False):
125 return self.names2attributes(self.attributeNames())
126
127 def names2attributes(self,names,return_Result=False, return_copy=False):
128 if return_Result:
129 if return_copy:
130 return [copy.deepcopy(self.__getattr__(name)) for name in names]
131 else:
132 return [self.__getattr__(name) for name in names]
133 else:
134 if return_copy:
135 return [copy.deepcopy(self.__getattr__(name).data) for name in names]
136 else:
137 return [self.__getattr__(name).data for name in names]
138
139 def use(self,input_dataset,output_fieldnames=None,test_stats_collector=None,copy_inputs=True):
140 minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(),output_fieldnames,test_stats_collector)
141 virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
142 minibatchwise_use_function,
143 True,DataSet.numpy_vstack,
144 DataSet.numpy_hstack)
145 # actually force the computation
146 output_dataset = CachedDataSet(virtual_output_dataset,True)
147 if copy_inputs:
148 output_dataset = input_dataset | output_dataset
149 # compute the attributes that should be copied in the dataset
150 output_dataset.setAttributes(self.attributeNames(),self.attributes(return_copy=True))
151 if test_stats_collector:
152 test_stats_collector.update(output_dataset)
153 for attribute in test_stats_collector.attributeNames():
154 output_dataset[attribute] = copy.deepcopy(test_stats_collector[attribute])
155 return output_dataset
156
157 def update(self,training_set,train_stats_collector=None):
158 self.update_start()
159 for minibatch in training_set.minibatches(self.training_set_input_fields, minibatch_size=self.minibatch_size):
160 self.update_minibatch(minibatch)
161 if train_stats_collector:
162 minibatch_set = minibatch.examples()
163 minibatch_set.setAttributes(self.attributeNames(),self.attributes())
164 train_stats_collector.update(minibatch_set)
165 self.update_end()
166 return self.use
167 118
168 def __init__(self,lambda=0.,max_memory_use=500): 119 def __init__(self,lambda=0.,max_memory_use=500):
169 """ 120 """
170 @type lambda: float 121 @type lambda: float
171 @param lambda: regularization coefficient 122 @param lambda: regularization coefficient