view pylearn/sampling/hmc.py @ 1502:4fa5ebe8a7ad

Auto white space fix.
author Frederic Bastien <nouiz@nouiz.org>
date Fri, 09 Sep 2011 10:54:17 -0400
parents 55534951dd91
children
line wrap: on
line source


"""Hybrid / Hamiltonian Monte Carlo Sampling

This algorithm is described in Radford Neal's PhD Thesis, pages 63--70.

:note:

  The 'mass' of so-called particles is taken to be 1, so that kinetic energy (K) is the sum
  of squared velocities (p).

    :math:`K = \sum_i p_i^2 / 2`


The 'leap-frog' algorithm that advances by turns the velocity and the position is
currently implemented via several theano functions, rather than one complete theano
expression graph.

The initialize_dynamics() theano-function does several things:
1. samples a random velocity for each particle (saving it to self.velocities)
2. calculates the initial hamiltonian based on those velocities (saving it to
    self.initial_hamiltonian)
3. saves self.positions to self.initial_positions.

The finalize_dynamics() theano-function re-calculates the Hamiltonian for each particle
based on the self.positions and self.velocities, and then implements the
Metropolis-Hastings accept/reject for each particle in the simulation by consulting the
self.initial_hamiltonian storing the result to self.


"""
import numpy
from theano import function, shared
from theano import tensor as TT
import theano
from theano.printing import Print

def Print(msg):
    return lambda x: x

def kinetic_energy(velocities, masses):
    if masses != 1:
        raise NotImplementedError()
    return 0.5 * (velocities**2).sum(axis=1)

def metropolis_hastings_accept(energy_prev, energy_next, s_rng, shape=None):
    """
    Return acceptance of moves - energy_prev and energy_next are vectors of the energy of each
    particle.
    """
    if shape is None:
        shape = energy_prev.shape
    ediff = Print('diff')(energy_prev - energy_next)
    return (TT.exp(ediff) - s_rng.uniform(size=shape)) >= 0

def hamiltonian(pos, vel, energy_fn):
    """Return a vector of energies - sum of kinetic and potential energy
    """
    # assuming mass is 1
    return energy_fn(pos) + kinetic_energy(vel, 1)

def simulate_dynamics(initial_p, initial_v, stepsize, n_steps, energy_fn):
    """
    Return final (position, velocity) obtained after an `n_steps` leapfrog updates, using
    Hamiltonian dynamics.

    Parameters
    ----------
    initial_p: shared theano matrix
        Initial position at which to start the simulation
    initial_v: shared theano matrix
        Initial velocity of particles
    stepsize: shared theano scalar
        Scalar value controlling amount by which to move
    energy_fn: python function
        Python function, operating on symbolic theano variables, used to compute the potential
        energy at a given position.

    Returns
    -------
    rval1: theano matrix
        Final positions obtained after simulation
    rval2: theano matrix
        Final velocity obtained after simulation
    """

    def leapfrog(pos, vel, step):
        """
        Inside loop of Scan. Performs one step of leapfrog update, using Hamiltonian dynamics.

        Parameters
        ----------
        pos: theano matrix
            in leapfrog update equations, represents pos(t), position at time t
        vel: theano matrix
            in leapfrog update equations, represents vel(t - stepsize/2),
            velocity at time (t - stepsize/2)
        step: theano scalar
            scalar value controlling amount by which to move

        Returns
        -------
        rval1: [theano matrix, theano matrix]
            Symbolic theano matrices for new position pos(t + stepsize), and velocity
            vel(t + stepsize/2)
        rval2: dictionary
            Dictionary of updates for the Scan Op
        """
        # from pos(t) and vel(t-sigma/2), compute vel(t+sigma/2)
        dE_dpos = TT.grad(energy_fn(pos).sum(), pos)
        new_vel = vel - step * dE_dpos
        # from vel(t+sigma/2) compute pos(t+sigma)
        new_pos = pos + step * new_vel
        return [new_pos, new_vel],{}

    # compute velocity at time-step: t + sigma/2
    initial_energy = energy_fn(initial_p)
    dE_dpos = TT.grad(initial_energy.sum(), initial_p)
    v_half_step = initial_v - 0.5*stepsize*dE_dpos

    p_full_step = initial_p + stepsize * v_half_step

    # perform leapfrog updates: the scan op is used to repeatedly compute pos(t_1 + n*sigma) and
    # vel(t_1 + n*sigma + 1/2) for n in [0,n_steps-2].
    (all_p, all_v), scan_updates = theano.scan(leapfrog,
            outputs_info=[
                dict(initial=p_full_step),
                dict(initial=v_half_step),
                ],
            non_sequences=[stepsize],
            n_steps=n_steps-1)

    final_p = all_p[-1]
    final_v = all_v[-1]

    # NOTE: Scan always returns an updates dictionary, in case the scanned function draws
    # samples from a RandomStream. These updates must then be used when compiling the Theano
    # function, to avoid drawing the same random numbers each time the function is called. In
    # this case however, we consciously ignore "scan_updates" because we know it is empty.
    assert not scan_updates

    # The last velocity returned by the scan op is at time-step: t + n_steps* stepsize - 1/2
    # We therefore perform one more half-step to return vel(t + n_steps*stepsize)
    energy = energy_fn(final_p)
    final_v = final_v - 0.5 * stepsize * TT.grad(energy.sum(), final_p)
    return final_p, final_v


def mcmc_move(s_rng, positions, energy_fn, stepsize, n_steps, positions_shape=None):
    """Return new position
    """
    if positions_shape is None:
        positions_shape = positions.shape

    batchsize = positions_shape[0]

    initial_v = s_rng.normal(size=positions_shape)

    final_p, final_v = simulate_dynamics(
            initial_p = positions,
            initial_v = initial_v,
            stepsize = stepsize,
            n_steps = n_steps,
            energy_fn = energy_fn)

    accept = metropolis_hastings_accept(
            energy_prev = Print('ep')(hamiltonian(positions, initial_v, energy_fn)),
            energy_next = Print('en')(hamiltonian(final_p, final_v, energy_fn)),
            s_rng=s_rng, shape=(batchsize,))

    return Print('accept')(accept), final_p

def mcmc_updates(shrd_pos, shrd_stepsize, shrd_avg_acceptance_rate, final_p, accept,
        target_acceptance_rate,
        stepsize_inc,
        stepsize_dec,
        stepsize_min,
        stepsize_max,
        avg_acceptance_slowness
        ):
    return [
            (shrd_pos,
                TT.switch(
                    accept.dimshuffle(0, *(('x',)*(final_p.ndim-1))),
                    final_p,
                    shrd_pos)),
            (shrd_stepsize,
                TT.clip(
                    TT.switch(
                        shrd_avg_acceptance_rate > target_acceptance_rate,
                        shrd_stepsize * stepsize_inc,
                        shrd_stepsize * stepsize_dec,
                        ),
                    stepsize_min,
                    stepsize_max)),
            (shrd_avg_acceptance_rate,
                Print('arate')(TT.add(
                    avg_acceptance_slowness * shrd_avg_acceptance_rate,
                    (1.0 - avg_acceptance_slowness) * accept.mean()))),
            ]

class HMC_sampler(object):
    """Convenience wrapper for HMC

    The `draw` function advances the markov chain and returns the current sample by calling
    `simulate` and `get_position` in sequence.

    """

    # Constants taken from Marc'Aurelio's 'train_mcRBM.py' file found in the code online for his
    # paper.

    def __init__(self, **kwargs):
        # add things to __dict__
        self.__dict__.update(kwargs)

    @classmethod
    def new_from_shared_positions(cls, shared_positions, energy_fn,
            initial_stepsize=0.01, target_acceptance_rate=.9, n_steps=20,
            stepsize_dec = 0.98,
            stepsize_min = 0.001,
            stepsize_max = 0.25,
            stepsize_inc = 1.02,
            avg_acceptance_slowness = 0.9, # used in geometric avg. 1.0 would be not moving at all
            seed=12345, dtype=theano.config.floatX,
            shared_positions_shape=None,
            compile_simulate=True):
        """
        :param shared_positions: theano ndarray shared var with many particle [initial] positions
        :param energy_fn:
            callable such that energy_fn(positions)
            returns theano vector of energies.
            The len of this vector is the batchsize.

            The sum of this energy vector must be differentiable (with theano.tensor.grad) with
            respect to the positions for HMC sampling to work.
        """
        # allocate shared vars

        if shared_positions_shape==None:
            shared_positions_shape = shared_positions.get_value(borrow=True).shape
        batchsize = shared_positions_shape[0]

        stepsize = shared(numpy.asarray(initial_stepsize).astype(theano.config.floatX), 'hmc_stepsize')
        avg_acceptance_rate = shared(target_acceptance_rate, 'avg_acceptance_rate')
        s_rng = TT.shared_randomstreams.RandomStreams(seed)

        accept, final_p = mcmc_move(
                s_rng,
                shared_positions,
                energy_fn,
                stepsize,
                n_steps,
                shared_positions_shape)
        simulate_updates = mcmc_updates(
                shared_positions,
                stepsize,
                avg_acceptance_rate,
                final_p=final_p,
                accept=accept,
                stepsize_min=stepsize_min,
                stepsize_max=stepsize_max,
                stepsize_inc=stepsize_inc,
                stepsize_dec=stepsize_dec,
                target_acceptance_rate=target_acceptance_rate,
                avg_acceptance_slowness=avg_acceptance_slowness)
        if compile_simulate:
            simulate = function([], [], updates=simulate_updates)
        else:
            simulate = None
        return cls(
                positions=shared_positions,
                stepsize=stepsize,
                stepsize_min=stepsize_min,
                stepsize_max=stepsize_max,
                avg_acceptance_rate=avg_acceptance_rate,
                target_acceptance_rate=target_acceptance_rate,
                s_rng=s_rng,
                _updates=simulate_updates,
                simulate=simulate)

    def draw(self, **kwargs):
        """Return the current sample in the Markov chain as a list of numpy arrays

        The size of the arrays will match the size of the `initial_position` argument to
        __init__.

        The `kwargs` dictionary is passed to the shared variable (self.positions) `get_value()`
        function.  So for example, to avoid copying the shared variable value, consider passing
        `borrow=True`.
        """
        self.simulate()
        return self.positions.get_value(borrow=False)

    def updates(self):
        """Returns the update expressions required to simulate the Markov Chain

        :TODO: :WRITEME: *prescriptive* definition of what this method does (API)
        """
        return list(self._updates)

#TODO:
# Consider a heuristic for updating the *MASS* of the particles.  We might want the mass to be
# such that the momentum is in the same range as the gradient on the energy.  Look at Radford's
# recent book chapter, maybe there are hints. (2010).