view gradient_learner.py @ 22:b6b36f65664f

Created virtual sub-classes of DataSet: {Finite{Length,Width},Sliceable}DataSet, removed .field ability from LookupList (because of setattr problems), removed fieldNames() from DataSet (but is in FiniteWidthDataSet, where it makes sense), and added hasFields() instead. Fixed problems in asarray, and tested previous functionality in _test_dataset.py, but not yet new functionality.
author bengioy@esprit.iro.umontreal.ca
date Mon, 07 Apr 2008 20:44:37 -0400
parents 266c68cb6136
children 526e192b0699
line wrap: on
line source


from learner import *
from tensor import *
import gradient
from compile import Function

class GradientLearner(Learner):
    """
    Base class for gradient-based optimization of a training criterion
    that can consist in two parts, an additive part over examples, and
    an example-independent part (usually called the regularizer).
    The user provides a Theano formula that maps the fields of a training example
    and parameters to output fields (for the use function), one of which must be a cost
    that is the training criterion to be minimized. Subclasses implement
    a training strategy that uses the Theano formula to compute gradients and
    to compute outputs in the update method.
    The inputs, parameters, and outputs are lists of Theano tensors,
    while the example_wise_cost and regularization_term are Theano tensors.
    The user can specify a regularization coefficient that multiplies the regularization term.
    The training algorithm looks for parameters that minimize
       regularization_coefficient * regularization_term(parameters) +
       sum_{inputs in training_set} example_wise_cost(inputs,parameters)
    i.e. the regularization_term should not depend on the inputs, only on the parameters.
    The learned function can map a subset of inputs to a subset of outputs (as long as the inputs subset
    includes all the inputs required in the Theano expression for the selected outputs).
    It is assumed that all the inputs are provided in the training set (as dataset fields
    with the corresponding name), but not necessarily when using the learned function.
    """
    def __init__(self, inputs, parameters, outputs, example_wise_cost, regularization_term,
                 regularization_coefficient = astensor(1.0)):
        self.inputs = inputs
        self.outputs = outputs
        self.parameters = parameters
        self.example_wise_cost = example_wise_cost
        self.regularization_term = regularization_term
        self.regularization_coefficient = regularization_coefficient
        self.parameters_example_wise_gradient = gradient.grad(example_wise_cost, parameters)
        self.parameters_regularization_gradient = gradient.grad(self.regularization_coefficient * regularization_term, parameters)
        if example_wise_cost not in outputs:
            outputs.append(example_wise_cost)
        if regularization_term not in outputs:
            outputs.append(regularization_term)
        self.example_wise_gradient_fn = Function(inputs + parameters, 
                                       [self.parameters_example_wise_gradient + self.parameters_regularization_gradient])
        self.use_functions = {frozenset([input.name for input in inputs]+[output.name for output in outputs])
                                        : Function(inputs, outputs)}

    def use(self,input_dataset,output_fields=None,copy_inputs=True):
        # obtain the function that maps the desired inputs to desired outputs
        input_fields = input_dataset.fieldNames()
        if output_fields is None: output_fields = [output.name for output in outputs]
        # handle special case of inputs that are directly copied into outputs
        
        use_function_key = input_fields+output_fields
        if not self.use_functions.has_key(use_function_key):
            self.use_function[use_function_key]=Function(input_fields,output_fields)
        use_function = self.use_functions[use_function_key]
        # return a virtual dataset that computes the outputs on demand
        return input_dataset.apply_function(use_function,input_fields,output_fields,copy_inputs,accept_minibatches=???)