diff pylearn/algorithms/tests/test_mcRBM.py @ 1000:d4a14c6c36e0

mcRBM - post code-review #1 with Guillaume
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 19:24:54 -0400
parents
children 075c193afd1b
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/algorithms/tests/test_mcRBM.py	Tue Aug 24 19:24:54 2010 -0400
@@ -0,0 +1,169 @@
+
+
+from pylearn.algorithms.mcRBM import *
+
+
+def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True):
+    dataset='MAR'
+    if dataset == 'MAR':
+        n_vis=105
+        n_patches=10240
+    else:
+        R,C= 16,16 # the size of image patches
+        n_vis=R*C
+        n_patches=100000
+
+    n_train_iters=5000
+
+    n_burnin_steps=10000
+
+    l1_penalty=1e-3
+    no_l1_epochs = 10
+    effective_l1_penalty=0.0
+
+    epoch_size=n_patches
+    batchsize = 128
+    lr = 0.075 / batchsize
+    s_lr = TT.scalar()
+    s_l1_penalty=TT.scalar()
+    n_K=256
+    n_J=100
+
+    rbm = MeanCovRBM.new_from_dims(n_I=n_vis, n_K=n_K, n_J=n_J) 
+
+    smplr = sampler(rbm, n_particles=batchsize)
+
+    def l2(X):
+        return numpy.sqrt((X**2).sum())
+    if dataset == 'MAR':
+        tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
+    else:
+        def tile(X, fname):
+            _img = image_tiling.tile_raster_images(X,
+                    img_shape=(R,C),
+                    min_dynamic_range=1e-2)
+            image_tiling.save_tiled_raster_images(_img, fname)
+
+    batch_idx = TT.iscalar()
+
+    if dataset == 'MAR':
+        train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize))
+    else:
+        train_batch = pylearn.dataset_ops.image_patches.image_patches(
+                s_idx = (batch_idx * batchsize + np.arange(batchsize)),
+                dims = (n_patches,R,C),
+                center=True,
+                unitvar=True,
+                dtype=floatX,
+                rasterized=True)
+
+    if not as_unittest:
+        imgs_fn = function([batch_idx], outputs=train_batch)
+
+    grads = contrastive_grad(
+            free_energy_fn=lambda v: free_energy_given_v(rbm, v),
+            pos_v=train_batch, 
+            neg_v=smplr.positions[0],
+            params=list(rbm),
+            other_cost=(l1(rbm.U)+l1(rbm.W)) * s_l1_penalty)
+    sgd_ups = sgd_updates(
+                rbm.params,
+                grads,
+                stepsizes=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])
+    learn_fn = function([batch_idx, s_lr, s_l1_penalty], 
+            outputs=[ 
+                grads[0].norm(2),
+                (sgd_ups[0][1] - sgd_ups[0][0]).norm(2),
+                (sgd_ups[1][1] - sgd_ups[1][0]).norm(2),
+                ],
+            updates = sgd_ups)
+
+    print "Learning..."
+    normVF=1
+    last_epoch = -1
+    for jj in xrange(n_train_iters):
+        epoch = jj*batchsize / epoch_size
+
+        print_jj = epoch != last_epoch
+        last_epoch = epoch
+
+        if epoch > 10:
+            break
+
+        if as_unittest and epoch == 5:
+            U = rbm.U.value
+            W = rbm.W.value
+            def allclose(a,b):
+                return numpy.allclose(a,b,rtol=1.01,atol=1e-3)
+            print ""
+            print "--------------"
+            print "assert allclose(l2(U), %f)"%l2(U)
+            print "assert allclose(l2(W), %f)"%l2(W)
+            print "assert allclose(U.min(), %f)"%U.min()
+            print "assert allclose(U.max(), %f)"%U.max()
+            print "assert allclose(W.min(),%f)"%W.min()
+            print "assert allclose(W.max(), %f)"%W.max()
+            print "--------------"
+
+            assert allclose(l2(U), 21.351664)
+            assert allclose(l2(W), 6.275828)
+            assert allclose(U.min(), -1.176703)
+            assert allclose(U.max(), 0.859802)
+            assert allclose(W.min(),-0.223128)
+            assert allclose(W.max(), 0.227558 )
+
+            break
+
+        if print_jj:
+            if not as_unittest:
+                tile(imgs_fn(jj), "imgs_%06i.png"%jj)
+                tile(smplr.positions[0].value, "sample_%06i.png"%jj)
+                tile(rbm.U.value.T, "U_%06i.png"%jj)
+                tile(rbm.W.value.T, "W_%06i.png"%jj)
+
+            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)
+
+            print 'l2(U)', l2(rbm.U.value),
+            print 'l2(W)', l2(rbm.W.value)
+
+            print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
+            print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
+            print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
+            print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
+            print 'c min max', rbm.c.value.min(), rbm.c.value.max()
+
+            print 'parts min', smplr.positions[0].value.min(), 
+            print 'max',smplr.positions[0].value.max(),
+            print 'HMC step', smplr.stepsize,
+            print 'arate', smplr.avg_acceptance_rate
+
+        # Continue HMC chain
+        smplr.simulate()
+
+        # Do CD update
+        l2_of_Ugrad = learn_fn(jj, 
+                lr/max(1, jj/(20*epoch_size/batchsize)),
+                effective_l1_penalty)
+
+        if print_jj:
+            print 'l2(U_grad)', float(l2_of_Ugrad[0]),
+            print 'l2(U_inc)', float(l2_of_Ugrad[1]),
+            print 'l2(W_inc)', float(l2_of_Ugrad[2]),
+            #print 'FE+', float(l2_of_Ugrad[2]),
+            #print 'FE+[0]', float(l2_of_Ugrad[3]),
+            #print 'FE+[1]', float(l2_of_Ugrad[4]),
+            #print 'FE+[2]', float(l2_of_Ugrad[5]),
+            #print 'FE+[3]', float(l2_of_Ugrad[6])
+
+        if jj == no_l1_epochs * epoch_size/batchsize:
+            print "Activating L1 weight decay"
+            effective_l1_penalty = 1e-3
+
+        # weird normalization technique...
+        # It constrains all the columns of the matrix to have the same length
+        # But the matrix itself is re-scaled to have an arbitrary abslute size.
+        U = rbm.U.value
+        U_norms = np.sqrt((U*U).sum(axis=0))
+        assert len(U_norms) == n_K
+        normVF = .95 * normVF + .05 * np.mean(U_norms)
+        rbm.U.value = rbm.U.value * normVF/U_norms