changeset 1000:d4a14c6c36e0

mcRBM - post code-review #1 with Guillaume
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 19:24:54 -0400
parents c6d08a760960
children 660d784d14c7
files pylearn/algorithms/mcRBM.py pylearn/algorithms/tests/__init__.py pylearn/algorithms/tests/test_mcRBM.py pylearn/sandbox/train_mcRBM.py
diffstat 3 files changed, 695 insertions(+), 163 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py	Tue Aug 24 17:01:09 2010 -0400
+++ b/pylearn/algorithms/mcRBM.py	Tue Aug 24 19:24:54 2010 -0400
@@ -190,8 +190,7 @@
 #    + 0.5 \sum_i v_i^2
 #    - \sum_i a_i v_i 
 
-import sys
-import logging
+import sys, os, logging
 import numpy as np
 import numpy
 
@@ -201,21 +200,54 @@
 floatX = theano.config.floatX
 
 import pylearn
+#TODO: clean up the HMC_sampler code
+#TODO: think of naming convention for acronyms + suffix?
 from pylearn.sampling.hmc import HMC_sampler
 from pylearn.io import image_tiling
 from pylearn.gd.sgd import sgd_updates
+import pylearn.dataset_ops.image_patches
 
-#TODO: This should be in the datasets folder
-import pylearn.datasets.config
-import pylearn.dataset_ops.image_patches
-from pylearn.dataset_ops.protocol import TensorFnDataset
-from pylearn.dataset_ops.memo import memo
-import pylearn
-import scipy.io
-import os
+###########################################
+#
+# Candidates for factoring
+#
+###########################################
+
+#TODO: Document, move to pylearn's math lib
+def l1(X):
+    return abs(X).sum()
+
+#TODO: Document, move to pylearn's math lib
+def l2(X):
+    return TT.sqrt((X**2).sum())
+
+#TODO: Document, move to pylearn's math lib
+def contrastive_cost(free_energy_fn, pos_v, neg_v):
+    return (free_energy_fn(pos_v) - free_energy_fn(neg_v)).sum()
 
+#TODO: Typical use of contrastive_cost is to later use tensor.grad, but in that case we want to
+#      block  gradient going through neg_v
+def contrastive_grad(free_energy_fn, pos_v, neg_v, params, other_cost=0):
+    """
+    :param pos_v: positive-phase sample of visible units
+    :param neg_v: negative-phase sample of visible units
+    """
+    #block the grad through neg_v
+    cost=contrastive_cost(free_energy_fn, pos_v, neg_v)
+    if other_cost:
+        cost = cost + other_cost
+    return theano.tensor.grad(cost,
+            wrt=params,
+            consider_constant=[neg_v])
 
-#TODO: This should be in the nnet part of the library
+###########################################
+#
+# Expressions that are mcRBM-specific
+#
+###########################################
+
+#TODO: make global function to initialize parameter
+
 def hidden_cov_units_preactivation_given_v(rbm, v, small=0.5):
     """Return argument to the sigmoid that would give mean of covariance hid units
 
@@ -248,24 +280,6 @@
     """
     return sum(free_energy_terms_given_v(rbm,v))
 
-def contrastive_gradient(rbm, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0):
-    """Return a list of gradient expressions for the rbm parameters
-
-    :param pos_v: positive-phase sample of visible units
-    :param neg_v: negative-phase sample of visible units
-    :param U_l1_penalty: a scalar-valued multiplier on the L1 penalty on U
-    :param W_l1_penalty: a scalar-valued multiplier on the L1 penalty on W
-    """
-    U, W, a, b, c = rbm
-    pos_FE = free_energy_given_v(rbm, pos_v)
-    neg_FE = free_energy_given_v(rbm, neg_v)
-    c0 = (pos_FE - neg_FE).sum()
-    c1 = abs(U).sum()*U_l1_penalty 
-    c2 = abs(W).sum()*W_l1_penalty 
-    cost = c0 + c1 + c2
-    rval = theano.tensor.grad(cost, list(rbm))
-    return rval
-
 def expected_h_g_given_v(rbm, v):
     """Returns tuple (`h`, `g`) of theano expression conditional expectations in an mcRBM.
 
@@ -291,7 +305,6 @@
     except AttributeError:
         return W.shape[0]
 
-
 def sampler(rbm, n_particles, n_visible=None, rng=7823748):
     """Return an `HMC_sampler` that will draw samples from the distribution over visible
     units specified by this RBM.
@@ -313,6 +326,12 @@
         seed=int(rng.randint(2**30)))
     return rval
 
+#############################
+#
+# Convenient data container
+#
+#############################
+
 class MeanCovRBM(object):
     """Container for mcRBM parameters
 
@@ -380,140 +399,12 @@
             d[key] = shared(d[key], name=key)
         self.__init__(**d)
 
-if __name__ == '__main__':
-
-    dataset='MAR'
-    if dataset == 'MAR':
-        n_vis=105
-        n_patches=10240
-    else:
-        R,C= 16,16 # the size of image patches
-        n_vis=R*C
-        n_patches=100000
-
-    n_train_iters=5000
-
-    n_burnin_steps=10000
-
-    l1_penalty=1e-3
-    no_l1_epochs = 10
-    effective_l1_penalty=0.0
-
-    epoch_size=n_patches
-    batchsize = 128
-    lr = 0.075 / batchsize
-    s_lr = TT.scalar()
-    s_l1_penalty=TT.scalar()
-    n_K=256
-    n_J=100
 
-    rbm = MeanCovRBM.new_from_dims(n_I=n_vis, n_K=n_K, n_J=n_J) 
-
-    smplr = sampler(rbm, n_particles=batchsize)
-
-    def l2(X):
-        return numpy.sqrt((X**2).sum())
-    if dataset == 'MAR':
-        tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
-    else:
-        def tile(X, fname):
-            _img = image_tiling.tile_raster_images(X,
-                    img_shape=(R,C),
-                    min_dynamic_range=1e-2)
-            image_tiling.save_tiled_raster_images(_img, fname)
+#TODO: put the normalization of U as a global function
 
-    batch_idx = TT.iscalar()
-
-    if dataset == 'MAR':
-        train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize))
-    else:
-        train_batch = pylearn.dataset_ops.image_patches.image_patches(
-                s_idx = (batch_idx * batchsize + np.arange(batchsize)),
-                dims = (n_patches,R,C),
-                center=True,
-                unitvar=True,
-                dtype=floatX,
-                rasterized=True)
-
-    imgs_fn = function([batch_idx], outputs=train_batch)
 
-    grads = contrastive_gradient(rbm,
-            pos_v=train_batch, 
-            neg_v=smplr.positions[0],
-            U_l1_penalty=s_l1_penalty,
-            W_l1_penalty=s_l1_penalty)
-    sgd_ups = sgd_updates(
-                rbm.params,
-                grads,
-                stepsizes=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])
-    learn_fn = function([batch_idx, s_lr, s_l1_penalty], 
-            outputs=[ 
-                grads[0].norm(2),
-                (sgd_ups[0][1] - sgd_ups[0][0]).norm(2),
-                (sgd_ups[1][1] - sgd_ups[1][0]).norm(2),
-                ],
-            updates = sgd_ups)
-
-    print "Learning..."
-    normVF=1
-    last_epoch = -1
-    for jj in xrange(n_train_iters):
-        epoch = jj*batchsize / epoch_size
-
-        print_jj = epoch != last_epoch
-        last_epoch = epoch
-
-        if epoch > 10:
-            break
-
-        if print_jj:
-            tile(imgs_fn(jj), "imgs_%06i.png"%jj)
-            tile(smplr.positions[0].value, "sample_%06i.png"%jj)
-            tile(rbm.U.value.T, "U_%06i.png"%jj)
-            tile(rbm.W.value.T, "W_%06i.png"%jj)
-
-            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)
-
-            print 'l2(U)', l2(rbm.U.value),
-            print 'l2(W)', l2(rbm.W.value)
+#TODO: put the learning loop as a global function or class, so that someone could load and *TRAIN* an mcRBM!!!
 
-            print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
-            print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
-            print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
-            print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
-            print 'c min max', rbm.c.value.min(), rbm.c.value.max()
-
-            print 'parts min', smplr.positions[0].value.min(), 
-            print 'max',smplr.positions[0].value.max(),
-            print 'HMC step', smplr.stepsize,
-            print 'arate', smplr.avg_acceptance_rate
-
-        smplr.simulate()
-
-        l2_of_Ugrad = learn_fn(jj, 
-                lr/max(1, jj/(20*epoch_size/batchsize)),
-                effective_l1_penalty)
-
-        if print_jj:
-            print 'l2(U_grad)', float(l2_of_Ugrad[0]),
-            print 'l2(U_inc)', float(l2_of_Ugrad[1]),
-            print 'l2(W_inc)', float(l2_of_Ugrad[2]),
-            #print 'FE+', float(l2_of_Ugrad[2]),
-            #print 'FE+[0]', float(l2_of_Ugrad[3]),
-            #print 'FE+[1]', float(l2_of_Ugrad[4]),
-            #print 'FE+[2]', float(l2_of_Ugrad[5]),
-            #print 'FE+[3]', float(l2_of_Ugrad[6])
-
-        if jj == no_l1_epochs * epoch_size/batchsize:
-            print "Activating L1 weight decay"
-            effective_l1_penalty = 1e-3
-
-        # weird normalization technique...
-        # It constrains all the columns of the matrix to have the same length
-        # But the matrix itself is re-scaled to have an arbitrary abslute size.
-        U = rbm.U.value
-        U_norms = np.sqrt((U*U).sum(axis=0))
-        assert len(U_norms) == n_K
-        normVF = .95 * normVF + .05 * np.mean(U_norms)
-        rbm.U.value = rbm.U.value * normVF/U_norms
-
+if __name__ == '__main__':
+    import pylearn.algorithms.tests.test_mcRBM
+    pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010(as_unittest=True)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/algorithms/tests/test_mcRBM.py	Tue Aug 24 19:24:54 2010 -0400
@@ -0,0 +1,169 @@
+
+
+from pylearn.algorithms.mcRBM import *
+
+
+def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True):
+    dataset='MAR'
+    if dataset == 'MAR':
+        n_vis=105
+        n_patches=10240
+    else:
+        R,C= 16,16 # the size of image patches
+        n_vis=R*C
+        n_patches=100000
+
+    n_train_iters=5000
+
+    n_burnin_steps=10000
+
+    l1_penalty=1e-3
+    no_l1_epochs = 10
+    effective_l1_penalty=0.0
+
+    epoch_size=n_patches
+    batchsize = 128
+    lr = 0.075 / batchsize
+    s_lr = TT.scalar()
+    s_l1_penalty=TT.scalar()
+    n_K=256
+    n_J=100
+
+    rbm = MeanCovRBM.new_from_dims(n_I=n_vis, n_K=n_K, n_J=n_J) 
+
+    smplr = sampler(rbm, n_particles=batchsize)
+
+    def l2(X):
+        return numpy.sqrt((X**2).sum())
+    if dataset == 'MAR':
+        tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
+    else:
+        def tile(X, fname):
+            _img = image_tiling.tile_raster_images(X,
+                    img_shape=(R,C),
+                    min_dynamic_range=1e-2)
+            image_tiling.save_tiled_raster_images(_img, fname)
+
+    batch_idx = TT.iscalar()
+
+    if dataset == 'MAR':
+        train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize))
+    else:
+        train_batch = pylearn.dataset_ops.image_patches.image_patches(
+                s_idx = (batch_idx * batchsize + np.arange(batchsize)),
+                dims = (n_patches,R,C),
+                center=True,
+                unitvar=True,
+                dtype=floatX,
+                rasterized=True)
+
+    if not as_unittest:
+        imgs_fn = function([batch_idx], outputs=train_batch)
+
+    grads = contrastive_grad(
+            free_energy_fn=lambda v: free_energy_given_v(rbm, v),
+            pos_v=train_batch, 
+            neg_v=smplr.positions[0],
+            params=list(rbm),
+            other_cost=(l1(rbm.U)+l1(rbm.W)) * s_l1_penalty)
+    sgd_ups = sgd_updates(
+                rbm.params,
+                grads,
+                stepsizes=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])
+    learn_fn = function([batch_idx, s_lr, s_l1_penalty], 
+            outputs=[ 
+                grads[0].norm(2),
+                (sgd_ups[0][1] - sgd_ups[0][0]).norm(2),
+                (sgd_ups[1][1] - sgd_ups[1][0]).norm(2),
+                ],
+            updates = sgd_ups)
+
+    print "Learning..."
+    normVF=1
+    last_epoch = -1
+    for jj in xrange(n_train_iters):
+        epoch = jj*batchsize / epoch_size
+
+        print_jj = epoch != last_epoch
+        last_epoch = epoch
+
+        if epoch > 10:
+            break
+
+        if as_unittest and epoch == 5:
+            U = rbm.U.value
+            W = rbm.W.value
+            def allclose(a,b):
+                return numpy.allclose(a,b,rtol=1.01,atol=1e-3)
+            print ""
+            print "--------------"
+            print "assert allclose(l2(U), %f)"%l2(U)
+            print "assert allclose(l2(W), %f)"%l2(W)
+            print "assert allclose(U.min(), %f)"%U.min()
+            print "assert allclose(U.max(), %f)"%U.max()
+            print "assert allclose(W.min(),%f)"%W.min()
+            print "assert allclose(W.max(), %f)"%W.max()
+            print "--------------"
+
+            assert allclose(l2(U), 21.351664)
+            assert allclose(l2(W), 6.275828)
+            assert allclose(U.min(), -1.176703)
+            assert allclose(U.max(), 0.859802)
+            assert allclose(W.min(),-0.223128)
+            assert allclose(W.max(), 0.227558 )
+
+            break
+
+        if print_jj:
+            if not as_unittest:
+                tile(imgs_fn(jj), "imgs_%06i.png"%jj)
+                tile(smplr.positions[0].value, "sample_%06i.png"%jj)
+                tile(rbm.U.value.T, "U_%06i.png"%jj)
+                tile(rbm.W.value.T, "W_%06i.png"%jj)
+
+            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)
+
+            print 'l2(U)', l2(rbm.U.value),
+            print 'l2(W)', l2(rbm.W.value)
+
+            print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
+            print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
+            print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
+            print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
+            print 'c min max', rbm.c.value.min(), rbm.c.value.max()
+
+            print 'parts min', smplr.positions[0].value.min(), 
+            print 'max',smplr.positions[0].value.max(),
+            print 'HMC step', smplr.stepsize,
+            print 'arate', smplr.avg_acceptance_rate
+
+        # Continue HMC chain
+        smplr.simulate()
+
+        # Do CD update
+        l2_of_Ugrad = learn_fn(jj, 
+                lr/max(1, jj/(20*epoch_size/batchsize)),
+                effective_l1_penalty)
+
+        if print_jj:
+            print 'l2(U_grad)', float(l2_of_Ugrad[0]),
+            print 'l2(U_inc)', float(l2_of_Ugrad[1]),
+            print 'l2(W_inc)', float(l2_of_Ugrad[2]),
+            #print 'FE+', float(l2_of_Ugrad[2]),
+            #print 'FE+[0]', float(l2_of_Ugrad[3]),
+            #print 'FE+[1]', float(l2_of_Ugrad[4]),
+            #print 'FE+[2]', float(l2_of_Ugrad[5]),
+            #print 'FE+[3]', float(l2_of_Ugrad[6])
+
+        if jj == no_l1_epochs * epoch_size/batchsize:
+            print "Activating L1 weight decay"
+            effective_l1_penalty = 1e-3
+
+        # weird normalization technique...
+        # It constrains all the columns of the matrix to have the same length
+        # But the matrix itself is re-scaled to have an arbitrary abslute size.
+        U = rbm.U.value
+        U_norms = np.sqrt((U*U).sum(axis=0))
+        assert len(U_norms) == n_K
+        normVF = .95 * normVF + .05 * np.mean(U_norms)
+        rbm.U.value = rbm.U.value * normVF/U_norms
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/sandbox/train_mcRBM.py	Tue Aug 24 19:24:54 2010 -0400
@@ -0,0 +1,472 @@
+"""
+This is a copy of mcRBM training that James modified to print out more information, visualize
+filters, etc.  Once mcRBM is stable, it can be deleted.
+"""
+# mcRBM training
+# Refer to Ranzato and Hinton CVPR 2010 "Modeling Pixel Means and Covariances Using Factorized Third-Order BMs"
+#
+# Marc'Aurelio Ranzato
+# 28 July 2010
+
+import sys
+import numpy as np
+import cudamat as cmt
+from scipy.io import loadmat, savemat
+#import gpu_lock # put here you locking system package, if any
+from ConfigParser import *
+
+demodata = None
+
+from pylearn.io import image_tiling
+def tile(X, fname):
+    X = np.dot(X, demodata['invpcatransf'].T)
+    R=16
+    C=16
+    X = (X[:,:256], X[:,256:512], X[:,512:], None)
+    #X = (X[:,0::3], X[:,1::3], X[:,2::3], None)
+    _img = image_tiling.tile_raster_images(X,
+            img_shape=(R,C),
+            min_dynamic_range=1e-2)
+    image_tiling.save_tiled_raster_images(_img, fname)
+
+def save_imshow(X, fname):
+    image_tiling.Image.fromarray(
+            (image_tiling.scale_to_unit_interval(X)*255).astype('uint8'),
+            'L').save(fname)
+
+######################################################################
+# compute the value of the free energy at a given input
+# F = - sum log(1+exp(- .5 FH (VF data/norm(data))^2 + bias_cov)) +...
+#     - sum log(1+exp(w_mean data + bias_mean)) + ...
+#     - bias_vis data + 0.5 data^2
+# NOTE: FH is constrained to be positive 
+# (in the paper the sign is negative but the sign in front of it is also flipped)
+def compute_energy_mcRBM(data,normdata,vel,energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis):
+    # normalize input data vectors
+    data.mult(data, target = t6) # DxP (nr input dims x nr samples)
+    t6.sum(axis = 0, target = lengthsq) # 1xP
+    lengthsq.mult(0.5, target = energy) # energy of quadratic regularization term   
+    lengthsq.mult(1./num_vis) # normalize by number of components (like std)    
+    lengthsq.add(small) # small prevents division by 0
+    cmt.sqrt(lengthsq, target = length) 
+    length.reciprocal(target = normcoeff) # 1xP
+    data.mult_by_row(normcoeff, target = normdata) # normalized data    
+    ## potential
+    # covariance contribution
+    cmt.dot(VF.T, normdata, target = feat) # HxP (nr factors x nr samples)
+    feat.mult(feat, target = featsq)   # HxP
+    cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples)
+    t1.mult(-0.5)
+    t1.add_col_vec(bias_cov) # OxP
+    cmt.exp(t1) # OxP
+    t1.add(1, target = t2) # OxP
+    cmt.log(t2)
+    t2.mult(-1)
+    energy.add_sums(t2, axis=0)
+    # mean contribution
+    cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples)
+    feat_mean.add_col_vec(bias_mean) # HxP
+    cmt.exp(feat_mean) 
+    feat_mean.add(1)
+    cmt.log(feat_mean)
+    feat_mean.mult(-1)
+    energy.add_sums(feat_mean,  axis=0)
+    # visible bias term
+    data.mult_by_col(bias_vis, target = t6)
+    t6.mult(-1) # DxP
+    energy.add_sums(t6,  axis=0) # 1xP
+    # kinetic
+    vel.mult(vel, target = t6)
+    energy.add_sums(t6, axis = 0, mult = .5)
+
+#################################################################
+# compute the derivative if the free energy at a given input
+def compute_gradient_mcRBM(data,normdata,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t3,t4,t6,feat,featsq,feat_mean,gradient,normgradient,length,lengthsq,normcoeff,small,num_vis):
+    # normalize input data
+    data.mult(data, target = t6) # DxP
+    t6.sum(axis = 0, target = lengthsq) # 1xP
+    lengthsq.mult(1./num_vis) # normalize by number of components (like std)
+    lengthsq.add(small)
+    cmt.sqrt(lengthsq, target = length)
+    length.reciprocal(target = normcoeff) # 1xP
+    data.mult_by_row(normcoeff, target = normdata) # normalized data    
+    cmt.dot(VF.T, normdata, target = feat) # HxP 
+    feat.mult(feat, target = featsq)   # HxP
+    cmt.dot(FH.T,featsq, target = t1) # OxP
+    t1.mult(-.5)
+    t1.add_col_vec(bias_cov) # OxP
+    t1.apply_sigmoid(target = t2) # OxP
+    cmt.dot(FH,t2, target = t3) # HxP
+    t3.mult(feat)
+    cmt.dot(VF, t3, target = normgradient) # VxP
+    # final bprop through normalization
+    length.mult(lengthsq, target = normcoeff)
+    normcoeff.reciprocal() # 1xP
+    normgradient.mult(data, target = gradient) # VxP
+    gradient.sum(axis = 0, target = t4) # 1xP
+    t4.mult(-1./num_vis)
+    data.mult_by_row(t4, target = gradient)
+    normgradient.mult_by_row(lengthsq, target = t6)
+    gradient.add(t6)
+    gradient.mult_by_row(normcoeff)
+    # add quadratic term gradient
+    gradient.add(data)
+    # add visible bias term
+    gradient.add_col_mult(bias_vis, -1)
+    # add MEAN contribution to gradient
+    cmt.dot(w_mean.T, data, target = feat_mean) # HxP 
+    feat_mean.add_col_vec(bias_mean) # HxP
+    feat_mean.apply_sigmoid() # HxP
+    gradient.subtract_dot(w_mean,feat_mean) # VxP 
+
+############################################################3
+# Hybrid Monte Carlo sampler
+def draw_HMC_samples(data,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis):
+    vel.fill_with_randn()
+    negdata.assign(data)
+    compute_energy_mcRBM(negdata,normdata,vel,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis)
+    compute_gradient_mcRBM(negdata,normdata,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t3,t4,t6,feat,featsq,feat_mean,gradient,normgradient,length,lengthsq,normcoeff,small,num_vis)
+    # half step
+    vel.add_mult(gradient, -0.5*hmc_step)
+    negdata.add_mult(vel,hmc_step)
+    # full leap-frog steps
+    for ss in range(hmc_step_nr - 1):
+        ## re-evaluate the gradient
+        compute_gradient_mcRBM(negdata,normdata,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t3,t4,t6,feat,featsq,feat_mean,gradient,normgradient,length,lengthsq,normcoeff,small,num_vis)
+        # update variables
+        vel.add_mult(gradient, -hmc_step)
+        negdata.add_mult(vel,hmc_step)
+    # final half-step
+    compute_gradient_mcRBM(negdata,normdata,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t3,t4,t6,feat,featsq,feat_mean,gradient,normgradient,length,lengthsq,normcoeff,small,num_vis)
+    vel.add_mult(gradient, -0.5*hmc_step)
+    # compute new energy
+    compute_energy_mcRBM(negdata,normdata,vel,new_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis)
+    # rejecton
+    old_energy.subtract(new_energy, target = thresh)
+    cmt.exp(thresh)
+    t4.fill_with_rand()
+    t4.less_than(thresh)
+    #    update negdata and rejection rate
+    t4.mult(-1)
+    t4.add(1) # now 1's detect rejections
+    t4.sum(axis = 1, target = t5)
+    t5.copy_to_host()
+    rej = t5.numpy_array[0,0]/batch_size
+    data.mult_by_row(t4, target = t6)
+    negdata.mult_by_row(t4, target = t7)
+    negdata.subtract(t7)
+    negdata.add(t6)
+    hmc_ave_rej = 0.9*hmc_ave_rej + 0.1*rej
+    if hmc_ave_rej < hmc_target_ave_rej:
+        hmc_step = min(hmc_step*1.01,0.25)
+    else:
+        hmc_step = max(hmc_step*0.99,.001)
+    return hmc_step, hmc_ave_rej
+
+
+######################################################
+# mcRBM trainer: sweeps over the training set.
+# For each batch of samples compute derivatives to update the parameters
+# at the training samples and at the negative samples drawn calling HMC sampler.
+def train_mcRBM():
+
+    config = ConfigParser()
+    config.read('input_configuration')
+
+    verbose = config.getint('VERBOSITY','verbose')
+
+    num_epochs = config.getint('MAIN_PARAMETER_SETTING','num_epochs')
+    batch_size = config.getint('MAIN_PARAMETER_SETTING','batch_size')
+    startFH = config.getint('MAIN_PARAMETER_SETTING','startFH')
+    startwd = config.getint('MAIN_PARAMETER_SETTING','startwd')
+    doPCD = config.getint('MAIN_PARAMETER_SETTING','doPCD')
+
+    # model parameters
+    num_fac = config.getint('MODEL_PARAMETER_SETTING','num_fac')
+    num_hid_cov =  config.getint('MODEL_PARAMETER_SETTING','num_hid_cov')
+    num_hid_mean =  config.getint('MODEL_PARAMETER_SETTING','num_hid_mean')
+    apply_mask =  config.getint('MODEL_PARAMETER_SETTING','apply_mask')
+
+    # load data
+    data_file_name =  config.get('DATA','data_file_name')
+    d = loadmat(data_file_name) # input in the format PxD (P vectorized samples with D dimensions)
+    global demodata
+    demodata = d
+    totnumcases = d["whitendata"].shape[0]
+    d = d["whitendata"][0:np.floor(totnumcases/batch_size)*batch_size,:].copy() 
+    totnumcases = d.shape[0]
+    num_vis =  d.shape[1]
+    num_batches = int(totnumcases/batch_size)
+    dev_dat = cmt.CUDAMatrix(d.T) # VxP 
+
+    tile(d[:100], "100_whitened_data.png")
+
+    # training parameters
+    epsilon = config.getfloat('OPTIMIZER_PARAMETERS','epsilon')
+    epsilonVF = 2*epsilon
+    epsilonFH = 0.02*epsilon
+    epsilonb = 0.02*epsilon
+    epsilonw_mean = 0.2*epsilon
+    epsilonb_mean = 0.1*epsilon
+    weightcost_final =  config.getfloat('OPTIMIZER_PARAMETERS','weightcost_final')
+
+    # HMC setting
+    hmc_step_nr = config.getint('HMC_PARAMETERS','hmc_step_nr')
+    hmc_step =  0.01
+    hmc_target_ave_rej =  config.getfloat('HMC_PARAMETERS','hmc_target_ave_rej')
+    hmc_ave_rej =  hmc_target_ave_rej
+
+    # initialize weights
+    VF = cmt.CUDAMatrix(np.array(0.02 * np.random.randn(num_vis, num_fac), dtype=np.float32, order='F')) # VxH
+    if apply_mask == 0:
+        FH = cmt.CUDAMatrix( np.array( np.eye(num_fac,num_hid_cov), dtype=np.float32, order='F')  ) # HxO
+    else:
+        dd = loadmat('your_FHinit_mask_file.mat') # see CVPR2010paper_material/topo2D_3x3_stride2_576filt.mat for an example
+        FH = cmt.CUDAMatrix( np.array( dd["FH"], dtype=np.float32, order='F')  )
+    bias_cov = cmt.CUDAMatrix( np.array(2.0*np.ones((num_hid_cov, 1)), dtype=np.float32, order='F') )
+    bias_vis = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F') )
+    w_mean = cmt.CUDAMatrix( np.array( 0.05 * np.random.randn(num_vis, num_hid_mean), dtype=np.float32, order='F') ) # VxH
+    bias_mean = cmt.CUDAMatrix( np.array( -2.0*np.ones((num_hid_mean,1)), dtype=np.float32, order='F') )
+
+    # initialize variables to store derivatives 
+    VFinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_fac)), dtype=np.float32, order='F'))
+    FHinc = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F'))
+    bias_covinc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_cov, 1)), dtype=np.float32, order='F'))
+    bias_visinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F'))
+    w_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_hid_mean)), dtype=np.float32, order='F'))
+    bias_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_mean, 1)), dtype=np.float32, order='F'))
+
+    # initialize temporary storage
+    data = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    normdata = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    negdataini = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    feat = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
+    featsq = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
+    negdata = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F'))
+    old_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
+    new_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
+    gradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    normgradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    thresh = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
+    feat_mean = cmt.CUDAMatrix( np.array(np.empty((num_hid_mean, batch_size)), dtype=np.float32, order='F'))
+    vel = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F'))
+    length = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
+    lengthsq = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
+    normcoeff = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
+    if apply_mask==1: # this used to constrain very large FH matrices only allowing to change values in a neighborhood
+        dd = loadmat('your_FHinit_mask_file.mat') 
+        mask = cmt.CUDAMatrix( np.array(dd["mask"], dtype=np.float32, order='F'))
+    normVF = 1    
+    small = 0.5
+    
+    # other temporary vars
+    t1 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F'))
+    t2 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F'))
+    t3 = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
+    t4 = cmt.CUDAMatrix( np.array(np.empty((1,batch_size)), dtype=np.float32, order='F'))
+    t5 = cmt.CUDAMatrix( np.array(np.empty((1,1)), dtype=np.float32, order='F'))
+    t6 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F'))
+    t7 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F'))
+    t8 = cmt.CUDAMatrix( np.array(np.empty((num_vis, num_fac)), dtype=np.float32, order='F'))
+    t9 = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F'))
+    t10 = cmt.CUDAMatrix( np.array(np.empty((1,num_fac)), dtype=np.float32, order='F'))
+    t11 = cmt.CUDAMatrix( np.array(np.empty((1,num_hid_cov)), dtype=np.float32, order='F'))
+    # start training
+    for epoch in range(num_epochs):
+
+        def print_stuff():
+            print "VF: " + '%3.2e' % VF.euclid_norm() \
+                    + ", DVF: " + '%3.2e' % (VFinc.euclid_norm()*(epsilonVFc/batch_size))\
+                    + ", VF_inc: " + '%3.2e' % (VFinc.euclid_norm())\
+                    + ", FH: " + '%3.2e' % FH.euclid_norm() \
+                    + ", DFH: " + '%3.2e' % (FHinc.euclid_norm()*(epsilonFHc/batch_size)) \
+                    + ", bias_cov: " + '%3.2e' % bias_cov.euclid_norm() \
+                    + ", Dbias_cov: " + '%3.2e' % (bias_covinc.euclid_norm()*(epsilonbc/batch_size)) \
+                    + ", bias_vis: " + '%3.2e' % bias_vis.euclid_norm() \
+                    + ", Dbias_vis: " + '%3.2e' % (bias_visinc.euclid_norm()*(epsilonbc/batch_size)) \
+                    + ", wm: " + '%3.2e' % w_mean.euclid_norm() \
+                    + ", Dwm: " + '%3.2e' % (w_meaninc.euclid_norm()*(epsilonw_meanc/batch_size)) \
+                    + ", bm: " + '%3.2e' % bias_mean.euclid_norm() \
+                    + ", Dbm: " + '%3.2e' % (bias_meaninc.euclid_norm()*(epsilonb_meanc/batch_size)) \
+                    + ", step: " + '%3.2e' % hmc_step  \
+                    + ", rej: " + '%3.2e' % hmc_ave_rej 
+            sys.stdout.flush()
+
+        def save_stuff():
+            VF.copy_to_host()
+            FH.copy_to_host()
+            bias_cov.copy_to_host()
+            w_mean.copy_to_host()
+            bias_mean.copy_to_host()
+            bias_vis.copy_to_host()
+            savemat("ws_temp", {
+                'VF':VF.numpy_array,
+                'FH':FH.numpy_array,
+                'bias_cov': bias_cov.numpy_array, 
+                'bias_vis': bias_vis.numpy_array,
+                'w_mean': w_mean.numpy_array, 
+                'bias_mean': bias_mean.numpy_array,
+                'epoch':epoch})    
+
+            tile(VF.numpy_array.T, 'VF_%000i.png'%epoch)
+            tile(w_mean.numpy_array.T, 'w_mean_%000i.png'%epoch)
+            save_imshow(FH.numpy_array, 'FH_%000i.png'%epoch)
+
+        # anneal learning rates
+        epsilonVFc    = epsilonVF/max(1,epoch/20)
+        epsilonFHc    = epsilonFH/max(1,epoch/20)
+        epsilonbc    = epsilonb/max(1,epoch/20)
+        epsilonw_meanc = epsilonw_mean/max(1,epoch/20)
+        epsilonb_meanc = epsilonb_mean/max(1,epoch/20)
+        weightcost = weightcost_final
+
+        if epoch <= startFH:
+            epsilonFHc = 0 
+        if epoch <= startwd:	
+            weightcost = 0
+
+        print "Epoch " + str(epoch + 1), 'num_batches', num_batches
+        if epoch == 0:
+            print_stuff()
+
+        for batch in range(num_batches):
+
+            # get current minibatch
+            data = dev_dat.slice(batch*batch_size,(batch + 1)*batch_size) # DxP (nr dims x nr samples)
+
+            # normalize input data
+            data.mult(data, target = t6) # DxP
+            t6.sum(axis = 0, target = lengthsq) # 1xP
+            lengthsq.mult(1./num_vis) # normalize by number of components (like std)
+            lengthsq.add(small) # small avoids division by 0
+            cmt.sqrt(lengthsq, target = length)
+            length.reciprocal(target = normcoeff) # 1xP
+            data.mult_by_row(normcoeff, target = normdata) # normalized data 
+            ## compute positive sample derivatives
+            # covariance part
+            cmt.dot(VF.T, normdata, target = feat) # HxP (nr facs x nr samples)
+            feat.mult(feat, target = featsq)   # HxP
+            cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples)
+            t1.mult(-0.5)
+            t1.add_col_vec(bias_cov) # OxP
+            t1.apply_sigmoid(target = t2) # OxP
+            cmt.dot(featsq, t2.T, target = FHinc) # HxO
+            cmt.dot(FH,t2, target = t3) # HxP
+            t3.mult(feat)
+            cmt.dot(normdata, t3.T, target = VFinc) # VxH
+            t2.sum(axis = 1, target = bias_covinc)
+            bias_covinc.mult(-1)  
+            # visible bias
+            data.sum(axis = 1, target = bias_visinc)
+            bias_visinc.mult(-1)
+            # mean part
+            cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples)
+            feat_mean.add_col_vec(bias_mean) # HxP
+            feat_mean.apply_sigmoid() # HxP
+            feat_mean.mult(-1)
+            cmt.dot(data, feat_mean.T, target = w_meaninc)
+            feat_mean.sum(axis = 1, target = bias_meaninc)
+            
+            # HMC sampling: draw an approximate sample from the model
+            if doPCD == 0: # CD-1 (set negative data to current training samples)
+                hmc_step, hmc_ave_rej = draw_HMC_samples(data,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis)
+            else: # PCD-1 (use previous negative data as starting point for chain)
+                negdataini.assign(negdata)
+                hmc_step, hmc_ave_rej = draw_HMC_samples(negdataini,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis)
+                
+            # compute derivatives at the negative samples
+            # normalize input data
+            negdata.mult(negdata, target = t6) # DxP
+            t6.sum(axis = 0, target = lengthsq) # 1xP
+            lengthsq.mult(1./num_vis) # normalize by number of components (like std)
+            lengthsq.add(small)
+            cmt.sqrt(lengthsq, target = length)
+            length.reciprocal(target = normcoeff) # 1xP
+            negdata.mult_by_row(normcoeff, target = normdata) # normalized data 
+            # covariance part
+            cmt.dot(VF.T, normdata, target = feat) # HxP 
+            feat.mult(feat, target = featsq)   # HxP
+            cmt.dot(FH.T,featsq, target = t1) # OxP
+            t1.mult(-0.5)
+            t1.add_col_vec(bias_cov) # OxP
+            t1.apply_sigmoid(target = t2) # OxP
+            FHinc.subtract_dot(featsq, t2.T) # HxO
+            FHinc.mult(0.5)
+            cmt.dot(FH,t2, target = t3) # HxP
+            t3.mult(feat)
+            VFinc.subtract_dot(normdata, t3.T) # VxH
+            bias_covinc.add_sums(t2, axis = 1)
+            # visible bias
+            bias_visinc.add_sums(negdata, axis = 1)
+            # mean part
+            cmt.dot(w_mean.T, negdata, target = feat_mean) # HxP 
+            feat_mean.add_col_vec(bias_mean) # HxP
+            feat_mean.apply_sigmoid() # HxP
+            w_meaninc.add_dot(negdata, feat_mean.T)
+            bias_meaninc.add_sums(feat_mean, axis = 1)
+
+            # update parameters
+            VFinc.add_mult(VF.sign(), weightcost) # L1 regularization
+            VF.add_mult(VFinc, -epsilonVFc/batch_size)
+            # normalize columns of VF: normalize by running average of their norm 
+            VF.mult(VF, target = t8)
+            t8.sum(axis = 0, target = t10)
+            cmt.sqrt(t10)
+            t10.sum(axis=1,target = t5)
+            t5.copy_to_host()
+            normVF = .95*normVF + (.05/num_fac) * t5.numpy_array[0,0] # estimate norm
+            t10.reciprocal()
+            VF.mult_by_row(t10) 
+            VF.mult(normVF) 
+            bias_cov.add_mult(bias_covinc, -epsilonbc/batch_size)
+            bias_vis.add_mult(bias_visinc, -epsilonbc/batch_size)
+
+            if epoch > startFH:
+                FHinc.add_mult(FH.sign(), weightcost) # L1 regularization
+       		FH.add_mult(FHinc, -epsilonFHc/batch_size) # update
+	        # set to 0 negative entries in FH
+        	FH.greater_than(0, target = t9)
+	        FH.mult(t9)
+                if apply_mask==1:
+                    FH.mult(mask)
+		# normalize columns of FH: L1 norm set to 1 in each column
+		FH.sum(axis = 0, target = t11)               
+		t11.reciprocal()
+		FH.mult_by_row(t11) 
+            w_meaninc.add_mult(w_mean.sign(),weightcost)
+            w_mean.add_mult(w_meaninc, -epsilonw_meanc/batch_size)
+            bias_mean.add_mult(bias_meaninc, -epsilonb_meanc/batch_size)
+
+        if verbose == 1:
+            print_stuff()
+        # back-up every once in a while 
+        if np.mod(epoch,1) == 0:
+            save_stuff()
+    # final back-up
+    VF.copy_to_host()
+    FH.copy_to_host()
+    bias_cov.copy_to_host()
+    bias_vis.copy_to_host()
+    w_mean.copy_to_host()
+    bias_mean.copy_to_host()
+    savemat("ws_fac" + str(num_fac) + "_cov" + str(num_hid_cov) + "_mean" + str(num_hid_mean), {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array, 'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch})
+
+
+
+###################################33
+# main
+if __name__ == "__main__":
+  # initialize CUDA
+  #cmt.cuda_set_device(gpu_lock.obtain_lock_id()) # uncomment if you have a locking system or desire to choose the GPU board number
+  cmt.cublas_init()
+  cmt.CUDAMatrix.init_random(1)
+  train_mcRBM()
+  cmt.cublas_shutdown()
+
+
+
+
+
+
+
+