comparison pylearn/sampling/tests/test_hmc.py @ 1503:1ee532a6f33b

Fix import.
author Frederic Bastien <nouiz@nouiz.org>
date Mon, 12 Sep 2011 10:24:24 -0400
parents fbe470217937
children 5804e44d7a1b
comparison
equal deleted inserted replaced
1502:4fa5ebe8a7ad 1503:1ee532a6f33b
1 from pylearn.sampling.hmc import * 1 import numpy
2 import theano
3 from theano import tensor
4
5 from pylearn.sampling.hmc import HMC_sampler
2 6
3 def _sampler_on_2d_gaussian(sampler_cls, burnin, n_samples): 7 def _sampler_on_2d_gaussian(sampler_cls, burnin, n_samples):
4 batchsize=3 8 batchsize=3
5 9
6 rng = np.random.RandomState(234) 10 rng = numpy.random.RandomState(234)
7 11
8 # 12 #
9 # Define a covariance and mu for a gaussian 13 # Define a covariance and mu for a gaussian
10 # 14 #
11 tmp = rng.randn(2,2).astype(theano.config.floatX) 15 tmp = rng.randn(2,2).astype(theano.config.floatX)
12 tmp[0] += tmp[1] #induce some covariance 16 tmp[0] += tmp[1] #induce some covariance
13 cov = np.dot(tmp, tmp.T) 17 cov = numpy.dot(tmp, tmp.T)
14 cov_inv = np.linalg.inv(cov).astype(theano.config.floatX) 18 cov_inv = numpy.linalg.inv(cov).astype(theano.config.floatX)
15 mu = np.asarray([5, 9.5], dtype=theano.config.floatX) 19 mu = numpy.asarray([5, 9.5], dtype=theano.config.floatX)
16 20
17 def gaussian_energy(xlist): 21 def gaussian_energy(xlist):
18 x = xlist 22 x = xlist
19 return 0.5 * (TT.dot((x-mu),cov_inv)*(x-mu)).sum(axis=1) 23 return 0.5 * (tensor.dot((x-mu),cov_inv)*(x-mu)).sum(axis=1)
20 24
21 25
22 position = shared(rng.randn(batchsize, 2).astype(theano.config.floatX)) 26 position = theano.shared(rng.randn(batchsize, 2).astype(theano.config.floatX))
23 sampler = sampler_cls(position, gaussian_energy) 27 sampler = sampler_cls(position, gaussian_energy)
24 28
25 print 'initial position', position.get_value(borrow=True) 29 print 'initial position', position.get_value(borrow=True)
26 print 'initial stepsize', sampler.stepsize.get_value(borrow=True) 30 print 'initial stepsize', sampler.stepsize.get_value(borrow=True)
27 31
28 # DRAW SAMPLES 32 # DRAW SAMPLES
29 33
30 samples = [sampler.draw() for r in xrange(burnin)] #burn-in 34 samples = [sampler.draw() for r in xrange(burnin)] #burn-in
31 samples = np.asarray([sampler.draw() for r in xrange(n_samples)]) 35 samples = numpy.asarray([sampler.draw() for r in xrange(n_samples)])
32 36
33 assert sampler.avg_acceptance_rate.get_value() > 0 37 assert sampler.avg_acceptance_rate.get_value() > 0
34 assert sampler.avg_acceptance_rate.get_value() < 1 38 assert sampler.avg_acceptance_rate.get_value() < 1
35 39
36 # TEST THAT THEY ARE FROM THE RIGHT DISTRIBUTION 40 # TEST THAT THEY ARE FROM THE RIGHT DISTRIBUTION
37 41
38 # samples.shape == (1000, 3, 2) 42 # samples.shape == (1000, 3, 2)
39 43
40 print 'target mean:', mu 44 print 'target mean:', mu
41 print 'empirical mean: ', samples.mean(axis=0) 45 print 'empirical mean: ', samples.mean(axis=0)
42 #assert np.all(abs(mu - samples.mean(axis=0)) < 1) 46 #assert numpy.all(abs(mu - samples.mean(axis=0)) < 1)
43 47
44 48
45 print 'final stepsize', sampler.stepsize.get_value() 49 print 'final stepsize', sampler.stepsize.get_value()
46 print 'final acceptance_rate', sampler.avg_acceptance_rate.get_value() 50 print 'final acceptance_rate', sampler.avg_acceptance_rate.get_value()
47 51
48 print 'target cov', cov 52 print 'target cov', cov
49 s = samples[:,0,:] 53 s = samples[:,0,:]
50 empirical_cov = np.cov(samples[:,0,:].T) 54 empirical_cov = numpy.cov(samples[:,0,:].T)
51 print '' 55 print ''
52 print 'cov/empirical_cov', cov/empirical_cov 56 print 'cov/empirical_cov', cov/empirical_cov
53 empirical_cov = np.cov(samples[:,1,:].T) 57 empirical_cov = numpy.cov(samples[:,1,:].T)
54 print 'cov/empirical_cov', cov/empirical_cov 58 print 'cov/empirical_cov', cov/empirical_cov
55 empirical_cov = np.cov(samples[:,2,:].T) 59 empirical_cov = numpy.cov(samples[:,2,:].T)
56 print 'cov/empirical_cov', cov/empirical_cov 60 print 'cov/empirical_cov', cov/empirical_cov
57 return sampler 61 return sampler
58 62
59 def test_hmc(): 63 def test_hmc():
60 print ('HMC') 64 print ('HMC')