comparison linear_regression.py @ 78:3499918faa9d

In the middle of designing TLearner
author bengioy@bengiomac.local
date Mon, 05 May 2008 09:35:30 -0400
parents 1e2bb5bad636
children c4726e19b8ec
comparison
equal deleted inserted replaced
77:1e2bb5bad636 78:3499918faa9d
94 # self.input is a (n_examples, n_inputs) minibatch matrix 94 # self.input is a (n_examples, n_inputs) minibatch matrix
95 self.extended_input = t.prepend_one_to_each_row(self.input) 95 self.extended_input = t.prepend_one_to_each_row(self.input)
96 self.output = t.dot(self.input,self.W.T) + self.b # (n_examples , n_outputs) matrix 96 self.output = t.dot(self.input,self.W.T) + self.b # (n_examples , n_outputs) matrix
97 self.squared_error = t.sum_within_rows(t.sqr(self.output-self.target)) # (n_examples ) vector 97 self.squared_error = t.sum_within_rows(t.sqr(self.output-self.target)) # (n_examples ) vector
98 98
99 def attribute_names(self): 99 def attributeNames(self):
100 return ["lambda","b","W","regularization_term","XtX","XtY"] 100 return ["lambda","b","W","regularization_term","XtX","XtY"]
101 101
102 def default_output_fields(self, input_fields): 102 def defaultOutputFields(self, input_fields):
103 output_fields = ["output"] 103 output_fields = ["output"]
104 if "target" in input_fields: 104 if "target" in input_fields:
105 output_fields.append("squared_error") 105 output_fields.append("squared_error")
106 return output_fields 106 return output_fields
107 107
108 # poutine generale basee sur ces fonctions 108 # poutine generale basee sur ces fonctions
109 109
110 def minibatchwise_use_functions(self, input_fields, output_fields): 110 def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector):
111 if not output_fields: 111 if not output_fields:
112 output_fields = self.default_output_fields(input_fields) 112 output_fields = self.defaultOutputFields(input_fields)
113 if stats_collector:
114 stats_collector_inputs = stats_collector.inputUpdateAttributes()
115 for attribute in stats_collector_inputs:
116 if attribute not in input_fields:
117 output_fields.append(attribute)
113 key = (input_fields,output_fields) 118 key = (input_fields,output_fields)
114 if key not in use_functions_dictionary: 119 if key not in self.use_functions_dictionary:
115 use_functions_dictionary[key]=Function(self.names2attributes(input_fields), 120 self.use_functions_dictionary[key]=Function(self.names2attributes(input_fields),
116 self.names2attributes(output_fields)) 121 self.names2attributes(output_fields))
117 return use_functions_dictionary[key] 122 return self.use_functions_dictionary[key]
118 123
119 def names2attributes(self,names,return_Result=True): 124 def attributes(self,return_copy=False):
125 return self.names2attributes(self.attributeNames())
126
127 def names2attributes(self,names,return_Result=False, return_copy=False):
120 if return_Result: 128 if return_Result:
121 return [self.__getattr__(name) for name in names] 129 if return_copy:
130 return [copy.deepcopy(self.__getattr__(name)) for name in names]
131 else:
132 return [self.__getattr__(name) for name in names]
122 else: 133 else:
123 return [self.__getattr__(name).data for name in names] 134 if return_copy:
135 return [copy.deepcopy(self.__getattr__(name).data) for name in names]
136 else:
137 return [self.__getattr__(name).data for name in names]
124 138
125 def use(self,input_dataset,output_fieldnames=None,test_stats_collector=None,copy_inputs=True): 139 def use(self,input_dataset,output_fieldnames=None,test_stats_collector=None,copy_inputs=True):
126 minibatchwise_use_function = use_functions(input_dataset.fieldNames(),output_fieldnames) 140 minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(),output_fieldnames,test_stats_collector)
127 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, 141 virtual_output_dataset = ApplyFunctionDataSet(input_dataset,
128 minibatchwise_use_function, 142 minibatchwise_use_function,
129 True,DataSet.numpy_vstack, 143 True,DataSet.numpy_vstack,
130 DataSet.numpy_hstack) 144 DataSet.numpy_hstack)
131 # actually force the computation 145 # actually force the computation
132 output_dataset = CachedDataSet(virtual_output_dataset,True) 146 output_dataset = CachedDataSet(virtual_output_dataset,True)
133 if copy_inputs: 147 if copy_inputs:
134 output_dataset = input_dataset | output_dataset 148 output_dataset = input_dataset | output_dataset
135 # compute the attributes that should be copied in the dataset 149 # compute the attributes that should be copied in the dataset
136 for attribute in self.attribute_names(): 150 output_dataset.setAttributes(self.attributeNames(),self.attributes(return_copy=True))
137 # .data assumes that all attributes are Result objects
138 output_dataset.__setattr__(attribute) = copy.deepcopy(self.__getattr__(attribute).data)
139 if test_stats_collector: 151 if test_stats_collector:
140 test_stats_collector.update(output_dataset) 152 test_stats_collector.update(output_dataset)
141 for attribute in test_stats_collector.attribute_names(): 153 for attribute in test_stats_collector.attributeNames():
142 output_dataset[attribute] = copy.deepcopy(test_stats_collector[attribute]) 154 output_dataset[attribute] = copy.deepcopy(test_stats_collector[attribute])
143 return output_dataset 155 return output_dataset
144 156
145 def update(self,training_set,train_stats_collector=None): 157 def update(self,training_set,train_stats_collector=None):
146 158 self.update_start()
159 for minibatch in training_set.minibatches(self.training_set_input_fields, minibatch_size=self.minibatch_size):
160 self.update_minibatch(minibatch)
161 if train_stats_collector:
162 minibatch_set = minibatch.examples()
163 minibatch_set.setAttributes(self.attributeNames(),self.attributes())
164 train_stats_collector.update(minibatch_set)
165 self.update_end()
166 return self.use
147 167
148 def __init__(self,lambda=0.,max_memory_use=500): 168 def __init__(self,lambda=0.,max_memory_use=500):
149 """ 169 """
150 @type lambda: float 170 @type lambda: float
151 @param lambda: regularization coefficient 171 @param lambda: regularization coefficient