view pylearn/shared/layers/sgd.py @ 942:fdd648c7c583

shared/layers/sgd removed cruft
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 16 Jul 2010 14:20:48 -0400
parents 580087712f69
children 0181459b53a1
line wrap: on
line source

"""
Provides StochasticGradientDescent, HalflifeStopper
"""
import numpy
import theano
from theano import tensor
from theano.compile.sandbox import shared

class StochasticGradientDescent(object):
    """Fixed stepsize gradient descent

    For given inputs, the outputs of this object are the new values that the inputs should take
    in order to perform stochastic gradient descent. 

    The updates attribute is a list of (p, new_p) pairs for all inputs `p` that are
    SharedVariables. 

    """
    def __init__(self, inputs, cost, stepsize, gradients, params):
        """
        :param stepsize: the step to take in (negative) gradient direction
        :type stepsize: None, scalar value, or scalar TensorVariable
        """
        if len(inputs) != len(gradients):
            raise ValueError('inputs list and gradients list must have same len')

        self.inputs = inputs
        self.gradients = gradients
        self.params = params # contains either nothing or the learning rate
        self.outputs = [i - stepsize*g for (i,g) in zip(inputs, gradients)]
        self.updates = [(input, self.outputs[i])
            for (i,input) in enumerate(self.inputs)
            if hasattr(input, 'value')] # true for shared variables

    @classmethod
    def new(cls, inputs, cost, stepsize, dtype=None):
        if dtype is None:
            dtype = cost.dtype

        ginputs = tensor.grad(cost, inputs)

        if isinstance(stepsize, theano.Variable):
            _stepsize = stepsize
            params = []
        else:
            _stepsize = shared(numpy.asarray(stepsize, dtype=dtype))
            params = [_stepsize]

        if _stepsize.type.ndim != 0:
            raise TypeError('stepsize must be a scalar', stepsize)

        rval = cls(inputs, cost, _stepsize, ginputs, params)

        # if we allocated a shared variable for the stepsize, 
        # put it into the stepsize attribute.
        if params:
            rval.stepsize = _stepsize

        return rval


class HalflifeStopper(object):
    """An early-stopping crition.

    This object will track the progress of a dynamic quantity along some noisy U-shaped
    trajectory.

    The heuristic used is to first iterate at least `initial_wait` times, while looking at the
    score.  If at any point thereafter, the score hasn't made a *significant* improvement in the
    second half  of the entire run, the run is declared *not*-`promising`.

    Significant improvement in the second half of a run is defined as achieving
    `progresh_thresh` proportion of the best score from the first half of the run.

    .. code-block:: python

        stopper = HalflifeStopper()
        ...
        while (...):
            stopper.step(score)
            if m.stopper.best_updated:
                # this is the best score we've seen yet
            if not m.stopper.promising:
                # we haven't seen a good score in a long time,
                # and the stopper recommends giving up.
                break

    """
    def __init__(self, 
            initial_wait=20,
            patience_factor=2.0,
            progress_thresh=0.99 ):
        """
        :param method:
        :param method_output_idx:
        :param initial_wait:
        :param patience_factor:
        :param progress_thresh:
        """
        #constants
        self.progress_thresh = progress_thresh
        self.patience_factor = patience_factor
        self.initial_wait = initial_wait

        #dynamic variables
        self.iter = 0
        self.promising = True

        self.halflife_iter = -1
        self.halflife_value = float('inf')
        self.halflife_updated = False

        self.best_iter = -1
        self.best_value = float('inf')
        self.best_updated = False


    def step(self, value):
        if value < (self.halflife_value * self.progress_thresh):
            self.halflife_updated = True
            self.halflife_value = value
            self.halflife_iter = self.iter
        else:
            self.halflife_updated = False

        if value < self.best_value:
            self.best_updated = True
            self.best_value = value
            self.best_iter = self.iter
        else:
            self.best_updated = False

        self.promising = (self.iter < self.initial_wait) \
                or (self.iter < (self.halflife_iter * self.patience_factor))
        self.iter += 1