changeset 993:88107ec01ce8

mcRBM - cleaned up hmc_sampler
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 14:54:25 -0400
parents 30b7c4defb6c
children 610f563fb24a
files pylearn/algorithms/mcRBM.py
diffstat 1 files changed, 15 insertions(+), 11 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py	Tue Aug 24 14:52:09 2010 -0400
+++ b/pylearn/algorithms/mcRBM.py	Tue Aug 24 14:54:25 2010 -0400
@@ -359,22 +359,26 @@
     def __setstate__(self, dct):
         self.__init__(**dct) # calls as_shared on pickled arrays
 
-    def hmc_sampler(self, n_particles=100, seed=7823748):
+    def hmc_sampler(self, n_particles, rng=7823748):
+        """Return an HMC_sampler that will draw samples from the distribution over visible
+        units specified by this RBM.
+
+        :param n_particles: this many parallel chains will be simulated.
+        :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation.
+        """
+        if not hasattr(rng, 'randn'):
+            rng = np.random.RandomState(rng)
         return HMC_sampler(
             positions = [as_shared(
-                np.random.RandomState(seed^20893).randn(
+                rng.randn(
                     n_particles,
                     self.n_visible ))],
-            energy_fn = lambda p : free_energy_given_v(self.params, p[0]),
-            seed=seed)
+            energy_fn = lambda p : self.free_energy_given_v(p[0]),
+            seed=int(rng.randint(2**30)))
 
-    def free_energy_given_v(self, v, extra=False):
-        assert 0
-        rval = free_energy_given_v(self.params, v)
-        if extra:
-            return rval
-        else:
-            return rval[0]
+    def free_energy_given_v(self, v):
+        """Return expressions for F.E. of visible configuration `v`"""
+        return free_energy_given_v(self.params, v)
 
     def contrastive_gradient(self, *args, **kwargs):
         """Return a list of gradient expressions for self.params