changeset 996:60f279ec0f7f

mcRBM - made sampler a global function
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 15:50:14 -0400
parents 68ca3ea34e72
children 71b0132b694a
files pylearn/algorithms/mcRBM.py
diffstat 1 files changed, 34 insertions(+), 9 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py	Tue Aug 24 15:28:13 2010 -0400
+++ b/pylearn/algorithms/mcRBM.py	Tue Aug 24 15:50:14 2010 -0400
@@ -307,6 +307,39 @@
     g = nnet.sigmoid(c + dot(v,W))
     return (h, g)
 
+def n_visible_units(rbm):
+    """Return the number of visible units of this RBM
+
+    For an RBM made from shared variables, this will return an integer,
+    for a purely symbolic RBM this will return a theano expression.
+    
+    """
+    W = rbm[1]
+    try:
+        return W.value.shape[0]
+    except AttributeError:
+        return W.shape[0]
+
+
+def sampler(rbm, n_particles, rng=7823748):
+    """Return an `HMC_sampler` that will draw samples from the distribution over visible
+    units specified by this RBM.
+
+    :param n_particles: this many parallel chains will be simulated.
+    :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation.
+    """
+    if not hasattr(rng, 'randn'):
+        rng = np.random.RandomState(rng)
+
+    rval = HMC_sampler(
+        positions = [as_shared(
+            rng.randn(
+                n_particles,
+                n_visible_units(rbm)))],
+        energy_fn = lambda p : free_energy_given_v(rbm, p[0]),
+        seed=int(rng.randint(2**30)))
+    return rval
+
 class MeanCovRBM(object):
     """Container for mcRBM parameters that gives more convenient access to mcRBM methods.
 
@@ -374,15 +407,7 @@
         :param n_particles: this many parallel chains will be simulated.
         :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation.
         """
-        if not hasattr(rng, 'randn'):
-            rng = np.random.RandomState(rng)
-        return HMC_sampler(
-            positions = [as_shared(
-                rng.randn(
-                    n_particles,
-                    self.n_visible ))],
-            energy_fn = lambda p : self.free_energy_given_v(p[0]),
-            seed=int(rng.randint(2**30)))
+        return sampler(self.params, n_particles, rng)
 
     def free_energy_given_v(self, v):
         """Return expressions for F.E. of visible configuration `v`"""