view pylearn/algorithms/tests/test_mcRBM.py @ 1267:075c193afd1b

refactoring mcRBM
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 03 Sep 2010 12:35:10 -0400
parents d4a14c6c36e0
children ba25c6e4f55d
line wrap: on
line source

from pylearn.algorithms.mcRBM import *

def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True):
    dataset='MAR'
    if dataset == 'MAR':
        n_vis=105
        n_patches=10240
    else:
        R,C= 16,16 # the size of image patches
        n_vis=R*C
        n_patches=100000

    n_train_iters=5000

    n_burnin_steps=10000


    l1_penalty=1e-3
    no_l1_epochs = 10
    effective_l1_penalty=0.0

    epoch_size=n_patches
    batchsize = 128
    lr = 0.075 / batchsize
    s_lr = TT.scalar()
    n_K=256
    n_J=100

    def l2(X):
        return numpy.sqrt((X**2).sum())
    if dataset == 'MAR':
        tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
    else:
        def tile(X, fname):
            _img = image_tiling.tile_raster_images(X,
                    img_shape=(R,C),
                    min_dynamic_range=1e-2)
            image_tiling.save_tiled_raster_images(_img, fname)

    batch_idx = TT.iscalar()

    if dataset == 'MAR':
        train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize))
    else:
        train_batch = pylearn.dataset_ops.image_patches.image_patches(
                s_idx = (batch_idx * batchsize + np.arange(batchsize)),
                dims = (n_patches,R,C),
                center=True,
                unitvar=True,
                dtype=floatX,
                rasterized=True)

    if not as_unittest:
        imgs_fn = function([batch_idx], outputs=train_batch)

    trainer = mcRBMTrainer.alloc(
            mcRBM.alloc(n_I=n_vis, n_K=n_K, n_J=n_J),
            train_batch,
            batchsize, l1_penalty=TT.scalar())
    rbm=trainer.rbm
    smplr = trainer.sampler

    grads = trainer.contrastive_grads(train_batch)
    learn_fn = function([batch_idx, trainer.l1_penalty], 
            outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
            updates=trainer.cd_updates(train_batch))

    print "Learning..."
    last_epoch = -1
    for jj in xrange(n_train_iters):
        epoch = jj*batchsize / epoch_size

        print_jj = epoch != last_epoch
        last_epoch = epoch

        if as_unittest and epoch == 5:
            U = rbm.U.value
            W = rbm.W.value
            def allclose(a,b):
                return numpy.allclose(a,b,rtol=1.01,atol=1e-3)
            print ""
            print "--------------"
            print "assert allclose(l2(U), %f)"%l2(U)
            print "assert allclose(l2(W), %f)"%l2(W)
            print "assert allclose(U.min(), %f)"%U.min()
            print "assert allclose(U.max(), %f)"%U.max()
            print "assert allclose(W.min(),%f)"%W.min()
            print "assert allclose(W.max(), %f)"%W.max()
            print "--------------"

            assert allclose(l2(U), 21.351664)
            assert allclose(l2(W), 6.275828)
            assert allclose(U.min(), -1.176703)
            assert allclose(U.max(), 0.859802)
            assert allclose(W.min(),-0.223128)
            assert allclose(W.max(), 0.227558 )

            break

        if print_jj:
            if not as_unittest:
                tile(imgs_fn(jj), "imgs_%06i.png"%jj)
                tile(smplr.positions.value, "sample_%06i.png"%jj)
                tile(rbm.U.value.T, "U_%06i.png"%jj)
                tile(rbm.W.value.T, "W_%06i.png"%jj)

            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)

            print 'l2(U)', l2(rbm.U.value),
            print 'l2(W)', l2(rbm.W.value)

            print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
            print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
            print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
            print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
            print 'c min max', rbm.c.value.min(), rbm.c.value.max()

            print 'parts min', smplr.positions.value.min(), 
            print 'max',smplr.positions.value.max(),
            print 'HMC step', smplr.stepsize,
            print 'arate', smplr.avg_acceptance_rate


        if 0:
            # Continue HMC chain
            smplr.simulate()

            # Do CD update
            l2_of_Ugrad = learn_fn(jj, 
                    lr/max(1, jj/(20*epoch_size/batchsize)),
                    effective_l1_penalty)

        learn_fn(jj, effective_l1_penalty)

        if print_jj:
            print 'l2(U_grad)', float(l2_of_Ugrad[0]),
            print 'l2(U_inc)', float(l2_of_Ugrad[1]),
            print 'l2(W_inc)', float(l2_of_Ugrad[2]),
            #print 'FE+', float(l2_of_Ugrad[2]),
            #print 'FE+[0]', float(l2_of_Ugrad[3]),
            #print 'FE+[1]', float(l2_of_Ugrad[4]),
            #print 'FE+[2]', float(l2_of_Ugrad[5]),
            #print 'FE+[3]', float(l2_of_Ugrad[6])

        if jj == no_l1_epochs * epoch_size/batchsize:
            print "Activating L1 weight decay"
            effective_l1_penalty = 1e-3

        # weird normalization technique...
        # It constrains all the columns of the matrix to have the same length
        if 0:
            U = rbm.U.value
            U_norms = np.sqrt((U*U).sum(axis=0))
            assert len(U_norms) == n_K
            normVF = .95 * normVF + .05 * np.mean(U_norms)
            rbm.U.value = rbm.U.value * normVF/U_norms