# HG changeset patch # User James Bergstra # Date 1282676065 14400 # Node ID 88107ec01ce89684b6a99b17487cba4d62fa1497 # Parent 30b7c4defb6ce8a047b798f2644d76790d10345b mcRBM - cleaned up hmc_sampler diff -r 30b7c4defb6c -r 88107ec01ce8 pylearn/algorithms/mcRBM.py --- a/pylearn/algorithms/mcRBM.py Tue Aug 24 14:52:09 2010 -0400 +++ b/pylearn/algorithms/mcRBM.py Tue Aug 24 14:54:25 2010 -0400 @@ -359,22 +359,26 @@ def __setstate__(self, dct): self.__init__(**dct) # calls as_shared on pickled arrays - def hmc_sampler(self, n_particles=100, seed=7823748): + def hmc_sampler(self, n_particles, rng=7823748): + """Return an HMC_sampler that will draw samples from the distribution over visible + units specified by this RBM. + + :param n_particles: this many parallel chains will be simulated. + :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation. + """ + if not hasattr(rng, 'randn'): + rng = np.random.RandomState(rng) return HMC_sampler( positions = [as_shared( - np.random.RandomState(seed^20893).randn( + rng.randn( n_particles, self.n_visible ))], - energy_fn = lambda p : free_energy_given_v(self.params, p[0]), - seed=seed) + energy_fn = lambda p : self.free_energy_given_v(p[0]), + seed=int(rng.randint(2**30))) - def free_energy_given_v(self, v, extra=False): - assert 0 - rval = free_energy_given_v(self.params, v) - if extra: - return rval - else: - return rval[0] + def free_energy_given_v(self, v): + """Return expressions for F.E. of visible configuration `v`""" + return free_energy_given_v(self.params, v) def contrastive_gradient(self, *args, **kwargs): """Return a list of gradient expressions for self.params