Mercurial > pylearn
changeset 993:88107ec01ce8
mcRBM - cleaned up hmc_sampler
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 24 Aug 2010 14:54:25 -0400 |
parents | 30b7c4defb6c |
children | 610f563fb24a |
files | pylearn/algorithms/mcRBM.py |
diffstat | 1 files changed, 15 insertions(+), 11 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py Tue Aug 24 14:52:09 2010 -0400 +++ b/pylearn/algorithms/mcRBM.py Tue Aug 24 14:54:25 2010 -0400 @@ -359,22 +359,26 @@ def __setstate__(self, dct): self.__init__(**dct) # calls as_shared on pickled arrays - def hmc_sampler(self, n_particles=100, seed=7823748): + def hmc_sampler(self, n_particles, rng=7823748): + """Return an HMC_sampler that will draw samples from the distribution over visible + units specified by this RBM. + + :param n_particles: this many parallel chains will be simulated. + :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation. + """ + if not hasattr(rng, 'randn'): + rng = np.random.RandomState(rng) return HMC_sampler( positions = [as_shared( - np.random.RandomState(seed^20893).randn( + rng.randn( n_particles, self.n_visible ))], - energy_fn = lambda p : free_energy_given_v(self.params, p[0]), - seed=seed) + energy_fn = lambda p : self.free_energy_given_v(p[0]), + seed=int(rng.randint(2**30))) - def free_energy_given_v(self, v, extra=False): - assert 0 - rval = free_energy_given_v(self.params, v) - if extra: - return rval - else: - return rval[0] + def free_energy_given_v(self, v): + """Return expressions for F.E. of visible configuration `v`""" + return free_energy_given_v(self.params, v) def contrastive_gradient(self, *args, **kwargs): """Return a list of gradient expressions for self.params