Mercurial > pylearn
diff gradient_learner.py @ 26:672fe4b23032
Fixed dataset errors so that _test_dataset.py works again.
author | bengioy@grenat.iro.umontreal.ca |
---|---|
date | Fri, 11 Apr 2008 11:14:54 -0400 |
parents | 526e192b0699 |
children | 46c5c90019c2 |
line wrap: on
line diff
--- a/gradient_learner.py Wed Apr 09 18:27:13 2008 -0400 +++ b/gradient_learner.py Fri Apr 11 11:14:54 2008 -0400 @@ -26,7 +26,7 @@ It is assumed that all the inputs are provided in the training set (as dataset fields with the corresponding name), but not necessarily when using the learned function. """ - def __init__(self, inputs, parameters, outputs, example_wise_cost, regularization_term, + def __init__(self, inputs, parameters, outputs, example_wise_cost, regularization_term=astensor(0.0), regularization_coefficient = astensor(1.0)): self.inputs = inputs self.outputs = outputs @@ -48,13 +48,24 @@ def use(self,input_dataset,output_fields=None,copy_inputs=True): # obtain the function that maps the desired inputs to desired outputs input_fields = input_dataset.fieldNames() + # map names of input fields to Theano tensors in self.inputs + input_variables = ??? if output_fields is None: output_fields = [output.name for output in outputs] # handle special case of inputs that are directly copied into outputs - + # map names of output fields to Theano tensors in self.outputs + output_variables = ??? use_function_key = input_fields+output_fields if not self.use_functions.has_key(use_function_key): - self.use_function[use_function_key]=Function(input_fields,output_fields) + self.use_function[use_function_key]=Function(input_variables,output_variables) use_function = self.use_functions[use_function_key] # return a dataset that computes the outputs return input_dataset.applyFunction(use_function,input_fields,output_fields,copy_inputs,compute_now=True) + +class StochasticGradientDescent(object): + def update_parameters(self): + +class StochasticGradientLearner(GradientLearner,StochasticGradientDescent): + def __init__(self,inputs, parameters, outputs, example_wise_cost, regularization_term=astensor(0.0), + regularization_coefficient = astensor(1.0),) + def update()