changeset 1280:1eaa015e88c1

extended hmc to allow for use as a cd1 sampler
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 15 Sep 2010 17:44:26 -0400
parents e198515bd4d4
children 5ae77ac21609
files pylearn/sampling/hmc.py
diffstat 1 files changed, 11 insertions(+), 5 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/sampling/hmc.py	Wed Sep 15 17:43:42 2010 -0400
+++ b/pylearn/sampling/hmc.py	Wed Sep 15 17:44:26 2010 -0400
@@ -165,7 +165,9 @@
             stepsize_max = 0.25,
             stepsize_inc = 1.02,
             avg_acceptance_slowness = 0.9, # used in geometric avg. 1.0 would be not moving at all
-            seed=12345, dtype=theano.config.floatX):
+            seed=12345, dtype=theano.config.floatX,
+            shared_positions_shape=None, 
+            compile_simulate=True):
         """
         :param shared_positions: theano ndarray shared var with many particle [initial] positions
         :param energy_fn:
@@ -178,8 +180,9 @@
         """
         # allocate shared vars
 
-        positions_shape = shared_positions.value.shape
-        batchsize = shared_positions.value.shape[0]
+        if shared_positions_shape==None:
+            shared_positions_shape = shared_positions.value.shape
+        batchsize = shared_positions_shape[0]
 
         stepsize = shared(numpy.asarray(initial_stepsize).astype(theano.config.floatX), 'hmc_stepsize')
         avg_acceptance_rate = shared(target_acceptance_rate, 'avg_acceptance_rate')
@@ -191,7 +194,7 @@
                 energy_fn,
                 stepsize, 
                 n_steps,
-                positions_shape)
+                shared_positions_shape)
         simulate_updates = mcmc_updates(
                 shared_positions,
                 stepsize,
@@ -204,7 +207,10 @@
                 stepsize_dec=stepsize_dec,
                 target_acceptance_rate=target_acceptance_rate,
                 avg_acceptance_slowness=avg_acceptance_slowness)
-        simulate = function([], [], updates=simulate_updates)
+        if compile_simulate:
+            simulate = function([], [], updates=simulate_updates)
+        else:
+            simulate = None
         return cls(
                 positions=shared_positions,
                 stepsize=stepsize,