Mercurial > pylearn
diff pylearn/algorithms/tests/test_mcRBM.py @ 1000:d4a14c6c36e0
mcRBM - post code-review #1 with Guillaume
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 24 Aug 2010 19:24:54 -0400 |
parents | |
children | 075c193afd1b |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/algorithms/tests/test_mcRBM.py Tue Aug 24 19:24:54 2010 -0400 @@ -0,0 +1,169 @@ + + +from pylearn.algorithms.mcRBM import * + + +def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True): + dataset='MAR' + if dataset == 'MAR': + n_vis=105 + n_patches=10240 + else: + R,C= 16,16 # the size of image patches + n_vis=R*C + n_patches=100000 + + n_train_iters=5000 + + n_burnin_steps=10000 + + l1_penalty=1e-3 + no_l1_epochs = 10 + effective_l1_penalty=0.0 + + epoch_size=n_patches + batchsize = 128 + lr = 0.075 / batchsize + s_lr = TT.scalar() + s_l1_penalty=TT.scalar() + n_K=256 + n_J=100 + + rbm = MeanCovRBM.new_from_dims(n_I=n_vis, n_K=n_K, n_J=n_J) + + smplr = sampler(rbm, n_particles=batchsize) + + def l2(X): + return numpy.sqrt((X**2).sum()) + if dataset == 'MAR': + tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010 + else: + def tile(X, fname): + _img = image_tiling.tile_raster_images(X, + img_shape=(R,C), + min_dynamic_range=1e-2) + image_tiling.save_tiled_raster_images(_img, fname) + + batch_idx = TT.iscalar() + + if dataset == 'MAR': + train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize)) + else: + train_batch = pylearn.dataset_ops.image_patches.image_patches( + s_idx = (batch_idx * batchsize + np.arange(batchsize)), + dims = (n_patches,R,C), + center=True, + unitvar=True, + dtype=floatX, + rasterized=True) + + if not as_unittest: + imgs_fn = function([batch_idx], outputs=train_batch) + + grads = contrastive_grad( + free_energy_fn=lambda v: free_energy_given_v(rbm, v), + pos_v=train_batch, + neg_v=smplr.positions[0], + params=list(rbm), + other_cost=(l1(rbm.U)+l1(rbm.W)) * s_l1_penalty) + sgd_ups = sgd_updates( + rbm.params, + grads, + stepsizes=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ]) + learn_fn = function([batch_idx, s_lr, s_l1_penalty], + outputs=[ + grads[0].norm(2), + (sgd_ups[0][1] - sgd_ups[0][0]).norm(2), + (sgd_ups[1][1] - sgd_ups[1][0]).norm(2), + ], + updates = sgd_ups) + + print "Learning..." + normVF=1 + last_epoch = -1 + for jj in xrange(n_train_iters): + epoch = jj*batchsize / epoch_size + + print_jj = epoch != last_epoch + last_epoch = epoch + + if epoch > 10: + break + + if as_unittest and epoch == 5: + U = rbm.U.value + W = rbm.W.value + def allclose(a,b): + return numpy.allclose(a,b,rtol=1.01,atol=1e-3) + print "" + print "--------------" + print "assert allclose(l2(U), %f)"%l2(U) + print "assert allclose(l2(W), %f)"%l2(W) + print "assert allclose(U.min(), %f)"%U.min() + print "assert allclose(U.max(), %f)"%U.max() + print "assert allclose(W.min(),%f)"%W.min() + print "assert allclose(W.max(), %f)"%W.max() + print "--------------" + + assert allclose(l2(U), 21.351664) + assert allclose(l2(W), 6.275828) + assert allclose(U.min(), -1.176703) + assert allclose(U.max(), 0.859802) + assert allclose(W.min(),-0.223128) + assert allclose(W.max(), 0.227558 ) + + break + + if print_jj: + if not as_unittest: + tile(imgs_fn(jj), "imgs_%06i.png"%jj) + tile(smplr.positions[0].value, "sample_%06i.png"%jj) + tile(rbm.U.value.T, "U_%06i.png"%jj) + tile(rbm.W.value.T, "W_%06i.png"%jj) + + print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize) + + print 'l2(U)', l2(rbm.U.value), + print 'l2(W)', l2(rbm.W.value) + + print 'U min max', rbm.U.value.min(), rbm.U.value.max(), + print 'W min max', rbm.W.value.min(), rbm.W.value.max(), + print 'a min max', rbm.a.value.min(), rbm.a.value.max(), + print 'b min max', rbm.b.value.min(), rbm.b.value.max(), + print 'c min max', rbm.c.value.min(), rbm.c.value.max() + + print 'parts min', smplr.positions[0].value.min(), + print 'max',smplr.positions[0].value.max(), + print 'HMC step', smplr.stepsize, + print 'arate', smplr.avg_acceptance_rate + + # Continue HMC chain + smplr.simulate() + + # Do CD update + l2_of_Ugrad = learn_fn(jj, + lr/max(1, jj/(20*epoch_size/batchsize)), + effective_l1_penalty) + + if print_jj: + print 'l2(U_grad)', float(l2_of_Ugrad[0]), + print 'l2(U_inc)', float(l2_of_Ugrad[1]), + print 'l2(W_inc)', float(l2_of_Ugrad[2]), + #print 'FE+', float(l2_of_Ugrad[2]), + #print 'FE+[0]', float(l2_of_Ugrad[3]), + #print 'FE+[1]', float(l2_of_Ugrad[4]), + #print 'FE+[2]', float(l2_of_Ugrad[5]), + #print 'FE+[3]', float(l2_of_Ugrad[6]) + + if jj == no_l1_epochs * epoch_size/batchsize: + print "Activating L1 weight decay" + effective_l1_penalty = 1e-3 + + # weird normalization technique... + # It constrains all the columns of the matrix to have the same length + # But the matrix itself is re-scaled to have an arbitrary abslute size. + U = rbm.U.value + U_norms = np.sqrt((U*U).sum(axis=0)) + assert len(U_norms) == n_K + normVF = .95 * normVF + .05 * np.mean(U_norms) + rbm.U.value = rbm.U.value * normVF/U_norms