# HG changeset patch # User James Bergstra # Date 1238699302 14400 # Node ID fec0ba6f8c8f86e1390635ea643c96721af6bee6 # Parent a5a41b7ddd266d17ae1b9a6f2408133df30f5566 added updates parameter to sgd diff -r a5a41b7ddd26 -r fec0ba6f8c8f pylearn/algorithms/sgd.py --- a/pylearn/algorithms/sgd.py Wed Apr 01 19:48:47 2009 -0400 +++ b/pylearn/algorithms/sgd.py Thu Apr 02 15:08:22 2009 -0400 @@ -5,10 +5,14 @@ class StochasticGradientDescent(theano.Module): """Fixed stepsize gradient descent""" - def __init__(self, args, cost, params, gradients=None, stepsize=None): + def __init__(self, args, cost, params, gradients=None, stepsize=None, updates=None): """ :param stepsize: the step to take in (negative) gradient direction :type stepsize: None, scalar value, or scalar TensorVariable + + :param updates: extra symbolic updates to make when evating either step or step_cost + (these override the gradients if necessary) + :type updatess: dict Variable -> Variable """ super(StochasticGradientDescent, self).__init__() self.stepsize_init = None @@ -26,14 +30,19 @@ self.params = params self.gparams = theano.tensor.grad(cost, self.params) if gradients is None else gradients - self.updates = dict((p, p - self.stepsize * g) for p, g in zip(self.params, self.gparams)) + self._updates = (dict((p, p - self.stepsize * g) for p, g in zip(self.params, self.gparams))) + if updates is not None: + self._updates.update(updates) + self.step = theano.Method( args, [], - updates=self.updates) + updates=self._updates) self.step_cost = theano.Method( args, cost, - updates=self.updates) + updates=self._updates) + + updates = property(lambda self: self._updates.copy()) def _instance_initialize(self, obj): pass @@ -43,6 +52,6 @@ :returns: standard minimizer constructor f(args, cost, params, gradient=None) """ - def f(args, cost, params, gradient=None): - return StochasticGradientDescent(args, cost, params, gradient, stepsize) + def f(args, cost, params, gradient=None, updates=None): + return StochasticGradientDescent(args, cost, params, gradient, stepsize, updates=updates) return f