view pylearn/algorithms/mcRBM.py @ 977:9cac1ecaeef7

mcRBM - changed init of U to match M'A.R's code
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 23 Aug 2010 16:04:10 -0400
parents 4cbd65cf902d
children ab4bc97ca060
line wrap: on
line source

"""
This file implements the Mean & Covariance RBM discussed in 

    Ranzato, M. and Hinton, G. E. (2010)
    Modeling pixel means and covariances using factored third-order Boltzmann machines.
    IEEE Conference on Computer Vision and Pattern Recognition.

and performs one of the experiments on CIFAR-10 discussed in that paper.


Math
====

Energy of "covariance RBM"

    E = -0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i C_{if} v_i )^2
      = -0.5 \sum_f (\sum_k P_{fk} h_k) ( \sum_i C_{if} v_i )^2
                    "vector element f"           "vector element f"

In some parts of the paper, the P matrix is chosen to be a diagonal matrix with non-positive
diagonal entries, so it is helpful to see this as a simpler equation:

    E =  \sum_f h_f ( \sum_i C_{if} v_i )^2



Full Energy of mean and Covariance RBM, with 
:math:`h_k = h_k^{(c)}`,
:math:`g_j = h_j^{(m)}`,
:math:`b_k = b_k^{(c)}`,
:math:`c_j = b_j^{(m)}`,
:math:`U_{if} = C_{if}`,

:

    E (v, h, g) =
        - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / |U_{*f}|^2 |v|^2
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

For the energy function to correspond to a probability distribution, P must be non-positive.


Conventions in this file
========================

This file contains some global functions, as well as a class (MeanCovRBM) that makes using them a little
more convenient.


Global functions like `free_energy` work on an mcRBM as parametrized in a particular way.
Suppose we have 
I input dimensions, 
F squared filters, 
J mean variables, and 
K covariance variables.
The mcRBM is parametrized by 5 variables:

 - `P`, a matrix (probably sparse) of pooling (F x K)
 - `U`, a matrix whose rows are visible covariance directions (I x F)
 - `W`, a matrix whose rows are visible mean directions (I x J)
 - `b`, a vector of hidden covariance biases (K)
 - `c`, a vector of hidden mean biases  (J)

Matrices are generally layed out according to a C-order convention.

"""

# Free energy is the marginal energy of visible units
# Recall: 
#   Q(x) = exp(-E(x))/Z ==> -log(Q(x)) - log(Z) = E(x)
#
#
#   E (v, h, g) =
#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / |U_{*f}|^2 |v|^2
#       - \sum_k b_k h_k
#       + 0.5 \sum_i v_i^2
#       - \sum_j \sum_i W_{ij} g_j v_i
#       - \sum_j c_j g_j
#       - \sum_i a_i v_i
#
#
# Derivation, in which partition functions are ignored.
#
# E(v) = -\log(Q(v))
#  = -\log( \sum_{h,g} Q(v,h,g))
#  = -\log( \sum_{h,g} exp(-E(v,h,g)))
#  = -\log( \sum_{h,g} exp(-
#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
#       - \sum_k b_k h_k
#       + 0.5 \sum_i v_i^2
#       - \sum_j \sum_i W_{ij} g_j v_i
#       - \sum_j c_j g_j 
#       - \sum_i a_i v_i ))
#
# Get rid of double negs  in exp
#  = -\log(  \sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
#       + \sum_k b_k h_k
#       - 0.5 \sum_i v_i^2
#       ) * \sum_{g} exp(
#       + \sum_j \sum_i W_{ij} g_j v_i
#       + \sum_j c_j g_j))
#    - \sum_i a_i v_i 
#
# Break up log
#  = -\log(  \sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)
#       + \sum_k b_k h_k
#       ))
#    -\log( \sum_{g} exp(
#       + \sum_j \sum_i W_{ij} g_j v_i
#       + \sum_j c_j g_j )))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
# Use domain h is binary to turn log(sum(exp(sum...))) into sum(log(..
#  = -\log(\sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|* |v|)
#       + \sum_k b_k h_k
#       ))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
#  = - \sum_{k} \log(1 + exp(b_k + 0.5 \sum_f P_{fk}( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
# For negative-one-diagonal P this gives:
#
#  = - \sum_{k} \log(1 + exp(b_k - 0.5 \sum_i (U_{ik} v_i )^2 / (|U_{*k}|*|v|)))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 

import sys
import logging
import numpy as np
import numpy
from theano import function, shared, dot
from theano import tensor as TT
import theano.sparse #installs the sparse shared var handler
floatX = theano.config.floatX

from pylearn.sampling.hmc import HMC_sampler
from pylearn.io import image_tiling

from sparse_coding import numpy_project_onto_ball

print >> sys.stderr, "mcRBM IS NOT READY YET"


#TODO: This should be in the nnet part of the library
def sgd_updates(params, grads, lr):
    try:
        float(lr)
        lr = [lr for p in params]
    except TypeError:
        pass
    updates = [(p, p - plr * gp) for (plr, p, gp) in zip(lr, params, grads)]
    return updates

def as_shared(x, name=None, dtype=floatX):
    if hasattr(x, 'type'):
        return x
    else:
        if 'float' in str(x.dtype):
            return shared(x.astype(floatX), name=name)
        else:
            return shared(x, name=name)

def hidden_cov_units_preactivation_given_v(rbm, v, small=1e-8):
    (U,W,a,b,c) = rbm
    unit_v = v / (TT.sqrt(TT.sum(v**2, axis=1))+small).dimshuffle(0,'x') # unit rows
    unit_U = U  # assuming unit cols!
    #unit_U = U / (TT.sqrt(TT.sum(U**2, axis=0))+small)  #unit cols
    return b - 0.5 * dot(unit_v, unit_U)**2

def free_energy_given_v(rbm, v):
    """Returns theano expression for free energy of visible vector `v` in an mcRBM
    
    An mcRBM is parametrized
    by `U`, `W`, `b`, `c`.
    See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters.


    The free energy of v is what we need for learning and hybrid Monte-carlo negative-phase
    sampling.

    """
    U, W, a, b, c = rbm

    t0 = -TT.sum(TT.nnet.softplus(hidden_cov_units_preactivation_given_v(rbm, v)),axis=1)
    t1 = -TT.sum(TT.nnet.softplus(c + dot(v,W)), axis=1)
    t2 =  0.5 * TT.sum(v**2, axis=1)
    t3 = -TT.dot(v, a)
    return t0 + t1 + t2 + t3, (t0, t1, t2, t3)

def expected_h_g_given_v(P, U, W, b, c, v):
    """Returns theano expression conditional expectations (`h`, `g`) in an mcRBM.
    
    An mcRBM is parametrized
    by `U`, `W`, `b`, `c`.
    See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters.


    The conditional E[h, g | v] is what we need to classify images.
    """
    raise NotImplementedError()

    #TODO: check to see if these args should be negated?

    if P is None:
        h = nnet.sigmoid(b + 0.5 * cosines(v,U))
    else:
        h = nnet.sigmoid(b + 0.5 * dot(cosines(v,U), P))
    g = nnet.sigmoid(c + dot(v,W))
    return (h, g)

class MeanCovRBM(object):
    """Container for mcRBM parameters that gives more convenient access to mcRBM methods.
    """

    params = property(lambda s: [s.U, s.W, s.a, s.b, s.c])

    n_visible = property(lambda s: s.W.value.shape[0])

    def __init__(self, U, W, a, b, c):
        self.U = as_shared(U, 'U')
        self.W = as_shared(W, 'W')
        self.a = as_shared(a, 'a')
        self.b = as_shared(b, 'b')
        self.c = as_shared(c, 'c')

        assert self.b.type.dtype == 'float32'

    @classmethod
    def new_from_dims(cls, 
            n_I,  # input dimensionality
            n_K,  # number of covariance hidden units
            n_F,  # number of covariance filters (squared)
            n_J,  # number of mean filters (linear)
            seed = 8923402190,
            ):
        """
        Return a MeanCovRBM instance with randomly-initialized parameters.
        """

        
        if 0:
            if P_init == 'diag':
                if n_K != n_F:
                    raise ValueError('cannot use diagonal initialization of non-square P matrix')
                import scipy.sparse
                P =  -scipy.sparse.identity(n_K).tocsr()
            else:
                raise NotImplementedError()

        rng = np.random.RandomState(seed)

        # initialization taken from Marc'Aurelio

        return cls(
                #U = numpy_project_onto_ball(rng.randn(n_I, n_F).T).T,
                U = 0.2 * rng.randn(n_I, n_F),
                W = rng.randn(n_I, n_J)/np.sqrt((n_I+n_J)/2),
                a = np.ones(n_I)*(-2),
                b = np.ones(n_K)*2,
                c = np.zeros(n_J),)

    def __getstate__(self):
        # unpack shared containers, which may have references to Theano stuff
        # and are not a long-term stable data type.
        return dict(
                U = self.U.value,
                W = self.W.value,
                b = self.b.value,
                c = self.c.value)

    def __setstate__(self, dct):
        self.__init__(**dct) # calls as_shared on pickled arrays

    def hmc_sampler(self, n_particles=100, seed=7823748):
        return HMC_sampler(
            positions = [as_shared(
                np.random.RandomState(seed^20893).rand(
                    n_particles,
                    self.n_visible ))],
            energy_fn = lambda p : self.free_energy_given_v(p[0]),
            seed=seed)

    def free_energy_given_v(self, v, extra=False):
        rval = free_energy_given_v(self.params, v)
        if extra:
            return rval
        else:
            return rval[0]

    def contrastive_gradient(self, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0):
        """Return a list of gradient expressions for self.params

        :param pos_v: positive-phase sample of visible units
        :param neg_v: negative-phase sample of visible units
        """
        pos_FE = self.free_energy_given_v(pos_v)
        neg_FE = self.free_energy_given_v(neg_v)

        gpos_FE = theano.tensor.grad(pos_FE.sum(), self.params)
        gneg_FE = theano.tensor.grad(neg_FE.sum(), self.params)
        rval = [ gp - gn for (gp,gn) in zip(gpos_FE, gneg_FE)]
        rval[0] = rval[0] - TT.sign(self.U)*U_l1_penalty
        rval[1] = rval[1] - TT.sign(self.W)*W_l1_penalty
        return rval

from pylearn.dataset_ops.protocol import TensorFnDataset
from pylearn.dataset_ops.memo import memo
import scipy.io
@memo
def load_mcRBM_demo_patches():
    d = scipy.io.loadmat('/u/bergstrj/cvs/articles/2010/spike_slab_RBM/src/marcaurelio/training_colorpatches_16x16_demo.mat')
    totnumcases = d["whitendata"].shape[0]
    #d = d["whitendata"][0:np.floor(totnumcases/batch_size)*batch_size,:].copy() 
    d = d["whitendata"].copy()
    return d


if __name__ == '__main__':

    print >> sys.stderr, "TODO: use P matrix (aka FH matrix)"

    R,C= 8,8 # the size of image patches
    l1_penalty=1e-3
    no_l1_epochs = 10

    epoch_size=50000
    batchsize = 128
    lr = 0.075 / batchsize
    s_lr = TT.scalar()
    n_K=256
    n_F=256
    n_J=100

    rbm = MeanCovRBM.new_from_dims(n_I=R*C,
            n_K=n_K,
            n_J=n_J, 
            n_F=n_F,
            ) 

    sampler = rbm.hmc_sampler(n_particles=100)

    from pylearn.dataset_ops import image_patches

    batch_idx = TT.iscalar()
    train_batch = image_patches.image_patches(
            s_idx = (batch_idx * batchsize + np.arange(batchsize)),
            dims = (1000,R,C),
            dtype=floatX,
            rasterized=True)

    grads = rbm.contrastive_gradient(pos_v=train_batch, neg_v=sampler.positions[0])

    learn_fn = function([batch_idx, s_lr], 
            outputs=[ 
                grads[0].norm(2),
                rbm.U.norm(2)
                ],
            updates = sgd_updates(
                rbm.params,
                grads,
                lr=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ]))

    for jj in xrange(10000):
        sampler.simulate()
        l2_of_Ugrad = learn_fn(jj, lr/max(1, jj/(20*epoch_size/batchsize)))

        if jj > no_l1_epochs * epoch_size/batchsize:
            rbm.U.value -= l1_penalty * np.sign(rbm.U.value)
            rbm.W.value -= l1_penalty * np.sign(rbm.W.value)

        if jj % 5 == 0:
            rbm.U.value = numpy_project_onto_ball(rbm.U.value.T).T

        if ((jj < 10) 
                or (jj < 100 and 0==jj%10) 
                or (jj < 1000 and 0==jj%100)
                or (jj < 10000 and 0==jj%1000)):
            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize), l2_of_Ugrad
            print 'neg particles', sampler.positions[0].value.min(), sampler.positions[0].value.max()
            image_tiling.save_tiled_raster_images(
                image_tiling.tile_raster_images(sampler.positions[0].value, (R,C)),
                "sample_%06i.png"%jj)
            image_tiling.save_tiled_raster_images(
                image_tiling.tile_raster_images(rbm.U.value.T, (R,C)),
                "U_%06i.png"%jj)
            image_tiling.save_tiled_raster_images(
                image_tiling.tile_raster_images(rbm.W.value.T, (R,C)),
                "W_%06i.png"%jj)



#
#
# Marc'Aurelio Ranzato's code
#
######################################################################
# compute the value of the free energy at a given input
# F = - sum log(1+exp(- .5 FH (VF data/norm(data))^2 + bias_cov)) +...
#     - sum log(1+exp(w_mean data + bias_mean)) + ...
#     - bias_vis data + 0.5 data^2
# NOTE: FH is constrained to be positive 
# (in the paper the sign is negative but the sign in front of it is also flipped)
def compute_energy_mcRBM(data,normdata,vel,energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis):
    # normalize input data vectors
    data.mult(data, target = t6) # DxP (nr input dims x nr samples)
    t6.sum(axis = 0, target = lengthsq) # 1xP
    lengthsq.mult(0.5, target = energy) # energy of quadratic regularization term   
    lengthsq.mult(1./num_vis) # normalize by number of components (like std)    

    lengthsq.add(small) # small prevents division by 0
    # energy_j = \sum_i 0.5 data_ij ^2
    # lengthsq_j = 1/ (\sum_i data_ij ^2 + small)
    cmt.sqrt(lengthsq, target = length) 
    # length_j = sqrt(lengthsq_j)
    length.reciprocal(target = normcoeff) # 1xP
    # normcoef_j = 1/sqrt(lengthsq_j)
    data.mult_by_row(normcoeff, target = normdata) # normalized data    
    # normdata is like data, but cols have unit L2 norm

    ## potential
    # covariance contribution
    cmt.dot(VF.T, normdata, target = feat) # HxP (nr factors x nr samples)
    feat.mult(feat, target = featsq)   # HxP

    # featsq is the squared cosines (VF with data)
    cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples)
    t1.mult(-0.5)
    t1.add_col_vec(bias_cov) # OxP
    cmt.exp(t1) # OxP
    t1.add(1, target = t2) # OxP
    cmt.log(t2)
    t2.mult(-1)
    energy.add_sums(t2, axis=0)
    # mean contribution
    cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples)
    feat_mean.add_col_vec(bias_mean) # HxP
    cmt.exp(feat_mean) 
    feat_mean.add(1)
    cmt.log(feat_mean)
    feat_mean.mult(-1)
    energy.add_sums(feat_mean,  axis=0)
    # visible bias term
    data.mult_by_col(bias_vis, target = t6)
    t6.mult(-1) # DxP
    energy.add_sums(t6,  axis=0) # 1xP
    # kinetic
    vel.mult(vel, target = t6)
    energy.add_sums(t6, axis = 0, mult = .5)

######################################################
# mcRBM trainer: sweeps over the training set.
# For each batch of samples compute derivatives to update the parameters
# at the training samples and at the negative samples drawn calling HMC sampler.
def train_mcRBM():

    config = ConfigParser()
    config.read('input_configuration')

    verbose = config.getint('VERBOSITY','verbose')

    num_epochs = config.getint('MAIN_PARAMETER_SETTING','num_epochs')
    batch_size = config.getint('MAIN_PARAMETER_SETTING','batch_size')
    startFH = config.getint('MAIN_PARAMETER_SETTING','startFH')
    startwd = config.getint('MAIN_PARAMETER_SETTING','startwd')
    doPCD = config.getint('MAIN_PARAMETER_SETTING','doPCD')

    # model parameters
    num_fac = config.getint('MODEL_PARAMETER_SETTING','num_fac')
    num_hid_cov =  config.getint('MODEL_PARAMETER_SETTING','num_hid_cov')
    num_hid_mean =  config.getint('MODEL_PARAMETER_SETTING','num_hid_mean')
    apply_mask =  config.getint('MODEL_PARAMETER_SETTING','apply_mask')

    # load data
    data_file_name =  config.get('DATA','data_file_name')
    d = loadmat(data_file_name) # input in the format PxD (P vectorized samples with D dimensions)
    totnumcases = d["whitendata"].shape[0]
    d = d["whitendata"][0:floor(totnumcases/batch_size)*batch_size,:].copy() 
    totnumcases = d.shape[0]
    num_vis =  d.shape[1]
    num_batches = int(totnumcases/batch_size)
    dev_dat = cmt.CUDAMatrix(d.T) # VxP 

    # training parameters
    epsilon = config.getfloat('OPTIMIZER_PARAMETERS','epsilon')
    epsilonVF = 2*epsilon
    epsilonFH = 0.02*epsilon
    epsilonb = 0.02*epsilon
    epsilonw_mean = 0.2*epsilon
    epsilonb_mean = 0.1*epsilon
    weightcost_final =  config.getfloat('OPTIMIZER_PARAMETERS','weightcost_final')

    # HMC setting
    hmc_step_nr = config.getint('HMC_PARAMETERS','hmc_step_nr')
    hmc_step =  0.01
    hmc_target_ave_rej =  config.getfloat('HMC_PARAMETERS','hmc_target_ave_rej')
    hmc_ave_rej =  hmc_target_ave_rej

    # initialize weights
    VF = cmt.CUDAMatrix(np.array(0.02 * np.random.randn(num_vis, num_fac), dtype=np.float32, order='F')) # VxH
    if apply_mask == 0:
        FH = cmt.CUDAMatrix( np.array( np.eye(num_fac,num_hid_cov), dtype=np.float32, order='F')  ) # HxO
    else:
        dd = loadmat('your_FHinit_mask_file.mat') # see CVPR2010paper_material/topo2D_3x3_stride2_576filt.mat for an example
        FH = cmt.CUDAMatrix( np.array( dd["FH"], dtype=np.float32, order='F')  )
    bias_cov = cmt.CUDAMatrix( np.array(2.0*np.ones((num_hid_cov, 1)), dtype=np.float32, order='F') )
    bias_vis = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F') )
    w_mean = cmt.CUDAMatrix( np.array( 0.05 * np.random.randn(num_vis, num_hid_mean), dtype=np.float32, order='F') ) # VxH
    bias_mean = cmt.CUDAMatrix( np.array( -2.0*np.ones((num_hid_mean,1)), dtype=np.float32, order='F') )

    # initialize variables to store derivatives 
    VFinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_fac)), dtype=np.float32, order='F'))
    FHinc = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F'))
    bias_covinc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_cov, 1)), dtype=np.float32, order='F'))
    bias_visinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F'))
    w_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_hid_mean)), dtype=np.float32, order='F'))
    bias_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_mean, 1)), dtype=np.float32, order='F'))

    # initialize temporary storage
    data = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    normdata = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    negdataini = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    feat = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
    featsq = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
    negdata = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F'))
    old_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
    new_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
    gradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    normgradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
    thresh = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
    feat_mean = cmt.CUDAMatrix( np.array(np.empty((num_hid_mean, batch_size)), dtype=np.float32, order='F'))
    vel = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F'))
    length = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
    lengthsq = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
    normcoeff = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
    if apply_mask==1: # this used to constrain very large FH matrices only allowing to change values in a neighborhood
        dd = loadmat('your_FHinit_mask_file.mat') 
        mask = cmt.CUDAMatrix( np.array(dd["mask"], dtype=np.float32, order='F'))
    normVF = 1    
    small = 0.5
    
    # other temporary vars
    t1 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F'))
    t2 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F'))
    t3 = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
    t4 = cmt.CUDAMatrix( np.array(np.empty((1,batch_size)), dtype=np.float32, order='F'))
    t5 = cmt.CUDAMatrix( np.array(np.empty((1,1)), dtype=np.float32, order='F'))
    t6 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F'))
    t7 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F'))
    t8 = cmt.CUDAMatrix( np.array(np.empty((num_vis, num_fac)), dtype=np.float32, order='F'))
    t9 = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F'))
    t10 = cmt.CUDAMatrix( np.array(np.empty((1,num_fac)), dtype=np.float32, order='F'))
    t11 = cmt.CUDAMatrix( np.array(np.empty((1,num_hid_cov)), dtype=np.float32, order='F'))

    # start training
    for epoch in range(num_epochs):

        print "Epoch " + str(epoch + 1)
    
        # anneal learning rates
        epsilonVFc    = epsilonVF/max(1,epoch/20)
        epsilonFHc    = epsilonFH/max(1,epoch/20)
        epsilonbc    = epsilonb/max(1,epoch/20)
        epsilonw_meanc = epsilonw_mean/max(1,epoch/20)
        epsilonb_meanc = epsilonb_mean/max(1,epoch/20)
        weightcost = weightcost_final

        if epoch <= startFH:
            epsilonFHc = 0 
        if epoch <= startwd:	
            weightcost = 0

        for batch in range(num_batches):

            # get current minibatch
            data = dev_dat.slice(batch*batch_size,(batch + 1)*batch_size) # DxP (nr dims x nr samples)

            # normalize input data
            data.mult(data, target = t6) # DxP
            t6.sum(axis = 0, target = lengthsq) # 1xP
            lengthsq.mult(1./num_vis) # normalize by number of components (like std)
            lengthsq.add(small) # small avoids division by 0
            cmt.sqrt(lengthsq, target = length)
            length.reciprocal(target = normcoeff) # 1xP
            data.mult_by_row(normcoeff, target = normdata) # normalized data 
            ## compute positive sample derivatives
            # covariance part
            cmt.dot(VF.T, normdata, target = feat) # HxP (nr facs x nr samples)
            feat.mult(feat, target = featsq)   # HxP
            cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples)
            t1.mult(-0.5)
            t1.add_col_vec(bias_cov) # OxP
            t1.apply_sigmoid(target = t2) # OxP
            cmt.dot(featsq, t2.T, target = FHinc) # HxO
            cmt.dot(FH,t2, target = t3) # HxP
            t3.mult(feat)
            cmt.dot(normdata, t3.T, target = VFinc) # VxH
            t2.sum(axis = 1, target = bias_covinc)
            bias_covinc.mult(-1)  
            # visible bias
            data.sum(axis = 1, target = bias_visinc)
            bias_visinc.mult(-1)
            # mean part
            cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples)
            feat_mean.add_col_vec(bias_mean) # HxP
            feat_mean.apply_sigmoid() # HxP
            feat_mean.mult(-1)
            cmt.dot(data, feat_mean.T, target = w_meaninc)
            feat_mean.sum(axis = 1, target = bias_meaninc)
            
            # HMC sampling: draw an approximate sample from the model
            if doPCD == 0: # CD-1 (set negative data to current training samples)
                hmc_step, hmc_ave_rej = draw_HMC_samples(data,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis)
            else: # PCD-1 (use previous negative data as starting point for chain)
                negdataini.assign(negdata)
                hmc_step, hmc_ave_rej = draw_HMC_samples(negdataini,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis)
                
            # compute derivatives at the negative samples
            # normalize input data
            negdata.mult(negdata, target = t6) # DxP
            t6.sum(axis = 0, target = lengthsq) # 1xP
            lengthsq.mult(1./num_vis) # normalize by number of components (like std)
            lengthsq.add(small)
            cmt.sqrt(lengthsq, target = length)
            length.reciprocal(target = normcoeff) # 1xP
            negdata.mult_by_row(normcoeff, target = normdata) # normalized data 
            # covariance part
            cmt.dot(VF.T, normdata, target = feat) # HxP 
            feat.mult(feat, target = featsq)   # HxP
            cmt.dot(FH.T,featsq, target = t1) # OxP
            t1.mult(-0.5)
            t1.add_col_vec(bias_cov) # OxP
            t1.apply_sigmoid(target = t2) # OxP
            FHinc.subtract_dot(featsq, t2.T) # HxO
            FHinc.mult(0.5)
            cmt.dot(FH,t2, target = t3) # HxP
            t3.mult(feat)
            VFinc.subtract_dot(normdata, t3.T) # VxH
            bias_covinc.add_sums(t2, axis = 1)
            # visible bias
            bias_visinc.add_sums(negdata, axis = 1)
            # mean part
            cmt.dot(w_mean.T, negdata, target = feat_mean) # HxP 
            feat_mean.add_col_vec(bias_mean) # HxP
            feat_mean.apply_sigmoid() # HxP
            w_meaninc.add_dot(negdata, feat_mean.T)
            bias_meaninc.add_sums(feat_mean, axis = 1)

            # update parameters
            VFinc.add_mult(VF.sign(), weightcost) # L1 regularization
            VF.add_mult(VFinc, -epsilonVFc/batch_size)
            # normalize columns of VF: normalize by running average of their norm 
            VF.mult(VF, target = t8)
            t8.sum(axis = 0, target = t10)
            cmt.sqrt(t10)
            t10.sum(axis=1,target = t5)
            t5.copy_to_host()
            normVF = .95*normVF + (.05/num_fac) * t5.numpy_array[0,0] # estimate norm
            t10.reciprocal()
            VF.mult_by_row(t10) 
            VF.mult(normVF) 
            bias_cov.add_mult(bias_covinc, -epsilonbc/batch_size)
            bias_vis.add_mult(bias_visinc, -epsilonbc/batch_size)

            if epoch > startFH:
                FHinc.add_mult(FH.sign(), weightcost) # L1 regularization
       		FH.add_mult(FHinc, -epsilonFHc/batch_size) # update
	        # set to 0 negative entries in FH
        	FH.greater_than(0, target = t9)
	        FH.mult(t9)
                if apply_mask==1:
                    FH.mult(mask)
		# normalize columns of FH: L1 norm set to 1 in each column
		FH.sum(axis = 0, target = t11)               
		t11.reciprocal()
		FH.mult_by_row(t11) 
            w_meaninc.add_mult(w_mean.sign(),weightcost)
            w_mean.add_mult(w_meaninc, -epsilonw_meanc/batch_size)
            bias_mean.add_mult(bias_meaninc, -epsilonb_meanc/batch_size)

        if verbose == 1:
            print "VF: " + '%3.2e' % VF.euclid_norm() + ", DVF: " + '%3.2e' % (VFinc.euclid_norm()*(epsilonVFc/batch_size)) + ", FH: " + '%3.2e' % FH.euclid_norm() + ", DFH: " + '%3.2e' % (FHinc.euclid_norm()*(epsilonFHc/batch_size)) + ", bias_cov: " + '%3.2e' % bias_cov.euclid_norm() + ", Dbias_cov: " + '%3.2e' % (bias_covinc.euclid_norm()*(epsilonbc/batch_size)) + ", bias_vis: " + '%3.2e' % bias_vis.euclid_norm() + ", Dbias_vis: " + '%3.2e' % (bias_visinc.euclid_norm()*(epsilonbc/batch_size)) + ", wm: " + '%3.2e' % w_mean.euclid_norm() + ", Dwm: " + '%3.2e' % (w_meaninc.euclid_norm()*(epsilonw_meanc/batch_size)) + ", bm: " + '%3.2e' % bias_mean.euclid_norm() + ", Dbm: " + '%3.2e' % (bias_meaninc.euclid_norm()*(epsilonb_meanc/batch_size)) + ", step: " + '%3.2e' % hmc_step  +  ", rej: " + '%3.2e' % hmc_ave_rej 
            sys.stdout.flush()
        # back-up every once in a while 
        if np.mod(epoch,10) == 0:
            VF.copy_to_host()
            FH.copy_to_host()
            bias_cov.copy_to_host()
            w_mean.copy_to_host()
            bias_mean.copy_to_host()
            bias_vis.copy_to_host()
            savemat("ws_temp", {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array,'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch})    
    # final back-up
    VF.copy_to_host()
    FH.copy_to_host()
    bias_cov.copy_to_host()
    bias_vis.copy_to_host()
    w_mean.copy_to_host()
    bias_mean.copy_to_host()
    savemat("ws_fac" + str(num_fac) + "_cov" + str(num_hid_cov) + "_mean" + str(num_hid_mean), {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array, 'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch})