diff pylearn/algorithms/tests/test_mcRBM.py @ 1267:075c193afd1b

refactoring mcRBM
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 03 Sep 2010 12:35:10 -0400
parents d4a14c6c36e0
children ba25c6e4f55d
line wrap: on
line diff
--- a/pylearn/algorithms/tests/test_mcRBM.py	Thu Sep 02 16:48:33 2010 -0400
+++ b/pylearn/algorithms/tests/test_mcRBM.py	Fri Sep 03 12:35:10 2010 -0400
@@ -1,8 +1,5 @@
-
-
 from pylearn.algorithms.mcRBM import *
 
-
 def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True):
     dataset='MAR'
     if dataset == 'MAR':
@@ -17,6 +14,7 @@
 
     n_burnin_steps=10000
 
+
     l1_penalty=1e-3
     no_l1_epochs = 10
     effective_l1_penalty=0.0
@@ -25,14 +23,9 @@
     batchsize = 128
     lr = 0.075 / batchsize
     s_lr = TT.scalar()
-    s_l1_penalty=TT.scalar()
     n_K=256
     n_J=100
 
-    rbm = MeanCovRBM.new_from_dims(n_I=n_vis, n_K=n_K, n_J=n_J) 
-
-    smplr = sampler(rbm, n_particles=batchsize)
-
     def l2(X):
         return numpy.sqrt((X**2).sum())
     if dataset == 'MAR':
@@ -60,26 +53,19 @@
     if not as_unittest:
         imgs_fn = function([batch_idx], outputs=train_batch)
 
-    grads = contrastive_grad(
-            free_energy_fn=lambda v: free_energy_given_v(rbm, v),
-            pos_v=train_batch, 
-            neg_v=smplr.positions[0],
-            params=list(rbm),
-            other_cost=(l1(rbm.U)+l1(rbm.W)) * s_l1_penalty)
-    sgd_ups = sgd_updates(
-                rbm.params,
-                grads,
-                stepsizes=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])
-    learn_fn = function([batch_idx, s_lr, s_l1_penalty], 
-            outputs=[ 
-                grads[0].norm(2),
-                (sgd_ups[0][1] - sgd_ups[0][0]).norm(2),
-                (sgd_ups[1][1] - sgd_ups[1][0]).norm(2),
-                ],
-            updates = sgd_ups)
+    trainer = mcRBMTrainer.alloc(
+            mcRBM.alloc(n_I=n_vis, n_K=n_K, n_J=n_J),
+            train_batch,
+            batchsize, l1_penalty=TT.scalar())
+    rbm=trainer.rbm
+    smplr = trainer.sampler
+
+    grads = trainer.contrastive_grads(train_batch)
+    learn_fn = function([batch_idx, trainer.l1_penalty], 
+            outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
+            updates=trainer.cd_updates(train_batch))
 
     print "Learning..."
-    normVF=1
     last_epoch = -1
     for jj in xrange(n_train_iters):
         epoch = jj*batchsize / epoch_size
@@ -87,9 +73,6 @@
         print_jj = epoch != last_epoch
         last_epoch = epoch
 
-        if epoch > 10:
-            break
-
         if as_unittest and epoch == 5:
             U = rbm.U.value
             W = rbm.W.value
@@ -117,7 +100,7 @@
         if print_jj:
             if not as_unittest:
                 tile(imgs_fn(jj), "imgs_%06i.png"%jj)
-                tile(smplr.positions[0].value, "sample_%06i.png"%jj)
+                tile(smplr.positions.value, "sample_%06i.png"%jj)
                 tile(rbm.U.value.T, "U_%06i.png"%jj)
                 tile(rbm.W.value.T, "W_%06i.png"%jj)
 
@@ -132,18 +115,22 @@
             print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
             print 'c min max', rbm.c.value.min(), rbm.c.value.max()
 
-            print 'parts min', smplr.positions[0].value.min(), 
-            print 'max',smplr.positions[0].value.max(),
+            print 'parts min', smplr.positions.value.min(), 
+            print 'max',smplr.positions.value.max(),
             print 'HMC step', smplr.stepsize,
             print 'arate', smplr.avg_acceptance_rate
 
-        # Continue HMC chain
-        smplr.simulate()
+
+        if 0:
+            # Continue HMC chain
+            smplr.simulate()
 
-        # Do CD update
-        l2_of_Ugrad = learn_fn(jj, 
-                lr/max(1, jj/(20*epoch_size/batchsize)),
-                effective_l1_penalty)
+            # Do CD update
+            l2_of_Ugrad = learn_fn(jj, 
+                    lr/max(1, jj/(20*epoch_size/batchsize)),
+                    effective_l1_penalty)
+
+        learn_fn(jj, effective_l1_penalty)
 
         if print_jj:
             print 'l2(U_grad)', float(l2_of_Ugrad[0]),
@@ -161,9 +148,10 @@
 
         # weird normalization technique...
         # It constrains all the columns of the matrix to have the same length
-        # But the matrix itself is re-scaled to have an arbitrary abslute size.
-        U = rbm.U.value
-        U_norms = np.sqrt((U*U).sum(axis=0))
-        assert len(U_norms) == n_K
-        normVF = .95 * normVF + .05 * np.mean(U_norms)
-        rbm.U.value = rbm.U.value * normVF/U_norms
+        if 0:
+            U = rbm.U.value
+            U_norms = np.sqrt((U*U).sum(axis=0))
+            assert len(U_norms) == n_K
+            normVF = .95 * normVF + .05 * np.mean(U_norms)
+            rbm.U.value = rbm.U.value * normVF/U_norms
+