Mercurial > pylearn
comparison linear_regression.py @ 78:3499918faa9d
In the middle of designing TLearner
author | bengioy@bengiomac.local |
---|---|
date | Mon, 05 May 2008 09:35:30 -0400 |
parents | 1e2bb5bad636 |
children | c4726e19b8ec |
comparison
equal
deleted
inserted
replaced
77:1e2bb5bad636 | 78:3499918faa9d |
---|---|
94 # self.input is a (n_examples, n_inputs) minibatch matrix | 94 # self.input is a (n_examples, n_inputs) minibatch matrix |
95 self.extended_input = t.prepend_one_to_each_row(self.input) | 95 self.extended_input = t.prepend_one_to_each_row(self.input) |
96 self.output = t.dot(self.input,self.W.T) + self.b # (n_examples , n_outputs) matrix | 96 self.output = t.dot(self.input,self.W.T) + self.b # (n_examples , n_outputs) matrix |
97 self.squared_error = t.sum_within_rows(t.sqr(self.output-self.target)) # (n_examples ) vector | 97 self.squared_error = t.sum_within_rows(t.sqr(self.output-self.target)) # (n_examples ) vector |
98 | 98 |
99 def attribute_names(self): | 99 def attributeNames(self): |
100 return ["lambda","b","W","regularization_term","XtX","XtY"] | 100 return ["lambda","b","W","regularization_term","XtX","XtY"] |
101 | 101 |
102 def default_output_fields(self, input_fields): | 102 def defaultOutputFields(self, input_fields): |
103 output_fields = ["output"] | 103 output_fields = ["output"] |
104 if "target" in input_fields: | 104 if "target" in input_fields: |
105 output_fields.append("squared_error") | 105 output_fields.append("squared_error") |
106 return output_fields | 106 return output_fields |
107 | 107 |
108 # poutine generale basee sur ces fonctions | 108 # poutine generale basee sur ces fonctions |
109 | 109 |
110 def minibatchwise_use_functions(self, input_fields, output_fields): | 110 def minibatchwise_use_functions(self, input_fields, output_fields, stats_collector): |
111 if not output_fields: | 111 if not output_fields: |
112 output_fields = self.default_output_fields(input_fields) | 112 output_fields = self.defaultOutputFields(input_fields) |
113 if stats_collector: | |
114 stats_collector_inputs = stats_collector.inputUpdateAttributes() | |
115 for attribute in stats_collector_inputs: | |
116 if attribute not in input_fields: | |
117 output_fields.append(attribute) | |
113 key = (input_fields,output_fields) | 118 key = (input_fields,output_fields) |
114 if key not in use_functions_dictionary: | 119 if key not in self.use_functions_dictionary: |
115 use_functions_dictionary[key]=Function(self.names2attributes(input_fields), | 120 self.use_functions_dictionary[key]=Function(self.names2attributes(input_fields), |
116 self.names2attributes(output_fields)) | 121 self.names2attributes(output_fields)) |
117 return use_functions_dictionary[key] | 122 return self.use_functions_dictionary[key] |
118 | 123 |
119 def names2attributes(self,names,return_Result=True): | 124 def attributes(self,return_copy=False): |
125 return self.names2attributes(self.attributeNames()) | |
126 | |
127 def names2attributes(self,names,return_Result=False, return_copy=False): | |
120 if return_Result: | 128 if return_Result: |
121 return [self.__getattr__(name) for name in names] | 129 if return_copy: |
130 return [copy.deepcopy(self.__getattr__(name)) for name in names] | |
131 else: | |
132 return [self.__getattr__(name) for name in names] | |
122 else: | 133 else: |
123 return [self.__getattr__(name).data for name in names] | 134 if return_copy: |
135 return [copy.deepcopy(self.__getattr__(name).data) for name in names] | |
136 else: | |
137 return [self.__getattr__(name).data for name in names] | |
124 | 138 |
125 def use(self,input_dataset,output_fieldnames=None,test_stats_collector=None,copy_inputs=True): | 139 def use(self,input_dataset,output_fieldnames=None,test_stats_collector=None,copy_inputs=True): |
126 minibatchwise_use_function = use_functions(input_dataset.fieldNames(),output_fieldnames) | 140 minibatchwise_use_function = minibatchwise_use_functions(input_dataset.fieldNames(),output_fieldnames,test_stats_collector) |
127 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, | 141 virtual_output_dataset = ApplyFunctionDataSet(input_dataset, |
128 minibatchwise_use_function, | 142 minibatchwise_use_function, |
129 True,DataSet.numpy_vstack, | 143 True,DataSet.numpy_vstack, |
130 DataSet.numpy_hstack) | 144 DataSet.numpy_hstack) |
131 # actually force the computation | 145 # actually force the computation |
132 output_dataset = CachedDataSet(virtual_output_dataset,True) | 146 output_dataset = CachedDataSet(virtual_output_dataset,True) |
133 if copy_inputs: | 147 if copy_inputs: |
134 output_dataset = input_dataset | output_dataset | 148 output_dataset = input_dataset | output_dataset |
135 # compute the attributes that should be copied in the dataset | 149 # compute the attributes that should be copied in the dataset |
136 for attribute in self.attribute_names(): | 150 output_dataset.setAttributes(self.attributeNames(),self.attributes(return_copy=True)) |
137 # .data assumes that all attributes are Result objects | |
138 output_dataset.__setattr__(attribute) = copy.deepcopy(self.__getattr__(attribute).data) | |
139 if test_stats_collector: | 151 if test_stats_collector: |
140 test_stats_collector.update(output_dataset) | 152 test_stats_collector.update(output_dataset) |
141 for attribute in test_stats_collector.attribute_names(): | 153 for attribute in test_stats_collector.attributeNames(): |
142 output_dataset[attribute] = copy.deepcopy(test_stats_collector[attribute]) | 154 output_dataset[attribute] = copy.deepcopy(test_stats_collector[attribute]) |
143 return output_dataset | 155 return output_dataset |
144 | 156 |
145 def update(self,training_set,train_stats_collector=None): | 157 def update(self,training_set,train_stats_collector=None): |
146 | 158 self.update_start() |
159 for minibatch in training_set.minibatches(self.training_set_input_fields, minibatch_size=self.minibatch_size): | |
160 self.update_minibatch(minibatch) | |
161 if train_stats_collector: | |
162 minibatch_set = minibatch.examples() | |
163 minibatch_set.setAttributes(self.attributeNames(),self.attributes()) | |
164 train_stats_collector.update(minibatch_set) | |
165 self.update_end() | |
166 return self.use | |
147 | 167 |
148 def __init__(self,lambda=0.,max_memory_use=500): | 168 def __init__(self,lambda=0.,max_memory_use=500): |
149 """ | 169 """ |
150 @type lambda: float | 170 @type lambda: float |
151 @param lambda: regularization coefficient | 171 @param lambda: regularization coefficient |