# HG changeset patch # User James Bergstra # Date 1282679414 14400 # Node ID 60f279ec0f7ffd38e1364ca12980c593a4e1d991 # Parent 68ca3ea34e72839a048144ab1267f2c7756aa950 mcRBM - made sampler a global function diff -r 68ca3ea34e72 -r 60f279ec0f7f pylearn/algorithms/mcRBM.py --- a/pylearn/algorithms/mcRBM.py Tue Aug 24 15:28:13 2010 -0400 +++ b/pylearn/algorithms/mcRBM.py Tue Aug 24 15:50:14 2010 -0400 @@ -307,6 +307,39 @@ g = nnet.sigmoid(c + dot(v,W)) return (h, g) +def n_visible_units(rbm): + """Return the number of visible units of this RBM + + For an RBM made from shared variables, this will return an integer, + for a purely symbolic RBM this will return a theano expression. + + """ + W = rbm[1] + try: + return W.value.shape[0] + except AttributeError: + return W.shape[0] + + +def sampler(rbm, n_particles, rng=7823748): + """Return an `HMC_sampler` that will draw samples from the distribution over visible + units specified by this RBM. + + :param n_particles: this many parallel chains will be simulated. + :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation. + """ + if not hasattr(rng, 'randn'): + rng = np.random.RandomState(rng) + + rval = HMC_sampler( + positions = [as_shared( + rng.randn( + n_particles, + n_visible_units(rbm)))], + energy_fn = lambda p : free_energy_given_v(rbm, p[0]), + seed=int(rng.randint(2**30))) + return rval + class MeanCovRBM(object): """Container for mcRBM parameters that gives more convenient access to mcRBM methods. @@ -374,15 +407,7 @@ :param n_particles: this many parallel chains will be simulated. :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation. """ - if not hasattr(rng, 'randn'): - rng = np.random.RandomState(rng) - return HMC_sampler( - positions = [as_shared( - rng.randn( - n_particles, - self.n_visible ))], - energy_fn = lambda p : self.free_energy_given_v(p[0]), - seed=int(rng.randint(2**30))) + return sampler(self.params, n_particles, rng) def free_energy_given_v(self, v): """Return expressions for F.E. of visible configuration `v`"""