# HG changeset patch # User James Bergstra # Date 1282233618 14400 # Node ID 492473059b370c57f0a7c7b82b21ec9816a17c41 # Parent 9b0fd89599c7fb14a7b6cdbdf1a3c7c43459cc82 Adding sampling module diff -r 9b0fd89599c7 -r 492473059b37 pylearn/sampling/README.txt --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/sampling/README.txt Thu Aug 19 12:00:18 2010 -0400 @@ -0,0 +1,2 @@ +See __init__.py + diff -r 9b0fd89599c7 -r 492473059b37 pylearn/sampling/__init__.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/sampling/__init__.py Thu Aug 19 12:00:18 2010 -0400 @@ -0,0 +1,17 @@ +"""Sampling + +This module [will] contain theano-related code for various sampling algorithms, such as for +example: + + - MCMC + + - [Block] Gibbs Sampling + + - Slice sampling + + - HMC + + - Tempering methods + + +""" diff -r 9b0fd89599c7 -r 492473059b37 pylearn/sampling/hmc.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/sampling/hmc.py Thu Aug 19 12:00:18 2010 -0400 @@ -0,0 +1,203 @@ +"""Hybrid / Hamiltonian Monte Carlo Sampling + +This algorithm is described in Radford Neal's PhD Thesis, pages 63--70. + +""" +import sys +import logging +import numpy as np +from theano import function, shared +from theano import tensor as TT +import theano.sparse #installs the sparse shared var handler + + +#TODO: +# Consider how to refactor this code so that the key functions /functionality are provided as +# theano expressions?? +# It could be limiting that the implementation requires that the position be a list of shared +# variables, and basically uses procedural logic on top of that. + +#TODO: +# Consider a heuristic for updating the *MASS* of the particles. We might want the mass to be +# such that the momentum is in the same range as the gradient on the energy. Look at Radford's +# recent book chapter, maybe there are hints. (2010). + +class HMC_sampler(object): + """Batch-wise Hybrid Monte-Carlo sampler + + + The `draw` function advances the markov chain and returns the current sample by calling + `simulate` and `get_position` in sequence. + + + :note: + + The 'mass' of so-called particles is taken to be 1, so that kinetic energy (K) is the sum + of squared velocities (p). + + :math:`K = \sum_i p_i^2 / 2` + + """ + + # Constants taken from Marc'Aurelio's 'train_mcRBM.py' file found in the code online for his + # paper. + stepsize_dec = 0.98 + stepsize_min = 0.001 + stepsize_max = 0.25 + stepsize_inc = 1.02 + avg_acceptance_slowness = 0.9 # used in geometric avg. 1.0 would be not moving at all + n_steps=20 + + def __init__(self, positions, energy_fn, + velocity=None, + initial_stepsize=0.01, + target_acceptance_rate=0.9, + seed=12345, + dtype=theano.config.floatX): + """ + :param positions: list of shared ndarray variables. + + :param energy: + + callable such that energy_fn(positions) + returns theano vector of energies. + The len of this vector is the batchsize. + + The sum of this energy vector must be differentiable (with theano.tensor.grad) with + respect to the positions for HMC sampling to work. + + """ + + self.stepsize = initial_stepsize + batchsize = positions[0].value.shape[0] + self.target_acceptance_rate = target_acceptance_rate + self.avg_acceptance_rate = self.target_acceptance_rate + self.s_rng = TT.shared_randomstreams.RandomStreams(seed) + self.positions = positions + if velocity is None: + self.velocity = [shared(np.zeros_like(q.value)) for q in self.positions] + else: + self.velocity = velocity + + self.start_positions = [shared(np.zeros_like(q.value)) for q in self.positions] + self.initial_hamiltonian = shared(np.zeros(batchsize).astype(dtype) + float('inf')) + + energy = energy_fn(positions) + # assuming mass is 1 + kinetic_energy = 0.5 * sum([(p**2).sum(axis=1) for p in self.velocity]) + hamiltonian = energy + kinetic_energy + + dE_dpos = TT.grad(energy.sum(), self.positions) + + s_stepsize = TT.scalar('stepsize') + + self.velocity_step = function([s_stepsize], + [dE_dpos[0].norm(2)], + updates=[(p, p - s_stepsize*dE_dq) for p,dE_dq in zip(self.velocity,dE_dpos)]) + self.position_step = function([s_stepsize], [], + updates=[(q, q + s_stepsize*p) for (q,p) in zip(self.positions, self.velocity)]) + + self.save_start_positions = function([], [], + updates=[(self.initial_hamiltonian, hamiltonian)] + \ + [(sq, q) for sq, q in zip(self.start_positions, self.positions)]) + + # sample the initial velocity from a + # gaussian with mean 0. + # Note: I think the fact that this distribution is symmetric about zero justifies not + # sampling forward versus backward dynamics. + self.randomize_velocity = function([], [], + updates=[(p, self.s_rng.normal(size=p.value.shape)) for p in self.velocity]) + + + # accept-reject according to Metropolis algo + + accept = TT.exp(self.initial_hamiltonian - hamiltonian) - self.s_rng.uniform(size=(batchsize,)) >= 0 + + self.accept_reject_positions = function([], accept.mean(), + updates=[ + (self.initial_hamiltonian, + TT.switch(accept, hamiltonian, self.initial_hamiltonian))] + [ + (q, + TT.switch(accept.dimshuffle(0, *(('x',)*(sq.ndim-1))), q, sq)) + for sq, q in zip(self.start_positions, self.positions)]) + + def simulate(self, n_steps=None): + if n_steps is None: + n_steps = self.n_steps + + # updates self.velocity with new random numbers + self.randomize_velocity() + self.save_start_positions() + + if 0: + # not necessary for random initial direction + if np.random.rand() > 0.5: + step_amount = self.stepsize + else: + step_amount = -self.stepsize + else: + step_amount = self.stepsize + + if 0: + print "initial", + print "kinetic E", self.prev_energy.value , + print (self.velocity[0].value**2).sum(axis=1), + print (self.velocity[0].value**2).sum(axis=1) + self.prev_energy.value + + + # Note on the order of leap-frog steps: + # + # the leap-frog algorithm can start with either position or velocity updates first, + # but one of them has to run an extra time (because of two half-steps). + # The position_step is cheap to evaluate, whereas the velocity_step is expensive, + # so we evaluate the position_step the extra time. + # + # At the same time, we cannot drop the last half-step update of the position, because + # the position is actually the terms we care about. + + #opening half-step in leap-frog algorithm + self.position_step(step_amount/2.0) + + # full leap-frog steps + for ss in range(n_steps): + self.velocity_step(step_amount) + if ss == n_steps-1: + # closing half-step in the leap-frog algorithm + self.position_step(step_amount/2.0) + else: + self.position_step(step_amount) + + acceptance_rate = self.accept_reject_positions() + self.avg_acceptance_rate = self.avg_acceptance_slowness * self.avg_acceptance_rate \ + + (1.0 - self.avg_acceptance_slowness) * acceptance_rate + + if self.avg_acceptance_rate < self.target_acceptance_rate: + self.stepsize = max(self.stepsize*self.stepsize_dec,self.stepsize_min) + else: + self.stepsize = min(self.stepsize*self.stepsize_inc,self.stepsize_max) + + if 0: + print "final kinetic E", self.prev_energy.value , + print (self.velocity[0].value**2).sum(axis=1), + print (self.velocity[0].value**2).sum(axis=1) + self.prev_energy.value + + + # post-condition: + # self.positions contains a new sample from our markov chain + + # it is not returned from this function because accessing the .value of a shared + # variable can require making a copy + # see `draw()` or `get_position` for that behaviour. + + def get_position(self): + return [q.value for q in self.positions] + + def draw(self, n_steps=None): + """Return the current sample in the Markov chain as a list of numpy arrays + + The size of the arrays will match the size of the `initial_position` argument to + __init__. + """ + self.simulate(n_steps=n_steps) + return self.get_position() + diff -r 9b0fd89599c7 -r 492473059b37 pylearn/sampling/mcmc.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/sampling/mcmc.py Thu Aug 19 12:00:18 2010 -0400 @@ -0,0 +1,103 @@ +"""Metropolis-Hastings Monte Carlo + + +See also Hamiltonian Monte Carlo (aka Hybrid Monte Carlo) in the hmc.py file. + +""" +import sys +import logging +import numpy as np +from theano import function, shared +from theano import tensor as TT +import theano.sparse #installs the sparse shared var handler +from theano.printing import Print + +#TODO: +# Consider how to refactor this code so that the key functions /functionality are provided as +# theano expressions?? +# It could be limiting that the implementation requires that the position be a list of shared +# variables, and basically uses procedural logic on top of that. +# +# Theano can only do function optimization on one function at a time, so until these are +# written as update expressions, they must be compiled as their own functions and cannot be +# rolled into things the user is doing anyway. + +class MCMC_sampler(object): + """Generic Metropolis-Hastings Monte-Carlo sampler + + The `draw` function advances the markov chain and returns the current sample by calling + `simulate` and `get_position` in sequence. + + The stepsize is updated dynamically to + + """ + + # TODO: anyone has better guesses for these constants?? + stepsize_dec = 0.98 + stepsize_min = 0.001 + stepsize_max = 0.25 + stepsize_inc = 1.02 + target_acceptance_rate=0.7 + + avg_acceptance_slowness = 0.9 # used in geometric avg. 1.0 would be not moving at all + n_steps=1 + + def __init__(self, positions, energy_fn, + initial_stepsize=0.01, + seed=12345): + """ + :param positions: list of shared ndarray variables. + + :param energy: + + callable such that energy_fn(positions) + returns theano vector of energies. + The len of this vector is the batchsize. + """ + + batchsize = positions[0].value.shape[0] + self.s_rng = TT.shared_randomstreams.RandomStreams(seed) + self.positions = positions + self.prev_energy = shared(np.zeros(batchsize) + float('inf')) + self.avg_acceptance_rate = 0.5 + self.stepsize = initial_stepsize + + s_stepsize = TT.scalar('stepsize') + + new_positions = [p + s_stepsize * self.s_rng.normal(size=p.value.shape) + for p in self.positions] + + # accept-reject according to Metropolis-Hastings + + energy = energy_fn(new_positions) + accept = TT.exp(self.prev_energy - energy) - self.s_rng.uniform(size=(batchsize,)) >= 0 + + self.accept_reject_positions = function([s_stepsize], accept.mean(), + updates=[(self.prev_energy, TT.switch(accept, energy, self.prev_energy))] + [ + (q, TT.switch(accept.dimshuffle(0, *(('x',)*(q.ndim-1))), new_q, q)) + for q, new_q in zip(self.positions, new_positions)]) + + def simulate(self, n_steps=None): + if n_steps is None: + n_steps = self.n_steps + for ss in range(n_steps): + acceptance_rate = self.accept_reject_positions(self.stepsize) + self.avg_acceptance_rate = self.avg_acceptance_slowness * self.avg_acceptance_rate \ + + (1.0 - self.avg_acceptance_slowness) * acceptance_rate + if self.avg_acceptance_rate < self.target_acceptance_rate: + self.stepsize = max(self.stepsize*self.stepsize_dec,self.stepsize_min) + else: + self.stepsize = min(self.stepsize*self.stepsize_inc,self.stepsize_max) + + def get_position(self): + return [q.value for q in self.positions] + + def draw(self, n_steps=None): + """Return the current sample in the Markov chain as a list of numpy arrays + + The size of the arrays will match the size of the `initial_position` argument to + __init__. + """ + self.simulate(n_steps=n_steps) + return self.get_position() + diff -r 9b0fd89599c7 -r 492473059b37 pylearn/sampling/tests/test_hmc.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/sampling/tests/test_hmc.py Thu Aug 19 12:00:18 2010 -0400 @@ -0,0 +1,65 @@ +from pylearn.sampling.hmc import * + +def _sampler_on_2d_gaussian(sampler_cls, burnin, n_samples): + batchsize=3 + + rng = np.random.RandomState(234) + + # + # Define a covariance and mu for a gaussian + # + tmp = rng.randn(2,2).astype(theano.config.floatX) + tmp[0] += tmp[1] #induce some covariance + cov = np.dot(tmp, tmp.T) + cov_inv = np.linalg.inv(cov).astype(theano.config.floatX) + mu = np.asarray([5, 9.5], dtype=theano.config.floatX) + + def gaussian_energy(xlist): + x, = xlist + return 0.5 * (TT.dot((x-mu),cov_inv)*(x-mu)).sum(axis=1) + + + position = shared(rng.randn(batchsize, 2).astype(theano.config.floatX)) + sampler = sampler_cls([position], gaussian_energy) + + print 'initial position', position.value + print 'initial stepsize', sampler.stepsize + + # DRAW SAMPLES + + samples = [sampler.draw() for r in xrange(burnin)] #burn-in + samples = np.asarray([sampler.draw() for r in xrange(n_samples)]) + + assert sampler.avg_acceptance_rate > 0 + assert sampler.avg_acceptance_rate < 1 + + # TEST THAT THEY ARE FROM THE RIGHT DISTRIBUTION + + # samples.shape == (1000, 1, 3, 2) + + print 'target mean:', mu + print 'empirical mean: ', samples.mean(axis=0)[0] + #assert np.all(abs(mu - samples.mean(axis=0)) < 1) + + + print 'final stepsize', sampler.stepsize + print 'final acceptance_rate', sampler.avg_acceptance_rate + + print 'target cov', cov + s = samples[:,0,0,:] + empirical_cov = np.cov(samples[:,0,0,:].T) + print '' + print 'cov/empirical_cov', cov/empirical_cov + empirical_cov = np.cov(samples[:,0,1,:].T) + print 'cov/empirical_cov', cov/empirical_cov + empirical_cov = np.cov(samples[:,0,2,:].T) + print 'cov/empirical_cov', cov/empirical_cov + return sampler + +def test_hmc(): + print ('HMC') + sampler = _sampler_on_2d_gaussian(HMC_sampler, burnin=3000/20, n_samples=90000/20) + assert abs(sampler.avg_acceptance_rate - sampler.target_acceptance_rate) < .1 + assert sampler.stepsize >= sampler.stepsize_min + assert sampler.stepsize <= sampler.stepsize_max + diff -r 9b0fd89599c7 -r 492473059b37 pylearn/sampling/tests/test_mcmc.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/sampling/tests/test_mcmc.py Thu Aug 19 12:00:18 2010 -0400 @@ -0,0 +1,61 @@ +from pylearn.sampling.mcmc import * + +def _sampler_on_2d_gaussian(sampler_cls, burnin, n_samples): + batchsize=3 + + rng = np.random.RandomState(234) + + # + # Define a covariance and mu for a gaussian + # + tmp = rng.randn(2,2) + tmp[0] += tmp[1] #induce some covariance + cov = np.dot(tmp, tmp.T) + cov_inv = np.linalg.inv(cov) + mu = np.asarray([5, 9.5], dtype=theano.config.floatX) + + def gaussian_energy(xlist): + x, = xlist + return 0.5 * (TT.dot((x-mu),cov_inv)*(x-mu)).sum(axis=1) + + + position = shared(rng.randn(batchsize, 2).astype(theano.config.floatX)) + sampler = sampler_cls([position], gaussian_energy) + + print 'initial position', position.value + print 'initial stepsize', sampler.stepsize + + # DRAW SAMPLES + + samples = [sampler.draw() for r in xrange(burnin)] #burn-in + samples = np.asarray([sampler.draw() for r in xrange(n_samples)]) + + assert sampler.avg_acceptance_rate > 0 + assert sampler.avg_acceptance_rate < 1 + + # TEST THAT THEY ARE FROM THE RIGHT DISTRIBUTION + + # samples.shape == (1000, 1, 3, 2) + + print 'target mean:', mu + print 'empirical mean: ', samples.mean(axis=0)[0] + #assert np.all(abs(mu - samples.mean(axis=0)) < 1) + + + print 'final stepsize', sampler.stepsize + print 'final acceptance_rate', sampler.avg_acceptance_rate + + print 'target cov', cov + s = samples[:,0,0,:] + empirical_cov = np.cov(samples[:,0,0,:].T) + print '' + print 'cov/empirical_cov', cov/empirical_cov + empirical_cov = np.cov(samples[:,0,1,:].T) + print 'cov/empirical_cov', cov/empirical_cov + empirical_cov = np.cov(samples[:,0,2,:].T) + print 'cov/empirical_cov', cov/empirical_cov + return sampler + +def test_mcmc(): + print ('MCMC') + sampler = _sampler_on_2d_gaussian(MCMC_sampler, burnin=3000, n_samples=90000)