view pylearn/algorithms/tests/test_mcRBM.py @ 1000:d4a14c6c36e0

mcRBM - post code-review #1 with Guillaume
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 19:24:54 -0400
parents
children 075c193afd1b
line wrap: on
line source



from pylearn.algorithms.mcRBM import *


def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True):
    dataset='MAR'
    if dataset == 'MAR':
        n_vis=105
        n_patches=10240
    else:
        R,C= 16,16 # the size of image patches
        n_vis=R*C
        n_patches=100000

    n_train_iters=5000

    n_burnin_steps=10000

    l1_penalty=1e-3
    no_l1_epochs = 10
    effective_l1_penalty=0.0

    epoch_size=n_patches
    batchsize = 128
    lr = 0.075 / batchsize
    s_lr = TT.scalar()
    s_l1_penalty=TT.scalar()
    n_K=256
    n_J=100

    rbm = MeanCovRBM.new_from_dims(n_I=n_vis, n_K=n_K, n_J=n_J) 

    smplr = sampler(rbm, n_particles=batchsize)

    def l2(X):
        return numpy.sqrt((X**2).sum())
    if dataset == 'MAR':
        tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
    else:
        def tile(X, fname):
            _img = image_tiling.tile_raster_images(X,
                    img_shape=(R,C),
                    min_dynamic_range=1e-2)
            image_tiling.save_tiled_raster_images(_img, fname)

    batch_idx = TT.iscalar()

    if dataset == 'MAR':
        train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize))
    else:
        train_batch = pylearn.dataset_ops.image_patches.image_patches(
                s_idx = (batch_idx * batchsize + np.arange(batchsize)),
                dims = (n_patches,R,C),
                center=True,
                unitvar=True,
                dtype=floatX,
                rasterized=True)

    if not as_unittest:
        imgs_fn = function([batch_idx], outputs=train_batch)

    grads = contrastive_grad(
            free_energy_fn=lambda v: free_energy_given_v(rbm, v),
            pos_v=train_batch, 
            neg_v=smplr.positions[0],
            params=list(rbm),
            other_cost=(l1(rbm.U)+l1(rbm.W)) * s_l1_penalty)
    sgd_ups = sgd_updates(
                rbm.params,
                grads,
                stepsizes=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])
    learn_fn = function([batch_idx, s_lr, s_l1_penalty], 
            outputs=[ 
                grads[0].norm(2),
                (sgd_ups[0][1] - sgd_ups[0][0]).norm(2),
                (sgd_ups[1][1] - sgd_ups[1][0]).norm(2),
                ],
            updates = sgd_ups)

    print "Learning..."
    normVF=1
    last_epoch = -1
    for jj in xrange(n_train_iters):
        epoch = jj*batchsize / epoch_size

        print_jj = epoch != last_epoch
        last_epoch = epoch

        if epoch > 10:
            break

        if as_unittest and epoch == 5:
            U = rbm.U.value
            W = rbm.W.value
            def allclose(a,b):
                return numpy.allclose(a,b,rtol=1.01,atol=1e-3)
            print ""
            print "--------------"
            print "assert allclose(l2(U), %f)"%l2(U)
            print "assert allclose(l2(W), %f)"%l2(W)
            print "assert allclose(U.min(), %f)"%U.min()
            print "assert allclose(U.max(), %f)"%U.max()
            print "assert allclose(W.min(),%f)"%W.min()
            print "assert allclose(W.max(), %f)"%W.max()
            print "--------------"

            assert allclose(l2(U), 21.351664)
            assert allclose(l2(W), 6.275828)
            assert allclose(U.min(), -1.176703)
            assert allclose(U.max(), 0.859802)
            assert allclose(W.min(),-0.223128)
            assert allclose(W.max(), 0.227558 )

            break

        if print_jj:
            if not as_unittest:
                tile(imgs_fn(jj), "imgs_%06i.png"%jj)
                tile(smplr.positions[0].value, "sample_%06i.png"%jj)
                tile(rbm.U.value.T, "U_%06i.png"%jj)
                tile(rbm.W.value.T, "W_%06i.png"%jj)

            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)

            print 'l2(U)', l2(rbm.U.value),
            print 'l2(W)', l2(rbm.W.value)

            print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
            print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
            print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
            print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
            print 'c min max', rbm.c.value.min(), rbm.c.value.max()

            print 'parts min', smplr.positions[0].value.min(), 
            print 'max',smplr.positions[0].value.max(),
            print 'HMC step', smplr.stepsize,
            print 'arate', smplr.avg_acceptance_rate

        # Continue HMC chain
        smplr.simulate()

        # Do CD update
        l2_of_Ugrad = learn_fn(jj, 
                lr/max(1, jj/(20*epoch_size/batchsize)),
                effective_l1_penalty)

        if print_jj:
            print 'l2(U_grad)', float(l2_of_Ugrad[0]),
            print 'l2(U_inc)', float(l2_of_Ugrad[1]),
            print 'l2(W_inc)', float(l2_of_Ugrad[2]),
            #print 'FE+', float(l2_of_Ugrad[2]),
            #print 'FE+[0]', float(l2_of_Ugrad[3]),
            #print 'FE+[1]', float(l2_of_Ugrad[4]),
            #print 'FE+[2]', float(l2_of_Ugrad[5]),
            #print 'FE+[3]', float(l2_of_Ugrad[6])

        if jj == no_l1_epochs * epoch_size/batchsize:
            print "Activating L1 weight decay"
            effective_l1_penalty = 1e-3

        # weird normalization technique...
        # It constrains all the columns of the matrix to have the same length
        # But the matrix itself is re-scaled to have an arbitrary abslute size.
        U = rbm.U.value
        U_norms = np.sqrt((U*U).sum(axis=0))
        assert len(U_norms) == n_K
        normVF = .95 * normVF + .05 * np.mean(U_norms)
        rbm.U.value = rbm.U.value * normVF/U_norms