view pylearn/algorithms/mcRBM.py @ 984:5badf36a6daf

mcRBM - added notes to leading comment
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 13:50:26 -0400
parents 2a53384d9742
children 78b5bdf967f6
line wrap: on
line source

"""
This file implements the Mean & Covariance RBM discussed in 

    Ranzato, M. and Hinton, G. E. (2010)
    Modeling pixel means and covariances using factored third-order Boltzmann machines.
    IEEE Conference on Computer Vision and Pattern Recognition.

and performs one of the experiments on CIFAR-10 discussed in that paper.  There are some minor
discrepancies between the paper and the accompanying code (train_mcRBM.py), and the
accompanying code has been taken to be correct in those cases because I couldn't get things to
work otherwise.


Math
====

Energy of "covariance RBM"

    E = -0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i C_{if} v_i )^2
      = -0.5 \sum_f (\sum_k P_{fk} h_k) ( \sum_i C_{if} v_i )^2
                    "vector element f"           "vector element f"

In some parts of the paper, the P matrix is chosen to be a diagonal matrix with non-positive
diagonal entries, so it is helpful to see this as a simpler equation:

    E =  \sum_f h_f ( \sum_i C_{if} v_i )^2



Version in paper
----------------

Full Energy of the Mean and Covariance RBM, with 
:math:`h_k = h_k^{(c)}`,
:math:`g_j = h_j^{(m)}`,
:math:`b_k = b_k^{(c)}`,
:math:`c_j = b_j^{(m)}`,
:math:`U_{if} = C_{if}`,

    E (v, h, g) =
        - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i (U_{if} v_i) / |U_{.f}|*|v| )^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

For the energy function to correspond to a probability distribution, P must be non-positive.  P
is initialized to be a diagonal, and in our experience it can be left as such because even in
the paper it has a very low learning rate, and is only allowed to be updated after the filters
in U are learned (in effect).

Version in published train_mcRBM code
-------------------------------------

The train_mcRBM file implements learning in a similar but technically different Energy function:

    E (v, h, g) =
        - 0.5 \sum_f \sum_k P_{fk} h_k (\sum_i U_{if} v_i / sqrt(\sum_i v_i^2/I + 0.5))^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

There are two differences with respect to the paper:

    - 'v' is not normalized by its length, but rather it is normalized to have length close to
      the square root of the number of its components.  The variable called 'small' that
      "avoids division by zero" is orders larger than machine precision, and is on the order of
      the normalized sum-of-squares, so I've included it in the Energy function.

    - 'U' is also not normalized by its length.  U is initialized to have columns that are
      shorter than unit-length (approximately 0.2 with the 105 principle components in the
      train_mcRBM data).  During training, the columns of U are constrained manually to have
      equal lengths (see the use of normVF), but Euclidean norm is allowed to change.  During
      learning it quickly converges towards 1 and then exceeds 1.  It does not seem like this
      column-wise normalization of U is justified by maximum-likelihood, I have no intuition
      for why it is used.


Version in this code
--------------------

This file implements the same algorithm as the train_mcRBM code, except that the P matrix is
omitted for clarity, and replaced analytically with a negative identity matrix.

    E (v, h, g) =
        + 0.5 \sum_k h_k (\sum_i U_{ik} v_i / sqrt(\sum_i v_i^2/I + 0.5))^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

      

Conventions in this file
========================

This file contains some global functions, as well as a class (MeanCovRBM) that makes using them a little
more convenient.


Global functions like `free_energy` work on an mcRBM as parametrized in a particular way.
Suppose we have 
I input dimensions, 
F squared filters, 
J mean variables, and 
K covariance variables.
The mcRBM is parametrized by 5 variables:

 - `P`, a matrix (probably sparse) of pooling (F x K)
 - `U`, a matrix whose rows are visible covariance directions (I x F)
 - `W`, a matrix whose rows are visible mean directions (I x J)
 - `b`, a vector of hidden covariance biases (K)
 - `c`, a vector of hidden mean biases  (J)

Matrices are generally layed out and accessed according to a C-order convention.

"""

#
# WORKING NOTES
# THIS DERIVATION IS BASED ON THE ** PAPER ** ENERGY FUNCTION
# NOT THE ENERGY FUNCTION IN THE CODE!!!
#
# Free energy is the marginal energy of visible units
# Recall: 
#   Q(x) = exp(-E(x))/Z ==> -log(Q(x)) - log(Z) = E(x)
#
#
#   E (v, h, g) =
#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / |U_{*f}|^2 |v|^2
#       - \sum_k b_k h_k
#       + 0.5 \sum_i v_i^2
#       - \sum_j \sum_i W_{ij} g_j v_i
#       - \sum_j c_j g_j
#       - \sum_i a_i v_i
#
#
# Derivation, in which partition functions are ignored.
#
# E(v) = -\log(Q(v))
#  = -\log( \sum_{h,g} Q(v,h,g))
#  = -\log( \sum_{h,g} exp(-E(v,h,g)))
#  = -\log( \sum_{h,g} exp(-
#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
#       - \sum_k b_k h_k
#       + 0.5 \sum_i v_i^2
#       - \sum_j \sum_i W_{ij} g_j v_i
#       - \sum_j c_j g_j 
#       - \sum_i a_i v_i ))
#
# Get rid of double negs  in exp
#  = -\log(  \sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
#       + \sum_k b_k h_k
#       - 0.5 \sum_i v_i^2
#       ) * \sum_{g} exp(
#       + \sum_j \sum_i W_{ij} g_j v_i
#       + \sum_j c_j g_j))
#    - \sum_i a_i v_i 
#
# Break up log
#  = -\log(  \sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)
#       + \sum_k b_k h_k
#       ))
#    -\log( \sum_{g} exp(
#       + \sum_j \sum_i W_{ij} g_j v_i
#       + \sum_j c_j g_j )))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
# Use domain h is binary to turn log(sum(exp(sum...))) into sum(log(..
#  = -\log(\sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|* |v|)
#       + \sum_k b_k h_k
#       ))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
#  = - \sum_{k} \log(1 + exp(b_k + 0.5 \sum_f P_{fk}( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
# For negative-one-diagonal P this gives:
#
#  = - \sum_{k} \log(1 + exp(b_k - 0.5 \sum_i (U_{ik} v_i )^2 / (|U_{*k}|*|v|)))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 

import sys
import logging
import numpy as np
import numpy
from theano import function, shared, dot
from theano import tensor as TT
import theano.sparse #installs the sparse shared var handler
floatX = theano.config.floatX

from pylearn.sampling.hmc import HMC_sampler
from pylearn.io import image_tiling

from sparse_coding import numpy_project_onto_ball

print >> sys.stderr, "mcRBM IS NOT READY YET"


#TODO: This should be in the nnet part of the library
def sgd_updates(params, grads, lr):
    try:
        float(lr)
        lr = [lr for p in params]
    except TypeError:
        pass
    updates = [(p, p - plr * gp) for (plr, p, gp) in zip(lr, params, grads)]
    return updates

def as_shared(x, name=None, dtype=floatX):
    if hasattr(x, 'type'):
        return x
    else:
        if 'float' in str(x.dtype):
            return shared(x.astype(floatX), name=name)
        else:
            return shared(x, name=name)

def hidden_cov_units_preactivation_given_v(rbm, v, small=1e-8):
    (U,W,a,b,c) = rbm
    unit_v = v / (TT.sqrt(TT.sum(v**2, axis=1))+small).dimshuffle(0,'x') # unit rows
    unit_U = U  # assuming unit cols!
    #unit_U = U / (TT.sqrt(TT.sum(U**2, axis=0))+small)  #unit cols
    return b - 0.5 * dot(unit_v, unit_U)**2

def free_energy_given_v(rbm, v):
    """Returns theano expression for free energy of visible vector `v` in an mcRBM
    
    An mcRBM is parametrized
    by `U`, `W`, `b`, `c`.
    See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters.


    The free energy of v is what we need for learning and hybrid Monte-carlo negative-phase
    sampling.

    """
    U, W, a, b, c = rbm

    t0 = -TT.sum(TT.nnet.softplus(hidden_cov_units_preactivation_given_v(rbm, v)),axis=1)
    t1 = -TT.sum(TT.nnet.softplus(c + dot(v,W)), axis=1)
    t2 =  0.5 * TT.sum(v**2, axis=1)
    t3 = -TT.dot(v, a)
    return t0 + t1 + t2 + t3, (t0, t1, t2, t3)

def expected_h_g_given_v(P, U, W, b, c, v):
    """Returns theano expression conditional expectations (`h`, `g`) in an mcRBM.
    
    An mcRBM is parametrized
    by `U`, `W`, `b`, `c`.
    See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters.


    The conditional E[h, g | v] is what we need to classify images.
    """
    raise NotImplementedError()

    #TODO: check to see if these args should be negated?

    if P is None:
        h = nnet.sigmoid(b + 0.5 * cosines(v,U))
    else:
        h = nnet.sigmoid(b + 0.5 * dot(cosines(v,U), P))
    g = nnet.sigmoid(c + dot(v,W))
    return (h, g)

class MeanCovRBM(object):
    """Container for mcRBM parameters that gives more convenient access to mcRBM methods.
    """

    params = property(lambda s: [s.U, s.W, s.a, s.b, s.c])

    n_visible = property(lambda s: s.W.value.shape[0])

    def __init__(self, U, W, a, b, c):
        self.U = as_shared(U, 'U')
        self.W = as_shared(W, 'W')
        self.a = as_shared(a, 'a')
        self.b = as_shared(b, 'b')
        self.c = as_shared(c, 'c')

        assert self.b.type.dtype == 'float32'

    @classmethod
    def new_from_dims(cls, 
            n_I,  # input dimensionality
            n_K,  # number of covariance hidden units
            n_F,  # number of covariance filters (squared)
            n_J,  # number of mean filters (linear)
            seed = 8923402190,
            ):
        """
        Return a MeanCovRBM instance with randomly-initialized parameters.
        """

        
        if 0:
            if P_init == 'diag':
                if n_K != n_F:
                    raise ValueError('cannot use diagonal initialization of non-square P matrix')
                import scipy.sparse
                P =  -scipy.sparse.identity(n_K).tocsr()
            else:
                raise NotImplementedError()

        rng = np.random.RandomState(seed)

        # initialization taken from Marc'Aurelio

        return cls(
                #U = numpy_project_onto_ball(rng.randn(n_I, n_F).T).T,
                U = 0.2 * rng.randn(n_I, n_F),
                W = rng.randn(n_I, n_J)/np.sqrt((n_I+n_J)/2),
                a = np.ones(n_I)*(-2),
                b = np.ones(n_K)*2,
                c = np.zeros(n_J),)

    def __getstate__(self):
        # unpack shared containers, which may have references to Theano stuff
        # and are not a long-term stable data type.
        return dict(
                U = self.U.value,
                W = self.W.value,
                b = self.b.value,
                c = self.c.value)

    def __setstate__(self, dct):
        self.__init__(**dct) # calls as_shared on pickled arrays

    def hmc_sampler(self, n_particles=100, seed=7823748):
        return HMC_sampler(
            positions = [as_shared(
                np.random.RandomState(seed^20893).randn(
                    n_particles,
                    self.n_visible ))],
            energy_fn = lambda p : self.free_energy_given_v(p[0]),
            seed=seed)

    def free_energy_given_v(self, v, extra=False):
        rval = free_energy_given_v(self.params, v)
        if extra:
            return rval
        else:
            return rval[0]

    def contrastive_gradient(self, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0):
        """Return a list of gradient expressions for self.params

        :param pos_v: positive-phase sample of visible units
        :param neg_v: negative-phase sample of visible units
        """
        pos_FE = self.free_energy_given_v(pos_v)
        neg_FE = self.free_energy_given_v(neg_v)

        gpos_FE = theano.tensor.grad(pos_FE.sum(), self.params)
        gneg_FE = theano.tensor.grad(neg_FE.sum(), self.params)
        rval = [ gp - gn for (gp,gn) in zip(gpos_FE, gneg_FE)]
        rval[0] = rval[0] - TT.sign(self.U)*U_l1_penalty
        rval[1] = rval[1] - TT.sign(self.W)*W_l1_penalty
        return rval

from pylearn.dataset_ops.protocol import TensorFnDataset
from pylearn.dataset_ops.memo import memo
import scipy.io
@memo
def load_mcRBM_demo_patches():
    d = scipy.io.loadmat('/u/bergstrj/cvs/articles/2010/spike_slab_RBM/src/marcaurelio/training_colorpatches_16x16_demo.mat')
    totnumcases = d["whitendata"].shape[0]
    #d = d["whitendata"][0:np.floor(totnumcases/batch_size)*batch_size,:].copy() 
    d = d["whitendata"].copy()
    return d


if __name__ == '__main__':

    print >> sys.stderr, "TODO: use P matrix (aka FH matrix)"

    dataset='MAR'
    if dataset == 'MAR':
        R,C= 21,5
        n_patches=10240
        demodata = scipy.io.loadmat('/u/bergstrj/cvs/articles/2010/spike_slab_RBM/src/marcaurelio/training_colorpatches_16x16_demo.mat')
    else:
        R,C= 16,16 # the size of image patches
        n_patches=100000

    n_train_iters=30000

    n_burnin_steps=10000

    l1_penalty=1e-3
    no_l1_epochs = 10
    effective_l1_penalty=0.0

    epoch_size=50000
    batchsize = 128
    lr = 0.075 / batchsize
    s_lr = TT.scalar()
    s_l1_penalty=TT.scalar()
    n_K=256
    n_F=256
    n_J=100

    rbm = MeanCovRBM.new_from_dims(n_I=R*C,
            n_K=n_K,
            n_J=n_J, 
            n_F=n_F,
            ) 

    sampler = rbm.hmc_sampler(n_particles=batchsize)

    def l2(X):
        return (X**2).sum()
    def tile(X, fname):
        if dataset == 'MAR':
            X = np.dot(X, demodata['invpcatransf'].T)
            R=16
            C=16
            #X = X.reshape((X.shape[0], 3, 16, 16)).transpose([0,2,3,1]).copy()
            X = (X[:,:256], X[:,256:512], X[:,512:], None)
        _img = image_tiling.tile_raster_images(X,
                img_shape=(R,C),
                min_dynamic_range=1e-2)
        image_tiling.save_tiled_raster_images(_img, fname)
    #print "Burning in..."
    #for burnin in xrange(n_burnin_steps):
        #sampler.simulate()

    if 0:
        print "Just SAMPLING..."
        for jj in xrange(n_burnin_steps):
            if 0 == jj % 100:
                tile(sampler.positions[0].value, "sampler_%06i.png"%jj)
                tile(numpy.random.randn(100, 105), "random_%06i.png"%jj)
                print "burning in... ", jj
                sys.stdout.flush()
            sampler.simulate()

        sys.exit()

    batch_idx = TT.iscalar()

    if 0:
        from pylearn.dataset_ops import image_patches
        train_batch = image_patches.image_patches(
                s_idx = (batch_idx * batchsize + np.arange(batchsize)),
                dims = (n_patches,R,C),
                center=True,
                unitvar=True,
                dtype=floatX,
                rasterized=True)
    else:
        op = TensorFnDataset(floatX,
                bcast=(False,),
                fn=load_mcRBM_demo_patches,
                single_shape=(105,))
        train_batch = op((batch_idx * batchsize + np.arange(batchsize))%n_patches)

    imgs_fn = function([batch_idx], outputs=train_batch)

    grads = rbm.contrastive_gradient(
            pos_v=train_batch, 
            neg_v=sampler.positions[0],
            U_l1_penalty=s_l1_penalty,
            W_l1_penalty=s_l1_penalty)

    learn_fn = function([batch_idx, s_lr, s_l1_penalty], 
            outputs=[ 
                grads[0].norm(2),
                rbm.free_energy_given_v(train_batch).sum(),
                rbm.free_energy_given_v(train_batch,extra=1)[1][0].sum(),
                rbm.free_energy_given_v(train_batch,extra=1)[1][1].sum(),
                rbm.free_energy_given_v(train_batch,extra=1)[1][2].sum(),
                rbm.free_energy_given_v(train_batch,extra=1)[1][3].sum(),
                ],
            updates = sgd_updates(
                rbm.params,
                grads,
                lr=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ]))
    theano.printing.pydotprint(learn_fn, 'learn_fn.png')

    print "Learning..."
    normVF=1
    for jj in xrange(n_train_iters):

        print_jj = ((1 and jj < 100) 
                or (0 and jj < 100 and 0==jj%10) 
                or (jj < 1000 and 0==jj%100)
                or (1 and jj < 10000 and 0==jj%1000))


        if print_jj:
            tile(imgs_fn(jj), "imgs_%06i.png"%jj)
            tile(sampler.positions[0].value, "sample_%06i.png"%jj)
            tile(rbm.U.value.T, "U_%06i.png"%jj)
            tile(rbm.W.value.T, "W_%06i.png"%jj)

            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize), 
            print 'l2(U)', l2(rbm.U.value),
            print 'l2(W)', l2(rbm.W.value),
            print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
            print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
            print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
            print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
            print 'c min max', rbm.c.value.min(), rbm.c.value.max(),

            print 'parts min', sampler.positions[0].value.min(), 
            print 'max',sampler.positions[0].value.max(),
            print 'HMC step', sampler.stepsize,
            print 'arate', sampler.avg_acceptance_rate

        sampler.simulate()

        l2_of_Ugrad = learn_fn(jj, 
                lr/max(1, jj/(20*epoch_size/batchsize)),
                effective_l1_penalty)

        if print_jj:
            print 'l2(gU)', float(l2_of_Ugrad[0]),
            print 'FE+', float(l2_of_Ugrad[1]),
            print 'FE+[0]', float(l2_of_Ugrad[2]),
            print 'FE+[1]', float(l2_of_Ugrad[3]),
            print 'FE+[2]', float(l2_of_Ugrad[4]),
            print 'FE+[3]', float(l2_of_Ugrad[5]),

        if jj == no_l1_epochs * epoch_size/batchsize:
            print "Activating L1 weight decay"
            effective_l1_penalty = 1e-3

        if 0:
            rbm.U.value = numpy_project_onto_ball(rbm.U.value.T).T
        else:
            # weird normalization technique...
            # It constrains all the columns of the matrix to have the same length
            # But the matrix itself is re-scaled to have an arbitrary abslute size.
            U = rbm.U.value
            U_norms = np.sqrt((U*U).sum(axis=0))
            assert len(U_norms) == n_F
            normVF = .95 * normVF + .05 * np.mean(U_norms)
            rbm.U.value = rbm.U.value * normVF/U_norms


#
#
# Marc'Aurelio Ranzato's code
#
######################################################################
# compute the value of the free energy at a given input
# F = - sum log(1+exp(- .5 FH (VF data/norm(data))^2 + bias_cov)) +...
#     - sum log(1+exp(w_mean data + bias_mean)) + ...
#     - bias_vis data + 0.5 data^2
# NOTE: FH is constrained to be positive 
# (in the paper the sign is negative but the sign in front of it is also flipped)
def compute_energy_mcRBM(data,normdata,vel,energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis):
    # normalize input data vectors
    data.mult(data, target = t6) # DxP (nr input dims x nr samples)
    t6.sum(axis = 0, target = lengthsq) # 1xP
    lengthsq.mult(0.5, target = energy) # energy of quadratic regularization term   
    lengthsq.mult(1./num_vis) # normalize by number of components (like std)    

    lengthsq.add(small) # small prevents division by 0
    # energy_j = \sum_i 0.5 data_ij ^2
    # lengthsq_j = 1/ (\sum_i data_ij ^2 + small)
    cmt.sqrt(lengthsq, target = length) 
    # length_j = sqrt(lengthsq_j)
    length.reciprocal(target = normcoeff) # 1xP
    # normcoef_j = 1/sqrt(lengthsq_j)
    data.mult_by_row(normcoeff, target = normdata) # normalized data    
    # normdata is like data, but cols have unit L2 norm

    ## potential
    # covariance contribution
    cmt.dot(VF.T, normdata, target = feat) # HxP (nr factors x nr samples)
    feat.mult(feat, target = featsq)   # HxP

    # featsq is the squared cosines (VF with data)
    cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples)
    t1.mult(-0.5)
    t1.add_col_vec(bias_cov) # OxP
    cmt.exp(t1) # OxP
    t1.add(1, target = t2) # OxP
    cmt.log(t2)
    t2.mult(-1)
    energy.add_sums(t2, axis=0)
    # mean contribution
    cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples)
    feat_mean.add_col_vec(bias_mean) # HxP
    cmt.exp(feat_mean) 
    feat_mean.add(1)
    cmt.log(feat_mean)
    feat_mean.mult(-1)
    energy.add_sums(feat_mean,  axis=0)
    # visible bias term
    data.mult_by_col(bias_vis, target = t6)
    t6.mult(-1) # DxP
    energy.add_sums(t6,  axis=0) # 1xP
    # kinetic
    vel.mult(vel, target = t6)
    energy.add_sums(t6, axis = 0, mult = .5)

######################################################
# mcRBM trainer: sweeps over the training set.
# For each batch of samples compute derivatives to update the parameters
# at the training samples and at the negative samples drawn calling HMC sampler.
def train_mcRBM():

    config = ConfigParser()
    config.read('input_configuration')

    verbose = config.getint('VERBOSITY','verbose')

    num_epochs = config.getint('MAIN_PARAMETER_SETTING','num_epochs')
    batch_size = config.getint('MAIN_PARAMETER_SETTING','batch_size')
    startFH = config.getint('MAIN_PARAMETER_SETTING','startFH')
    startwd = config.getint('MAIN_PARAMETER_SETTING','startwd')
    doPCD = config.getint('MAIN_PARAMETER_SETTING','doPCD')

    # model parameters
    num_fac = config.getint('MODEL_PARAMETER_SETTING','num_fac')
    num_hid_cov =  config.getint('MODEL_PARAMETER_SETTING','num_hid_cov')
    num_hid_mean =  config.getint('MODEL_PARAMETER_SETTING','num_hid_mean')
    apply_mask =  config.getint('MODEL_PARAMETER_SETTING','apply_mask')

    # load data
    data_file_name =  config.get('DATA','data_file_name')
    d = loadmat(data_file_name) # input in the format PxD (P vectorized samples with D dimensions)
    totnumcases = d["whitendata"].shape[0]
    d = d["whitendata"][0:floor(totnumcases/batch_size)*batch_size,:].copy() 
    totnumcases = d.shape[0]
    num_vis =  d.shape[1]
    num_batches = int(totnumcases/batch_size)
    dev_dat = cmt.CUDAMatrix(d.T) # VxP 

    # training parameters
    epsilon = config.getfloat('OPTIMIZER_PARAMETERS','epsilon')
    epsilonVF = 2*epsilon
    epsilonFH = 0.02*epsilon
    epsilonb = 0.02*epsilon
    epsilonw_mean = 0.2*epsilon
    epsilonb_mean = 0.1*epsilon
    weightcost_final =  config.getfloat('OPTIMIZER_PARAMETERS','weightcost_final')

    # HMC setting
    hmc_step_nr = config.getint('HMC_PARAMETERS','hmc_step_nr')
    hmc_step =  0.01
    hmc_target_ave_rej =  config.getfloat('HMC_PARAMETERS','hmc_target_ave_rej')
    hmc_ave_rej =  hmc_target_ave_rej

    # initialize weights
    VF = cmt.CUDAMatrix(np.array(0.02 * np.random.randn(num_vis, num_fac), dtype=np.float32, order='F')) # VxH
    if apply_mask == 0:
        FH = cmt.CUDAMatrix( np.array( np.eye(num_fac,num_hid_cov), dtype=np.float32, order='F')  ) # HxO
    else:
        dd = loadmat('your_FHinit_mask_file.mat') # see CVPR2010paper_material/topo2D_3x3_stride2_576filt.mat for an example
        FH = cmt.CUDAMatrix( np.array( dd["FH"], dtype=np.float32, order='F')  )
    bias_cov = cmt.CUDAMatrix( np.array(2.0*np.ones((num_hid_cov, 1)), dtype=np.float32, order='F') )
    bias_vis = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F') )
    w_mean = cmt.CUDAMatrix( np.array( 0.05 * np.random.randn(num_vis, num_hid_mean), dtype=np.float32, order='F') ) # VxH
    bias_mean = cmt.CUDAMatrix( np.array( -2.0*np.ones((num_hid_mean,1)), dtype=np.float32, order='F') )

    # initialize variables to store derivatives 
    VFinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_fac)), dtype=np.float32, order='F'))
    FHinc = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F'))
    bias_covinc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_cov, 1)), dtype=np.float32, order='F'))
    bias_visinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F'))
    w_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_hid_mean)), dtype=np.float32, order='F'))
    bias_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_mean, 1)), dtype=np.float32, order='F'))

    # initialize temporary storage
    data = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    normdata = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    negdataini = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    feat = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
    featsq = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
    negdata = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F'))
    old_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
    new_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
    gradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    normgradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    thresh = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
    feat_mean = cmt.CUDAMatrix( np.array(np.empty((num_hid_mean, batch_size)), dtype=np.float32, order='F'))
    vel = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F'))
    length = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
    lengthsq = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
    normcoeff = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
    if apply_mask==1: # this used to constrain very large FH matrices only allowing to change values in a neighborhood
        dd = loadmat('your_FHinit_mask_file.mat') 
        mask = cmt.CUDAMatrix( np.array(dd["mask"], dtype=np.float32, order='F'))
    normVF = 1    
    small = 0.5
    
    # other temporary vars
    t1 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F'))
    t2 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F'))
    t3 = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
    t4 = cmt.CUDAMatrix( np.array(np.empty((1,batch_size)), dtype=np.float32, order='F'))
    t5 = cmt.CUDAMatrix( np.array(np.empty((1,1)), dtype=np.float32, order='F'))
    t6 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F'))
    t7 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F'))
    t8 = cmt.CUDAMatrix( np.array(np.empty((num_vis, num_fac)), dtype=np.float32, order='F'))
    t9 = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F'))
    t10 = cmt.CUDAMatrix( np.array(np.empty((1,num_fac)), dtype=np.float32, order='F'))
    t11 = cmt.CUDAMatrix( np.array(np.empty((1,num_hid_cov)), dtype=np.float32, order='F'))

    # start training
    for epoch in range(num_epochs):

        print "Epoch " + str(epoch + 1)
    
        # anneal learning rates
        epsilonVFc    = epsilonVF/max(1,epoch/20)
        epsilonFHc    = epsilonFH/max(1,epoch/20)
        epsilonbc    = epsilonb/max(1,epoch/20)
        epsilonw_meanc = epsilonw_mean/max(1,epoch/20)
        epsilonb_meanc = epsilonb_mean/max(1,epoch/20)
        weightcost = weightcost_final

        if epoch <= startFH:
            epsilonFHc = 0 
        if epoch <= startwd:	
            weightcost = 0

        for batch in range(num_batches):

            # get current minibatch
            data = dev_dat.slice(batch*batch_size,(batch + 1)*batch_size) # DxP (nr dims x nr samples)

            # normalize input data
            data.mult(data, target = t6) # DxP
            t6.sum(axis = 0, target = lengthsq) # 1xP
            lengthsq.mult(1./num_vis) # normalize by number of components (like std)
            lengthsq.add(small) # small avoids division by 0
            cmt.sqrt(lengthsq, target = length)
            length.reciprocal(target = normcoeff) # 1xP
            data.mult_by_row(normcoeff, target = normdata) # normalized data 
            ## compute positive sample derivatives
            # covariance part
            cmt.dot(VF.T, normdata, target = feat) # HxP (nr facs x nr samples)
            feat.mult(feat, target = featsq)   # HxP
            cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples)
            t1.mult(-0.5)
            t1.add_col_vec(bias_cov) # OxP
            t1.apply_sigmoid(target = t2) # OxP
            cmt.dot(featsq, t2.T, target = FHinc) # HxO
            cmt.dot(FH,t2, target = t3) # HxP
            t3.mult(feat)
            cmt.dot(normdata, t3.T, target = VFinc) # VxH
            t2.sum(axis = 1, target = bias_covinc)
            bias_covinc.mult(-1)  
            # visible bias
            data.sum(axis = 1, target = bias_visinc)
            bias_visinc.mult(-1)
            # mean part
            cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples)
            feat_mean.add_col_vec(bias_mean) # HxP
            feat_mean.apply_sigmoid() # HxP
            feat_mean.mult(-1)
            cmt.dot(data, feat_mean.T, target = w_meaninc)
            feat_mean.sum(axis = 1, target = bias_meaninc)
            
            # HMC sampling: draw an approximate sample from the model
            if doPCD == 0: # CD-1 (set negative data to current training samples)
                hmc_step, hmc_ave_rej = draw_HMC_samples(data,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis)
            else: # PCD-1 (use previous negative data as starting point for chain)
                negdataini.assign(negdata)
                hmc_step, hmc_ave_rej = draw_HMC_samples(negdataini,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis)
                
            # compute derivatives at the negative samples
            # normalize input data
            negdata.mult(negdata, target = t6) # DxP
            t6.sum(axis = 0, target = lengthsq) # 1xP
            lengthsq.mult(1./num_vis) # normalize by number of components (like std)
            lengthsq.add(small)
            cmt.sqrt(lengthsq, target = length)
            length.reciprocal(target = normcoeff) # 1xP
            negdata.mult_by_row(normcoeff, target = normdata) # normalized data 
            # covariance part
            cmt.dot(VF.T, normdata, target = feat) # HxP 
            feat.mult(feat, target = featsq)   # HxP
            cmt.dot(FH.T,featsq, target = t1) # OxP
            t1.mult(-0.5)
            t1.add_col_vec(bias_cov) # OxP
            t1.apply_sigmoid(target = t2) # OxP
            FHinc.subtract_dot(featsq, t2.T) # HxO
            FHinc.mult(0.5)
            cmt.dot(FH,t2, target = t3) # HxP
            t3.mult(feat)
            VFinc.subtract_dot(normdata, t3.T) # VxH
            bias_covinc.add_sums(t2, axis = 1)
            # visible bias
            bias_visinc.add_sums(negdata, axis = 1)
            # mean part
            cmt.dot(w_mean.T, negdata, target = feat_mean) # HxP 
            feat_mean.add_col_vec(bias_mean) # HxP
            feat_mean.apply_sigmoid() # HxP
            w_meaninc.add_dot(negdata, feat_mean.T)
            bias_meaninc.add_sums(feat_mean, axis = 1)

            # update parameters
            VFinc.add_mult(VF.sign(), weightcost) # L1 regularization
            VF.add_mult(VFinc, -epsilonVFc/batch_size)
            # normalize columns of VF: normalize by running average of their norm 
            VF.mult(VF, target = t8)
            t8.sum(axis = 0, target = t10)
            cmt.sqrt(t10)
            t10.sum(axis=1,target = t5)
            t5.copy_to_host()
            normVF = .95*normVF + (.05/num_fac) * t5.numpy_array[0,0] # estimate norm
            t10.reciprocal()
            VF.mult_by_row(t10) 
            VF.mult(normVF) 
            bias_cov.add_mult(bias_covinc, -epsilonbc/batch_size)
            bias_vis.add_mult(bias_visinc, -epsilonbc/batch_size)

            if epoch > startFH:
                FHinc.add_mult(FH.sign(), weightcost) # L1 regularization
       		FH.add_mult(FHinc, -epsilonFHc/batch_size) # update
	        # set to 0 negative entries in FH
        	FH.greater_than(0, target = t9)
	        FH.mult(t9)
                if apply_mask==1:
                    FH.mult(mask)
		# normalize columns of FH: L1 norm set to 1 in each column
		FH.sum(axis = 0, target = t11)               
		t11.reciprocal()
		FH.mult_by_row(t11) 
            w_meaninc.add_mult(w_mean.sign(),weightcost)
            w_mean.add_mult(w_meaninc, -epsilonw_meanc/batch_size)
            bias_mean.add_mult(bias_meaninc, -epsilonb_meanc/batch_size)

        if verbose == 1:
            print "VF: " + '%3.2e' % VF.euclid_norm() + ", DVF: " + '%3.2e' % (VFinc.euclid_norm()*(epsilonVFc/batch_size)) + ", FH: " + '%3.2e' % FH.euclid_norm() + ", DFH: " + '%3.2e' % (FHinc.euclid_norm()*(epsilonFHc/batch_size)) + ", bias_cov: " + '%3.2e' % bias_cov.euclid_norm() + ", Dbias_cov: " + '%3.2e' % (bias_covinc.euclid_norm()*(epsilonbc/batch_size)) + ", bias_vis: " + '%3.2e' % bias_vis.euclid_norm() + ", Dbias_vis: " + '%3.2e' % (bias_visinc.euclid_norm()*(epsilonbc/batch_size)) + ", wm: " + '%3.2e' % w_mean.euclid_norm() + ", Dwm: " + '%3.2e' % (w_meaninc.euclid_norm()*(epsilonw_meanc/batch_size)) + ", bm: " + '%3.2e' % bias_mean.euclid_norm() + ", Dbm: " + '%3.2e' % (bias_meaninc.euclid_norm()*(epsilonb_meanc/batch_size)) + ", step: " + '%3.2e' % hmc_step  +  ", rej: " + '%3.2e' % hmc_ave_rej 
            sys.stdout.flush()
        # back-up every once in a while 
        if np.mod(epoch,10) == 0:
            VF.copy_to_host()
            FH.copy_to_host()
            bias_cov.copy_to_host()
            w_mean.copy_to_host()
            bias_mean.copy_to_host()
            bias_vis.copy_to_host()
            savemat("ws_temp", {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array,'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch})    
    # final back-up
    VF.copy_to_host()
    FH.copy_to_host()
    bias_cov.copy_to_host()
    bias_vis.copy_to_host()
    w_mean.copy_to_host()
    bias_mean.copy_to_host()
    savemat("ws_fac" + str(num_fac) + "_cov" + str(num_hid_cov) + "_mean" + str(num_hid_mean), {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array, 'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch})