Mercurial > pylearn
changeset 996:60f279ec0f7f
mcRBM - made sampler a global function
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 24 Aug 2010 15:50:14 -0400 |
parents | 68ca3ea34e72 |
children | 71b0132b694a |
files | pylearn/algorithms/mcRBM.py |
diffstat | 1 files changed, 34 insertions(+), 9 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py Tue Aug 24 15:28:13 2010 -0400 +++ b/pylearn/algorithms/mcRBM.py Tue Aug 24 15:50:14 2010 -0400 @@ -307,6 +307,39 @@ g = nnet.sigmoid(c + dot(v,W)) return (h, g) +def n_visible_units(rbm): + """Return the number of visible units of this RBM + + For an RBM made from shared variables, this will return an integer, + for a purely symbolic RBM this will return a theano expression. + + """ + W = rbm[1] + try: + return W.value.shape[0] + except AttributeError: + return W.shape[0] + + +def sampler(rbm, n_particles, rng=7823748): + """Return an `HMC_sampler` that will draw samples from the distribution over visible + units specified by this RBM. + + :param n_particles: this many parallel chains will be simulated. + :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation. + """ + if not hasattr(rng, 'randn'): + rng = np.random.RandomState(rng) + + rval = HMC_sampler( + positions = [as_shared( + rng.randn( + n_particles, + n_visible_units(rbm)))], + energy_fn = lambda p : free_energy_given_v(rbm, p[0]), + seed=int(rng.randint(2**30))) + return rval + class MeanCovRBM(object): """Container for mcRBM parameters that gives more convenient access to mcRBM methods. @@ -374,15 +407,7 @@ :param n_particles: this many parallel chains will be simulated. :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation. """ - if not hasattr(rng, 'randn'): - rng = np.random.RandomState(rng) - return HMC_sampler( - positions = [as_shared( - rng.randn( - n_particles, - self.n_visible ))], - energy_fn = lambda p : self.free_energy_given_v(p[0]), - seed=int(rng.randint(2**30))) + return sampler(self.params, n_particles, rng) def free_energy_given_v(self, v): """Return expressions for F.E. of visible configuration `v`"""