changeset 1265:d6665a5af743

hmc - replaced with new refactored code
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 01 Sep 2010 17:40:21 -0400
parents 34512d1d4e9c
children bc4d98995bad
files pylearn/sampling/hmc.py pylearn/sampling/tests/test_hmc.py
diffstat 2 files changed, 210 insertions(+), 176 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sampling/hmc.py	Wed Sep 01 17:39:39 2010 -0400
+++ b/pylearn/sampling/hmc.py	Wed Sep 01 17:40:21 2010 -0400
@@ -1,203 +1,237 @@
+
 """Hybrid / Hamiltonian Monte Carlo Sampling
 
 This algorithm is described in Radford Neal's PhD Thesis, pages 63--70.
 
+:note:
+
+  The 'mass' of so-called particles is taken to be 1, so that kinetic energy (K) is the sum
+  of squared velocities (p).
+
+    :math:`K = \sum_i p_i^2 / 2`
+
+
+The 'leap-frog' algorithm that advances by turns the velocity and the position is
+currently implemented via several theano functions, rather than one complete theano
+expression graph.
+
+The initialize_dynamics() theano-function does several things:
+1. samples a random velocity for each particle (saving it to self.velocities) 
+2. calculates the initial hamiltonian based on those velocities (saving it to
+    self.initial_hamiltonian)
+3. saves self.positions to self.initial_positions.
+
+The finalize_dynamics() theano-function re-calculates the Hamiltonian for each particle
+based on the self.positions and self.velocities, and then implements the
+Metropolis-Hastings accept/reject for each particle in the simulation by consulting the
+self.initial_hamiltonian storing the result to self.
+
+
 """
 import sys
 import logging
 import numpy as np
 from theano import function, shared
 from theano import tensor as TT
-import theano.sparse #installs the sparse shared var handler
+import theano
+from theano.printing import Print
+
+def Print(msg):
+    return lambda x: x
+
+def kinetic_energy(velocities, masses):
+    if masses != 1:
+        raise NotImplementedError()
+    return 0.5 * (velocities**2).sum(axis=1)
+
+def metropolis_hastings_accept(energy_prev, energy_next, s_rng, shape=None):
+    """
+    Return acceptance of moves - energy_prev and energy_next are vectors of the energy of each
+    particle.
+    """
+    if shape is None:
+        shape = energy_prev.shape
+    ediff = Print('diff')(energy_prev - energy_next)
+    return (TT.exp(ediff) - s_rng.uniform(size=shape)) >= 0
 
+def hamiltonian(pos, vel, energy_fn):
+    """Return a vector of energies - sum of kinetic and potential energy
+    """
+    # assuming mass is 1
+    return energy_fn(pos) + kinetic_energy(vel, 1)
+
+def simulate_dynamics(initial_p, initial_v, stepsize, n_steps, energy_fn):
+    """Return final (position, velocity) of `n_step` trajectory
+    """
+    def leapfrog(pos, vel, step):
+        print 'LEAPFROG', vel, pos, step
+        egy = energy_fn(pos)
+        dE_dpos = TT.grad(egy.sum(), pos)
+        new_vel = vel - step * dE_dpos
+        new_pos = pos + step * new_vel
+        return [new_pos, new_vel],{}
+
+    (final_p, final_v), scan_updates = theano.scan(leapfrog, 
+            outputs_info=[
+                dict(initial=initial_p+ 0.5*stepsize*initial_v,
+                    return_steps=1),
+                dict(initial=initial_v,
+                    return_steps=1),
+                ],
+            non_sequences=[stepsize],
+            n_steps=n_steps)
+    print final_p.type
+    print final_v.type
 
-#TODO: 
-#  Consider how to refactor this code so that the key functions /functionality are provided as
-#  theano expressions??
-#  It could be limiting that the implementation requires that the position be a list of shared
-#  variables, and basically uses procedural logic on top of that.
+    if scan_updates:
+        raise NotImplementedError((
+                'TODO: check the scan updates to make sure that the s_rng is'
+                ' not being updated incorrectly'))
+    # undo half of the last leap-frog step
+    final_p = final_p - 0.5* stepsize * final_v
+    return final_p, final_v
+
+def mcmc_move(s_rng, positions, energy_fn, stepsize, n_steps, positions_shape=None):
+    """Return new position
+    """
+    if positions_shape is None:
+        positions_shape = positions.shape
+
+    batchsize = positions_shape[0]
+
+    initial_v = s_rng.normal(size=positions_shape)
+
+    final_p, final_v = simulate_dynamics(
+            initial_p = positions, 
+            initial_v = initial_v,
+            stepsize = stepsize,
+            n_steps = n_steps,
+            energy_fn = energy_fn)
 
-#TODO: 
-# Consider a heuristic for updating the *MASS* of the particles.  We might want the mass to be
-# such that the momentum is in the same range as the gradient on the energy.  Look at Radford's
-# recent book chapter, maybe there are hints. (2010).
+    accept = metropolis_hastings_accept(
+            energy_prev = Print('ep')(hamiltonian(positions, initial_v, energy_fn)),
+            energy_next = Print('en')(hamiltonian(final_p, final_v, energy_fn)),
+            s_rng=s_rng, shape=(batchsize,))
+    
+    return Print('accept')(accept), final_p
+
+def mcmc_updates(shrd_pos, shrd_stepsize, shrd_avg_acceptance_rate, final_p, accept, 
+        target_acceptance_rate,
+        stepsize_inc,
+        stepsize_dec,
+        stepsize_min,
+        stepsize_max,
+        avg_acceptance_slowness
+        ):
+    return [
+            (shrd_pos,
+                TT.switch(
+                    accept.dimshuffle(0, *(('x',)*(final_p.ndim-1))),
+                    final_p,
+                    shrd_pos)),
+            (shrd_stepsize, 
+                TT.clip(
+                    TT.switch( 
+                        shrd_avg_acceptance_rate > target_acceptance_rate,
+                        shrd_stepsize * stepsize_inc,
+                        shrd_stepsize * stepsize_dec,
+                        ),
+                    stepsize_min,
+                    stepsize_max)),
+            (shrd_avg_acceptance_rate,
+                Print('arate')(TT.add(
+                    avg_acceptance_slowness * shrd_avg_acceptance_rate,
+                    (1.0 - avg_acceptance_slowness) * accept.mean()))),
+            ]
 
 class HMC_sampler(object):
-    """Batch-wise Hybrid Monte-Carlo sampler
-
+    """Convenience wrapper for HMC
 
     The `draw` function advances the markov chain and returns the current sample by calling
     `simulate` and `get_position` in sequence.
 
-
-    :note:
-
-      The 'mass' of so-called particles is taken to be 1, so that kinetic energy (K) is the sum
-      of squared velocities (p).
-
-        :math:`K = \sum_i p_i^2 / 2`
-
     """
 
     # Constants taken from Marc'Aurelio's 'train_mcRBM.py' file found in the code online for his
     # paper.
-    stepsize_dec = 0.98
-    stepsize_min = 0.001
-    stepsize_max = 0.25
-    stepsize_inc = 1.02
-    avg_acceptance_slowness = 0.9  # used in geometric avg. 1.0 would be not moving at all
-    n_steps=20
+
+    def __init__(self, **kwargs):
+        # add things to __dict__
+        self.__dict__.update(kwargs)
 
-    def __init__(self, positions, energy_fn, 
-            velocity=None,
-            initial_stepsize=0.01,
-            target_acceptance_rate=0.9,
-            seed=12345,
-            dtype=theano.config.floatX):
+    @classmethod
+    def new_from_shared_positions(cls, shared_positions, energy_fn, 
+            initial_stepsize=0.01, target_acceptance_rate=.9, n_steps=20,
+            stepsize_dec = 0.98,
+            stepsize_min = 0.001,
+            stepsize_max = 0.25,
+            stepsize_inc = 1.02,
+            avg_acceptance_slowness = 0.9, # used in geometric avg. 1.0 would be not moving at all
+            seed=12345, dtype=theano.config.floatX):
         """
-        :param positions: list of shared ndarray variables.
-
-        :param energy: 
-
+        :param shared_positions: theano ndarray shared var with many particle [initial] positions
+        :param energy_fn:
             callable such that energy_fn(positions) 
             returns theano vector of energies.  
             The len of this vector is the batchsize.
 
             The sum of this energy vector must be differentiable (with theano.tensor.grad) with
             respect to the positions for HMC sampling to work.
-
         """
-
-        self.stepsize = initial_stepsize
-        batchsize = positions[0].value.shape[0]
-        self.target_acceptance_rate = target_acceptance_rate
-        self.avg_acceptance_rate = self.target_acceptance_rate
-        self.s_rng = TT.shared_randomstreams.RandomStreams(seed)
-        self.positions = positions
-        if velocity is None:
-            self.velocity = [shared(np.zeros_like(q.value)) for q in self.positions]
-        else:
-            self.velocity = velocity
-
-        self.start_positions = [shared(np.zeros_like(q.value)) for q in self.positions]
-        self.initial_hamiltonian = shared(np.zeros(batchsize).astype(dtype) + float('inf'))
-
-        energy = energy_fn(positions)
-        # assuming mass is 1
-        kinetic_energy = 0.5 * sum([(p**2).sum(axis=1) for p in self.velocity])
-        hamiltonian = energy + kinetic_energy
-
-        dE_dpos = TT.grad(energy.sum(), self.positions)
-
-        s_stepsize = TT.scalar('stepsize')
-
-        self.velocity_step = function([s_stepsize],
-                [dE_dpos[0].norm(2)],
-                updates=[(p, p - s_stepsize*dE_dq) for p,dE_dq in zip(self.velocity,dE_dpos)])
-        self.position_step = function([s_stepsize], [],
-                updates=[(q, q + s_stepsize*p) for (q,p) in zip(self.positions, self.velocity)])
+        # allocate shared vars
 
-        self.save_start_positions = function([], [],
-                updates=[(self.initial_hamiltonian, hamiltonian)] + \
-                        [(sq, q) for sq, q in zip(self.start_positions, self.positions)])
-
-        # sample the initial velocity from a
-        # gaussian with mean 0.
-        # Note: I think the fact that this distribution is symmetric about zero justifies not
-        # sampling forward versus backward dynamics.
-        self.randomize_velocity = function([], [],
-                updates=[(p, self.s_rng.normal(size=p.value.shape)) for p in self.velocity])
-
-
-        # accept-reject according to Metropolis algo
-
-        accept = TT.exp(self.initial_hamiltonian - hamiltonian) - self.s_rng.uniform(size=(batchsize,)) >= 0
+        positions_shape = shared_positions.value.shape
+        batchsize = shared_positions.value.shape[0]
 
-        self.accept_reject_positions = function([], accept.mean(),
-                updates=[
-                    (self.initial_hamiltonian, 
-                        TT.switch(accept, hamiltonian, self.initial_hamiltonian))] + [
-                    (q, 
-                        TT.switch(accept.dimshuffle(0, *(('x',)*(sq.ndim-1))), q, sq)) 
-                        for sq, q in zip(self.start_positions, self.positions)])
-
-    def simulate(self, n_steps=None):
-        if n_steps is None:
-            n_steps = self.n_steps
-
-        # updates self.velocity with new random numbers
-        self.randomize_velocity()
-        self.save_start_positions()
+        stepsize = shared(initial_stepsize, 'hmc_stepsize')
+        avg_acceptance_rate = shared(target_acceptance_rate, 'avg_acceptance_rate')
+        s_rng = TT.shared_randomstreams.RandomStreams(seed)
 
-        if 0:
-            # not necessary for random initial direction
-            if np.random.rand() > 0.5:
-                step_amount = self.stepsize
-            else:
-                step_amount = -self.stepsize
-        else:
-            step_amount = self.stepsize
-
-        if 0:
-            print "initial",
-            print "kinetic E", self.prev_energy.value ,
-            print (self.velocity[0].value**2).sum(axis=1),
-            print (self.velocity[0].value**2).sum(axis=1) + self.prev_energy.value
-
-
-        # Note on the order of leap-frog steps:
-        #
-        # the leap-frog algorithm can start with either position or velocity updates first,
-        # but one of them has to run an extra time (because of two half-steps).
-        # The position_step is cheap to evaluate, whereas the velocity_step is expensive,
-        # so we evaluate the position_step the extra time.
-        #
-        # At the same time, we cannot drop the last half-step update of the position, because
-        # the position is actually the terms we care about.
-
-        #opening half-step in leap-frog algorithm
-        self.position_step(step_amount/2.0)
+        accept, final_p = mcmc_move(
+                s_rng, 
+                shared_positions, 
+                energy_fn,
+                stepsize, 
+                n_steps,
+                positions_shape)
+        simulate = function([], [],
+                updates=mcmc_updates(
+                    shared_positions,
+                    stepsize,
+                    avg_acceptance_rate, 
+                    final_p=final_p, 
+                    accept=accept,
+                    stepsize_min=stepsize_min,
+                    stepsize_max=stepsize_max,
+                    stepsize_inc=stepsize_inc,
+                    stepsize_dec=stepsize_dec,
+                    target_acceptance_rate=target_acceptance_rate,
+                    avg_acceptance_slowness=avg_acceptance_slowness))
+        return cls(
+                positions=shared_positions,
+                stepsize=stepsize,
+                stepsize_min=stepsize_min,
+                stepsize_max=stepsize_max,
+                avg_acceptance_rate=avg_acceptance_rate,
+                target_acceptance_rate=target_acceptance_rate,
+                s_rng=s_rng,
+                simulate=simulate)
 
-        # full leap-frog steps
-        for ss in range(n_steps):
-            self.velocity_step(step_amount)
-            if ss == n_steps-1:
-                # closing half-step in the leap-frog algorithm
-                self.position_step(step_amount/2.0)
-            else:
-                self.position_step(step_amount)
-
-        acceptance_rate = self.accept_reject_positions()
-        self.avg_acceptance_rate = self.avg_acceptance_slowness * self.avg_acceptance_rate \
-                + (1.0 - self.avg_acceptance_slowness) * acceptance_rate
-
-        if self.avg_acceptance_rate < self.target_acceptance_rate:
-            self.stepsize = max(self.stepsize*self.stepsize_dec,self.stepsize_min)
-        else:
-            self.stepsize = min(self.stepsize*self.stepsize_inc,self.stepsize_max)
-
-        if 0:
-            print "final kinetic E", self.prev_energy.value ,
-            print (self.velocity[0].value**2).sum(axis=1),
-            print (self.velocity[0].value**2).sum(axis=1) + self.prev_energy.value
-
-
-        # post-condition: 
-        # self.positions contains a new sample from our markov chain
-
-        # it is not returned from this function because accessing the .value of a shared
-        # variable can require making a copy
-        # see `draw()` or `get_position` for that behaviour.
-
-    def get_position(self):
-        return [q.value.copy() for q in self.positions]
-
-    def draw(self, n_steps=None):
+    def draw(self, **kwargs):
         """Return the current sample in the Markov chain as a list of numpy arrays
 
         The size of the arrays will match the size of the `initial_position` argument to
         __init__.
+
+        The `kwargs` dictionary is passed to the shared variable (self.positions) `get_value()`
+        function.  So for example, to avoid copying the shared variable value, consider passing
+        `borrow=True`.
         """
-        self.simulate(n_steps=n_steps)
-        return self.get_position()
+        self.simulate()
+        return self.positions.value.copy()
 
+#TODO: 
+# Consider a heuristic for updating the *MASS* of the particles.  We might want the mass to be
+# such that the momentum is in the same range as the gradient on the energy.  Look at Radford's
+# recent book chapter, maybe there are hints. (2010).
+
--- a/pylearn/sampling/tests/test_hmc.py	Wed Sep 01 17:39:39 2010 -0400
+++ b/pylearn/sampling/tests/test_hmc.py	Wed Sep 01 17:40:21 2010 -0400
@@ -1,4 +1,4 @@
-from pylearn.sampling.hmc import *
+from pylearn.sampling.hmc2 import *
 
 def _sampler_on_2d_gaussian(sampler_cls, burnin, n_samples):
     batchsize=3
@@ -15,15 +15,15 @@
     mu = np.asarray([5, 9.5], dtype=theano.config.floatX)
 
     def gaussian_energy(xlist):
-        x, = xlist
+        x = xlist
         return 0.5 * (TT.dot((x-mu),cov_inv)*(x-mu)).sum(axis=1)
 
 
     position = shared(rng.randn(batchsize, 2).astype(theano.config.floatX))
-    sampler = sampler_cls([position], gaussian_energy)
+    sampler = sampler_cls(position, gaussian_energy)
 
     print 'initial position', position.value
-    print 'initial stepsize', sampler.stepsize
+    print 'initial stepsize', sampler.stepsize.value
 
     # DRAW SAMPLES
 
@@ -35,31 +35,31 @@
 
     # TEST THAT THEY ARE FROM THE RIGHT DISTRIBUTION
 
-    # samples.shape == (1000, 1, 3, 2)
+    # samples.shape == (1000, 3, 2)
 
     print 'target mean:', mu
-    print 'empirical mean: ', samples.mean(axis=0)[0]
+    print 'empirical mean: ', samples.mean(axis=0)
     #assert np.all(abs(mu - samples.mean(axis=0)) < 1)
 
 
-    print 'final stepsize', sampler.stepsize
-    print 'final acceptance_rate', sampler.avg_acceptance_rate
+    print 'final stepsize', sampler.stepsize.value
+    print 'final acceptance_rate', sampler.avg_acceptance_rate.value
 
     print 'target cov', cov
-    s = samples[:,0,0,:]
-    empirical_cov = np.cov(samples[:,0,0,:].T)
+    s = samples[:,0,:]
+    empirical_cov = np.cov(samples[:,0,:].T)
     print ''
     print 'cov/empirical_cov', cov/empirical_cov
-    empirical_cov = np.cov(samples[:,0,1,:].T)
+    empirical_cov = np.cov(samples[:,1,:].T)
     print 'cov/empirical_cov', cov/empirical_cov
-    empirical_cov = np.cov(samples[:,0,2,:].T)
+    empirical_cov = np.cov(samples[:,2,:].T)
     print 'cov/empirical_cov', cov/empirical_cov
     return sampler
 
 def test_hmc():
     print ('HMC')
-    sampler = _sampler_on_2d_gaussian(HMC_sampler, burnin=3000/20, n_samples=90000/20)
+    sampler = _sampler_on_2d_gaussian(HMC_sampler.new_from_shared_positions, burnin=3000/20, n_samples=90000/20)
     assert abs(sampler.avg_acceptance_rate - sampler.target_acceptance_rate) < .1
-    assert sampler.stepsize >= sampler.stepsize_min
-    assert sampler.stepsize <= sampler.stepsize_max
+    assert sampler.stepsize.value >= sampler.stepsize_min
+    assert sampler.stepsize.value <= sampler.stepsize_max