changeset 1341:fbe4a6383441

HMC: perform half-step on velocity first (instead of position). On a 5D gaussian this led to better performance than the previous implementation.
author gdesjardins
date Fri, 22 Oct 2010 10:58:57 -0400
parents 04b988fb00b6
children 4ac393ec2eb7
files pylearn/sampling/hmc.py
diffstat 1 files changed, 64 insertions(+), 14 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sampling/hmc.py	Thu Oct 21 16:18:52 2010 -0400
+++ b/pylearn/sampling/hmc.py	Fri Oct 22 10:58:57 2010 -0400
@@ -62,33 +62,83 @@
     return energy_fn(pos) + kinetic_energy(vel, 1)
 
 def simulate_dynamics(initial_p, initial_v, stepsize, n_steps, energy_fn):
-    """Return final (position, velocity) of `n_step` trajectory
     """
+    Return final (position, velocity) obtained after an `n_steps` leapfrog updates, using
+    Hamiltonian dynamics.
+
+    Parameters
+    ----------
+    initial_p: shared theano matrix
+        Initial position at which to start the simulation
+    initial_v: shared theano matrix
+        Initial velocity of particles
+    stepsize: shared theano scalar
+        Scalar value controlling amount by which to move
+    energy_fn: python function
+        Python function, operating on symbolic theano variables, used to compute the potential
+        energy at a given position.
+
+    Returns
+    -------
+    rval1: theano matrix
+        Final positions obtained after simulation
+    rval2: theano matrix
+        Final velocity obtained after simulation
+    """
+
     def leapfrog(pos, vel, step):
-        egy = energy_fn(pos)
-        dE_dpos = TT.grad(egy.sum(), pos)
+        """
+        Inside loop of Scan. Performs one step of leapfrog update, using Hamiltonian dynamics.
+
+        Parameters
+        ----------
+        pos: theano matrix
+            in leapfrog update equations, represents pos(t), position at time t
+        vel: theano matrix
+            in leapfrog update equations, represents vel(t + stepsize/2), 
+            velocity at time (t + stepsize/2)
+        step: theano scalar
+            scalar value controlling amount by which to move
+        
+        Returns
+        -------
+        rval1: [theano matrix, theano matrix]
+            Symbolic theano matrices for new position pos(t + stepsize), and velocity
+            vel(t+3*stepsize/2)
+        rval2: dictionary
+            Dictionary of updates for the Scan Op
+        """
+        # from pos(t) and vel(t-sigma/2), compute vel(t+sigma/2)
+        dE_dpos = TT.grad(energy_fn(pos).sum(), pos)
         new_vel = vel - step * dE_dpos
+        # from vel(t+sigma/2) compute pos(t+sigma)
         new_pos = pos + step * new_vel
         return [new_pos, new_vel],{}
 
+    # compute velocity at time-step: t + sigma/2
+    initial_energy = energy_fn(initial_p)
+    dE_dpos = TT.grad(initial_energy.sum(), initial_p)
+    v_half_step = initial_v - 0.5*stepsize*dE_dpos
+
+    p_full_step = initial_p + stepsize * v_half_step
+
+    # perform leapfrog updates: the scan op is used to repeatedly compute pos(t_1 + n*sigma) and
+    # vel(t_1 + n*sigma + 1/2) for n in [0,n_steps-2].
     (final_p, final_v), scan_updates = theano.scan(leapfrog, 
             outputs_info=[
-                dict(initial=initial_p+ 0.5*stepsize*initial_v,
-                    return_steps=1),
-                dict(initial=initial_v,
-                    return_steps=1),
+                dict(initial=p_full_step, return_steps=1),
+                dict(initial=v_half_step, return_steps=1),
                 ],
             non_sequences=[stepsize],
-            n_steps=n_steps)
+            n_steps=n_steps-1)
 
-    if scan_updates:
-        raise NotImplementedError((
-                'TODO: check the scan updates to make sure that the s_rng is'
-                ' not being updated incorrectly'))
-    # undo half of the last leap-frog step
-    final_p = final_p - 0.5* stepsize * final_v
+    # The last velocity returned by the scan op is at time-step: t + n_steps* stepsize - 1/2
+    # We therefore perform one more half-step to return vel(t + n_steps*stepsize)
+    energy = energy_fn(final_p)
+    final_v = final_v - 0.5 * stepsize * TT.grad(energy.sum(), final_p)
     return final_p, final_v
 
+
 def mcmc_move(s_rng, positions, energy_fn, stepsize, n_steps, positions_shape=None):
     """Return new position
     """