changeset 967:90e11d5d0a41

adding algorithms/mcRBM, but it is not done yet
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 20 Aug 2010 13:58:56 -0400
parents e88d7b7d53ed
children c96dc085b5b7
files pylearn/algorithms/mcRBM.py
diffstat 1 files changed, 657 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/algorithms/mcRBM.py	Fri Aug 20 13:58:56 2010 -0400
@@ -0,0 +1,657 @@
+"""
+This file implements the Mean & Covariance RBM discussed in 
+
+    Ranzato, M. and Hinton, G. E. (2010)
+    Modeling pixel means and covariances using factored third-order Boltzmann machines.
+    IEEE Conference on Computer Vision and Pattern Recognition.
+
+and performs one of the experiments on CIFAR-10 discussed in that paper.
+
+
+Math
+====
+
+Energy of "covariance RBM"
+
+    E = -0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i C_{if} v_i )^2
+      = -0.5 \sum_f (\sum_k P_{fk} h_k) ( \sum_i C_{if} v_i )^2
+                    "vector element f"           "vector element f"
+
+In some parts of the paper, the P matrix is chosen to be a diagonal matrix with non-positive
+diagonal entries, so it is helpful to see this as a simpler equation:
+
+    E =  \sum_f h_f ( \sum_i C_{if} v_i )^2
+
+
+
+Full Energy of mean and Covariance RBM, with 
+:math:`h_k = h_k^{(c)}`,
+:math:`g_j = h_j^{(m)}`,
+:math:`b_k = b_k^{(c)}`,
+:math:`c_j = b_j^{(m)}`,
+:math:`U_{if} = C_{if}`,
+
+:
+
+    E (v, h, g) =
+        - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / |U_{*f}|^2 |v|^2
+        - \sum_k b_k h_k
+        + 0.5 \sum_i v_i^2
+        - \sum_j \sum_i W_{ij} g_j v_i
+        - \sum_j c_j g_j
+
+
+Conventions in this file
+========================
+
+This file contains some global functions, as well as a class (MeanCovRBM) that makes using them a little
+more convenient.
+
+
+Global functions like `free_energy` work on an mcRBM as parametrized in a particular way.
+Suppose we have 
+I input dimensions, 
+F squared filters, 
+J mean variables, and 
+K covariance variables.
+The mcRBM is parametrized by 5 variables:
+
+ - `P`, a matrix (probably sparse) of pooling (F x K)
+ - `U`, a matrix whose rows are visible covariance directions (I x F)
+ - `W`, a matrix whose rows are visible mean directions (I x J)
+ - `b`, a vector of hidden covariance biases (K)
+ - `c`, a vector of hidden mean biases  (J)
+
+Matrices are generally layed out according to a C-order convention.
+
+"""
+
+# Free energy is the marginal energy of visible units
+# Recall: 
+#   Q(x) = exp(-E(x))/Z ==> -log(Q(x)) - log(Z) = E(x)
+#
+# Derivation, in which partition functions are ignored.
+#
+# E(v) = -\log(Q(v))
+#  = -\log( \sum_{h,g} Q(v,h,g))
+#  = -\log( \sum_{h,g} exp(-E(v,h,g)))
+#  = -\log( \sum_{h,g} exp(-
+#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
+#       - \sum_k b_k h_k
+#       + 0.5 \sum_i v_i^2
+#       - \sum_j \sum_i W_{ij} g_j v_i
+#       - \sum_j c_j g_j ))
+#  = -\log(  \sum_{h} exp(
+#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
+#       + \sum_k b_k h_k
+#       - 0.5 \sum_i v_i^2
+#       ) * \sum_{g} exp(
+#       + \sum_j \sum_i W_{ij} g_j v_i
+#       + \sum_j c_j g_j )))
+#  = -\log(  \sum_{h} exp(
+#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)
+#       + \sum_k b_k h_k
+#       ))
+#    -\log( \sum_{g} exp(
+#       + \sum_j \sum_i W_{ij} g_j v_i
+#       + \sum_j c_j g_j )))
+#    + 0.5 \sum_i v_i^2
+#  = -\log(\sum_{h} exp(
+#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|* |v|)
+#       + \sum_k b_k h_k
+#       ))
+#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
+#    + 0.5 \sum_i v_i^2
+#  = - \sum_{k} \log(1 + exp(b_k + 0.5 \sum_f P_{fk}( \sum_i U_{if} v_i )^2 / (|U_{*f}|* #  |v|)))
+#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
+#    + 0.5 \sum_i v_i^2
+
+import sys
+import logging
+import numpy as np
+from theano import function, shared, dot
+from theano import tensor as TT
+import theano.sparse #installs the sparse shared var handler
+floatX = theano.config.floatX
+
+from pylearn.sampling.hmc import HMC_sampler
+from pylearn.io import image_tiling
+
+from sparse_coding import numpy_project_onto_ball
+
+#TODO: This should be in the nnet part of the library
+def sgd_updates(params, grads, lr):
+    try:
+        float(lr)
+        lr = [lr for p in params]
+    except TypeError:
+        pass
+    updates = [(p, p + plr * gp) for (plr, p, gp) in zip(lr, params, grads)]
+    return updates
+
+def as_shared(x, name=None, dtype=floatX):
+    if hasattr(x, 'type'):
+        return x
+    else:
+        if 'float' in str(x.dtype):
+            return shared(x.astype(floatX), name=name)
+        else:
+            return shared(x, name=name)
+
+def hidden_cov_units_preactivation_given_v(rbm, v, small=1e-8):
+    (U,W,a,b,c) = rbm
+    unit_v = v / (TT.sqrt(TT.sum(v**2, axis=1))+small).dimshuffle(0,'x') # unit rows
+    unit_U = U  # assuming unit cols!
+    #unit_U = U / (TT.sqrt(TT.sum(U**2, axis=0))+small)  #unit cols
+    return b - 0.5 * dot(unit_v, unit_U)**2
+
+def free_energy_given_v(rbm, v):
+    """Returns theano expression for free energy of visible vector `v` in an mcRBM
+    
+    An mcRBM is parametrized
+    by `U`, `W`, `b`, `c`.
+    See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters.
+
+
+    The free energy of v is what we need for learning and hybrid Monte-carlo negative-phase
+    sampling.
+
+    """
+    U, W, a, b, c = rbm
+
+    t0 = -TT.sum(TT.log1p(TT.exp(hidden_cov_units_preactivation_given_v(rbm, v))),axis=1)
+    t1 = -TT.sum(TT.log1p(TT.exp(c + dot(v,W))), axis=1)
+    t2 = 0.5 * TT.sum(v**2, axis=1)
+    t3 = -TT.dot(v, a)
+    return t0 + t1 + t2 + t3
+
+def expected_h_g_given_v(P, U, W, b, c, v):
+    """Returns theano expression conditional expectations (`h`, `g`) in an mcRBM.
+    
+    An mcRBM is parametrized
+    by `U`, `W`, `b`, `c`.
+    See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters.
+
+
+    The conditional E[h, g | v] is what we need to classify images.
+    """
+    raise NotImplementedError()
+
+    #TODO: check to see if these args should be negated?
+
+    if P is None:
+        h = nnet.sigmoid(b + 0.5 * cosines(v,U))
+    else:
+        h = nnet.sigmoid(b + 0.5 * dot(cosines(v,U), P))
+    g = nnet.sigmoid(c + dot(v,W))
+    return (h, g)
+
+class MeanCovRBM(object):
+    """Container for mcRBM parameters that gives more convenient access to mcRBM methods.
+    """
+
+    params = property(lambda s: [s.U, s.W, s.a, s.b, s.c])
+
+    n_visible = property(lambda s: s.W.value.shape[0])
+
+    def __init__(self, U, W, a, b, c):
+        self.U = as_shared(U, 'U')
+        self.W = as_shared(W, 'W')
+        self.a = as_shared(a, 'a')
+        self.b = as_shared(b, 'b')
+        self.c = as_shared(c, 'c')
+
+        assert self.b.type.dtype == 'float32'
+
+    @classmethod
+    def new_from_dims(cls, 
+            n_I,  # input dimensionality
+            n_K,  # number of covariance hidden units
+            n_F,  # number of covariance filters (squared)
+            n_J,  # number of mean filters (linear)
+            seed = 8923402190,
+            ):
+        """
+        Return a MeanCovRBM instance with randomly-initialized parameters.
+        """
+
+        
+        if 0:
+            if P_init == 'diag':
+                if n_K != n_F:
+                    raise ValueError('cannot use diagonal initialization of non-square P matrix')
+                import scipy.sparse
+                P =  -scipy.sparse.identity(n_K).tocsr()
+            else:
+                raise NotImplementedError()
+
+        rng = np.random.RandomState(seed)
+
+        # initialization taken from Marc'Aurelio
+
+        return cls(
+                U = numpy_project_onto_ball(rng.randn(n_I, n_F).T).T,
+                W = rng.randn(n_I, n_J)/np.sqrt((n_I+n_J)/2),
+                a = np.ones(n_I)*(-2),
+                b = np.ones(n_K)*2,
+                c = np.zeros(n_J),)
+
+    def __getstate__(self):
+        # unpack shared containers, which may have references to Theano stuff
+        # and are not a long-term stable data type.
+        return dict(
+                U = self.U.value,
+                W = self.W.value,
+                b = self.b.value,
+                c = self.c.value)
+
+    def __setstate__(self, dct):
+        self.__init__(**dct) # calls as_shared on pickled arrays
+
+    def hmc_sampler(self, n_particles=100, seed=7823748):
+        return HMC_sampler(
+            positions = [as_shared(
+                np.random.RandomState(seed^20893).rand(
+                    n_particles,
+                    self.n_visible ))],
+            energy_fn = lambda p : self.free_energy_given_v(p[0]),
+            seed=seed)
+
+    def free_energy_given_v(self, v):
+        return free_energy_given_v(self.params, v)
+
+    def contrastive_gradient(self, pos_v, neg_v):
+        """Return a list of gradient expressions for self.params
+
+        :param pos_v: positive-phase sample of visible units
+        :param neg_v: negative-phase sample of visible units
+        """
+        pos_FE = self.free_energy_given_v(pos_v)
+        neg_FE = self.free_energy_given_v(neg_v)
+
+        gpos_FE = theano.tensor.grad(pos_FE.sum(), self.params)
+        gneg_FE = theano.tensor.grad(neg_FE.sum(), self.params)
+        return [ gn - gp for (gp,gn) in zip(gpos_FE, gneg_FE)]
+
+if __name__ == '__main__':
+
+    print >> sys.stderr, "TODO: use P matrix (aka FH matrix)"
+
+    R,C= 8,8 # the size of image patches
+    l1_penalty=1e-3
+    no_l1_epochs = 10
+
+    epoch_size=50000
+    batchsize = 128
+    lr = 0.075 / batchsize
+    s_lr = TT.scalar()
+    n_K=256
+    n_F=256
+    n_J=100
+
+    rbm = MeanCovRBM.new_from_dims(n_I=R*C,
+            n_K=n_K,
+            n_J=n_J, 
+            n_F=n_F,
+            ) 
+
+    sampler = rbm.hmc_sampler(n_particles=100)
+
+    from pylearn.dataset_ops import image_patches
+
+    batch_idx = TT.iscalar()
+    train_batch = image_patches.image_patches(
+            s_idx = (batch_idx * batchsize + np.arange(batchsize)),
+            dims = (1000,R,C),
+            dtype=floatX,
+            rasterized=True)
+
+    grads = rbm.contrastive_gradient(pos_v=train_batch, neg_v=sampler.positions[0])
+
+    learn_fn = function([batch_idx, s_lr], 
+            outputs=[ 
+                grads[0].norm(2),
+                rbm.U.norm(2)
+                ],
+            updates = sgd_updates(
+                rbm.params,
+                grads,
+                lr=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ]))
+
+    for jj in xrange(10000):
+        sampler.simulate()
+        l2_of_Ugrad = learn_fn(jj, lr/max(1, jj/(20*epoch_size/batchsize)))
+
+        if jj > no_l1_epochs * epoch_size/batchsize:
+            rbm.U.value -= l1_penalty * np.sign(rbm.U.value)
+            rbm.W.value -= l1_penalty * np.sign(rbm.W.value)
+
+        if jj % 5 == 0:
+            rbm.U.value = numpy_project_onto_ball(rbm.U.value.T).T
+
+        if ((jj < 10) 
+                or (jj < 100 and 0==jj%10) 
+                or (jj < 1000 and 0==jj%100)
+                or (jj < 10000 and 0==jj%1000)):
+            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize), l2_of_Ugrad
+            print 'neg particles', sampler.positions[0].value.min(), sampler.positions[0].value.max()
+            image_tiling.save_tiled_raster_images(
+                image_tiling.tile_raster_images(sampler.positions[0].value, (R,C)),
+                "sample_%06i.png"%jj)
+            image_tiling.save_tiled_raster_images(
+                image_tiling.tile_raster_images(rbm.U.value.T, (R,C)),
+                "U_%06i.png"%jj)
+            image_tiling.save_tiled_raster_images(
+                image_tiling.tile_raster_images(rbm.W.value.T, (R,C)),
+                "W_%06i.png"%jj)
+
+
+
+#
+#
+# Marc'Aurelio Ranzato's code
+#
+######################################################################
+# compute the value of the free energy at a given input
+# F = - sum log(1+exp(- .5 FH (VF data/norm(data))^2 + bias_cov)) +...
+#     - sum log(1+exp(w_mean data + bias_mean)) + ...
+#     - bias_vis data + 0.5 data^2
+# NOTE: FH is constrained to be positive 
+# (in the paper the sign is negative but the sign in front of it is also flipped)
+def compute_energy_mcRBM(data,normdata,vel,energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis):
+    # normalize input data vectors
+    data.mult(data, target = t6) # DxP (nr input dims x nr samples)
+    t6.sum(axis = 0, target = lengthsq) # 1xP
+    lengthsq.mult(0.5, target = energy) # energy of quadratic regularization term   
+    lengthsq.mult(1./num_vis) # normalize by number of components (like std)    
+
+    lengthsq.add(small) # small prevents division by 0
+    # energy_j = \sum_i 0.5 data_ij ^2
+    # lengthsq_j = 1/ (\sum_i data_ij ^2 + small)
+    cmt.sqrt(lengthsq, target = length) 
+    # length_j = sqrt(lengthsq_j)
+    length.reciprocal(target = normcoeff) # 1xP
+    # normcoef_j = 1/sqrt(lengthsq_j)
+    data.mult_by_row(normcoeff, target = normdata) # normalized data    
+    # normdata is like data, but cols have unit L2 norm
+
+    ## potential
+    # covariance contribution
+    cmt.dot(VF.T, normdata, target = feat) # HxP (nr factors x nr samples)
+    feat.mult(feat, target = featsq)   # HxP
+
+    # featsq is the squared cosines (VF with data)
+    cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples)
+    t1.mult(-0.5)
+    t1.add_col_vec(bias_cov) # OxP
+    cmt.exp(t1) # OxP
+    t1.add(1, target = t2) # OxP
+    cmt.log(t2)
+    t2.mult(-1)
+    energy.add_sums(t2, axis=0)
+    # mean contribution
+    cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples)
+    feat_mean.add_col_vec(bias_mean) # HxP
+    cmt.exp(feat_mean) 
+    feat_mean.add(1)
+    cmt.log(feat_mean)
+    feat_mean.mult(-1)
+    energy.add_sums(feat_mean,  axis=0)
+    # visible bias term
+    data.mult_by_col(bias_vis, target = t6)
+    t6.mult(-1) # DxP
+    energy.add_sums(t6,  axis=0) # 1xP
+    # kinetic
+    vel.mult(vel, target = t6)
+    energy.add_sums(t6, axis = 0, mult = .5)
+
+######################################################
+# mcRBM trainer: sweeps over the training set.
+# For each batch of samples compute derivatives to update the parameters
+# at the training samples and at the negative samples drawn calling HMC sampler.
+def train_mcRBM():
+
+    config = ConfigParser()
+    config.read('input_configuration')
+
+    verbose = config.getint('VERBOSITY','verbose')
+
+    num_epochs = config.getint('MAIN_PARAMETER_SETTING','num_epochs')
+    batch_size = config.getint('MAIN_PARAMETER_SETTING','batch_size')
+    startFH = config.getint('MAIN_PARAMETER_SETTING','startFH')
+    startwd = config.getint('MAIN_PARAMETER_SETTING','startwd')
+    doPCD = config.getint('MAIN_PARAMETER_SETTING','doPCD')
+
+    # model parameters
+    num_fac = config.getint('MODEL_PARAMETER_SETTING','num_fac')
+    num_hid_cov =  config.getint('MODEL_PARAMETER_SETTING','num_hid_cov')
+    num_hid_mean =  config.getint('MODEL_PARAMETER_SETTING','num_hid_mean')
+    apply_mask =  config.getint('MODEL_PARAMETER_SETTING','apply_mask')
+
+    # load data
+    data_file_name =  config.get('DATA','data_file_name')
+    d = loadmat(data_file_name) # input in the format PxD (P vectorized samples with D dimensions)
+    totnumcases = d["whitendata"].shape[0]
+    d = d["whitendata"][0:floor(totnumcases/batch_size)*batch_size,:].copy() 
+    totnumcases = d.shape[0]
+    num_vis =  d.shape[1]
+    num_batches = int(totnumcases/batch_size)
+    dev_dat = cmt.CUDAMatrix(d.T) # VxP 
+
+    # training parameters
+    epsilon = config.getfloat('OPTIMIZER_PARAMETERS','epsilon')
+    epsilonVF = 2*epsilon
+    epsilonFH = 0.02*epsilon
+    epsilonb = 0.02*epsilon
+    epsilonw_mean = 0.2*epsilon
+    epsilonb_mean = 0.1*epsilon
+    weightcost_final =  config.getfloat('OPTIMIZER_PARAMETERS','weightcost_final')
+
+    # HMC setting
+    hmc_step_nr = config.getint('HMC_PARAMETERS','hmc_step_nr')
+    hmc_step =  0.01
+    hmc_target_ave_rej =  config.getfloat('HMC_PARAMETERS','hmc_target_ave_rej')
+    hmc_ave_rej =  hmc_target_ave_rej
+
+    # initialize weights
+    VF = cmt.CUDAMatrix(np.array(0.02 * np.random.randn(num_vis, num_fac), dtype=np.float32, order='F')) # VxH
+    if apply_mask == 0:
+        FH = cmt.CUDAMatrix( np.array( np.eye(num_fac,num_hid_cov), dtype=np.float32, order='F')  ) # HxO
+    else:
+        dd = loadmat('your_FHinit_mask_file.mat') # see CVPR2010paper_material/topo2D_3x3_stride2_576filt.mat for an example
+        FH = cmt.CUDAMatrix( np.array( dd["FH"], dtype=np.float32, order='F')  )
+    bias_cov = cmt.CUDAMatrix( np.array(2.0*np.ones((num_hid_cov, 1)), dtype=np.float32, order='F') )
+    bias_vis = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F') )
+    w_mean = cmt.CUDAMatrix( np.array( 0.05 * np.random.randn(num_vis, num_hid_mean), dtype=np.float32, order='F') ) # VxH
+    bias_mean = cmt.CUDAMatrix( np.array( -2.0*np.ones((num_hid_mean,1)), dtype=np.float32, order='F') )
+
+    # initialize variables to store derivatives 
+    VFinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_fac)), dtype=np.float32, order='F'))
+    FHinc = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F'))
+    bias_covinc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_cov, 1)), dtype=np.float32, order='F'))
+    bias_visinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F'))
+    w_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_hid_mean)), dtype=np.float32, order='F'))
+    bias_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_mean, 1)), dtype=np.float32, order='F'))
+
+    # initialize temporary storage
+    data = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    normdata = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    negdataini = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    feat = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
+    featsq = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
+    negdata = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F'))
+    old_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
+    new_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
+    gradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    normgradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP
+    thresh = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F'))
+    feat_mean = cmt.CUDAMatrix( np.array(np.empty((num_hid_mean, batch_size)), dtype=np.float32, order='F'))
+    vel = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F'))
+    length = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
+    lengthsq = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
+    normcoeff = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP
+    if apply_mask==1: # this used to constrain very large FH matrices only allowing to change values in a neighborhood
+        dd = loadmat('your_FHinit_mask_file.mat') 
+        mask = cmt.CUDAMatrix( np.array(dd["mask"], dtype=np.float32, order='F'))
+    normVF = 1    
+    small = 0.5
+    
+    # other temporary vars
+    t1 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F'))
+    t2 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F'))
+    t3 = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F'))
+    t4 = cmt.CUDAMatrix( np.array(np.empty((1,batch_size)), dtype=np.float32, order='F'))
+    t5 = cmt.CUDAMatrix( np.array(np.empty((1,1)), dtype=np.float32, order='F'))
+    t6 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F'))
+    t7 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F'))
+    t8 = cmt.CUDAMatrix( np.array(np.empty((num_vis, num_fac)), dtype=np.float32, order='F'))
+    t9 = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F'))
+    t10 = cmt.CUDAMatrix( np.array(np.empty((1,num_fac)), dtype=np.float32, order='F'))
+    t11 = cmt.CUDAMatrix( np.array(np.empty((1,num_hid_cov)), dtype=np.float32, order='F'))
+
+    # start training
+    for epoch in range(num_epochs):
+
+        print "Epoch " + str(epoch + 1)
+    
+        # anneal learning rates
+        epsilonVFc    = epsilonVF/max(1,epoch/20)
+        epsilonFHc    = epsilonFH/max(1,epoch/20)
+        epsilonbc    = epsilonb/max(1,epoch/20)
+        epsilonw_meanc = epsilonw_mean/max(1,epoch/20)
+        epsilonb_meanc = epsilonb_mean/max(1,epoch/20)
+        weightcost = weightcost_final
+
+        if epoch <= startFH:
+            epsilonFHc = 0 
+        if epoch <= startwd:	
+            weightcost = 0
+
+        for batch in range(num_batches):
+
+            # get current minibatch
+            data = dev_dat.slice(batch*batch_size,(batch + 1)*batch_size) # DxP (nr dims x nr samples)
+
+            # normalize input data
+            data.mult(data, target = t6) # DxP
+            t6.sum(axis = 0, target = lengthsq) # 1xP
+            lengthsq.mult(1./num_vis) # normalize by number of components (like std)
+            lengthsq.add(small) # small avoids division by 0
+            cmt.sqrt(lengthsq, target = length)
+            length.reciprocal(target = normcoeff) # 1xP
+            data.mult_by_row(normcoeff, target = normdata) # normalized data 
+            ## compute positive sample derivatives
+            # covariance part
+            cmt.dot(VF.T, normdata, target = feat) # HxP (nr facs x nr samples)
+            feat.mult(feat, target = featsq)   # HxP
+            cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples)
+            t1.mult(-0.5)
+            t1.add_col_vec(bias_cov) # OxP
+            t1.apply_sigmoid(target = t2) # OxP
+            cmt.dot(featsq, t2.T, target = FHinc) # HxO
+            cmt.dot(FH,t2, target = t3) # HxP
+            t3.mult(feat)
+            cmt.dot(normdata, t3.T, target = VFinc) # VxH
+            t2.sum(axis = 1, target = bias_covinc)
+            bias_covinc.mult(-1)  
+            # visible bias
+            data.sum(axis = 1, target = bias_visinc)
+            bias_visinc.mult(-1)
+            # mean part
+            cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples)
+            feat_mean.add_col_vec(bias_mean) # HxP
+            feat_mean.apply_sigmoid() # HxP
+            feat_mean.mult(-1)
+            cmt.dot(data, feat_mean.T, target = w_meaninc)
+            feat_mean.sum(axis = 1, target = bias_meaninc)
+            
+            # HMC sampling: draw an approximate sample from the model
+            if doPCD == 0: # CD-1 (set negative data to current training samples)
+                hmc_step, hmc_ave_rej = draw_HMC_samples(data,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis)
+            else: # PCD-1 (use previous negative data as starting point for chain)
+                negdataini.assign(negdata)
+                hmc_step, hmc_ave_rej = draw_HMC_samples(negdataini,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis)
+                
+            # compute derivatives at the negative samples
+            # normalize input data
+            negdata.mult(negdata, target = t6) # DxP
+            t6.sum(axis = 0, target = lengthsq) # 1xP
+            lengthsq.mult(1./num_vis) # normalize by number of components (like std)
+            lengthsq.add(small)
+            cmt.sqrt(lengthsq, target = length)
+            length.reciprocal(target = normcoeff) # 1xP
+            negdata.mult_by_row(normcoeff, target = normdata) # normalized data 
+            # covariance part
+            cmt.dot(VF.T, normdata, target = feat) # HxP 
+            feat.mult(feat, target = featsq)   # HxP
+            cmt.dot(FH.T,featsq, target = t1) # OxP
+            t1.mult(-0.5)
+            t1.add_col_vec(bias_cov) # OxP
+            t1.apply_sigmoid(target = t2) # OxP
+            FHinc.subtract_dot(featsq, t2.T) # HxO
+            FHinc.mult(0.5)
+            cmt.dot(FH,t2, target = t3) # HxP
+            t3.mult(feat)
+            VFinc.subtract_dot(normdata, t3.T) # VxH
+            bias_covinc.add_sums(t2, axis = 1)
+            # visible bias
+            bias_visinc.add_sums(negdata, axis = 1)
+            # mean part
+            cmt.dot(w_mean.T, negdata, target = feat_mean) # HxP 
+            feat_mean.add_col_vec(bias_mean) # HxP
+            feat_mean.apply_sigmoid() # HxP
+            w_meaninc.add_dot(negdata, feat_mean.T)
+            bias_meaninc.add_sums(feat_mean, axis = 1)
+
+            # update parameters
+            VFinc.add_mult(VF.sign(), weightcost) # L1 regularization
+            VF.add_mult(VFinc, -epsilonVFc/batch_size)
+            # normalize columns of VF: normalize by running average of their norm 
+            VF.mult(VF, target = t8)
+            t8.sum(axis = 0, target = t10)
+            cmt.sqrt(t10)
+            t10.sum(axis=1,target = t5)
+            t5.copy_to_host()
+            normVF = .95*normVF + (.05/num_fac) * t5.numpy_array[0,0] # estimate norm
+            t10.reciprocal()
+            VF.mult_by_row(t10) 
+            VF.mult(normVF) 
+            bias_cov.add_mult(bias_covinc, -epsilonbc/batch_size)
+            bias_vis.add_mult(bias_visinc, -epsilonbc/batch_size)
+
+            if epoch > startFH:
+                FHinc.add_mult(FH.sign(), weightcost) # L1 regularization
+       		FH.add_mult(FHinc, -epsilonFHc/batch_size) # update
+	        # set to 0 negative entries in FH
+        	FH.greater_than(0, target = t9)
+	        FH.mult(t9)
+                if apply_mask==1:
+                    FH.mult(mask)
+		# normalize columns of FH: L1 norm set to 1 in each column
+		FH.sum(axis = 0, target = t11)               
+		t11.reciprocal()
+		FH.mult_by_row(t11) 
+            w_meaninc.add_mult(w_mean.sign(),weightcost)
+            w_mean.add_mult(w_meaninc, -epsilonw_meanc/batch_size)
+            bias_mean.add_mult(bias_meaninc, -epsilonb_meanc/batch_size)
+
+        if verbose == 1:
+            print "VF: " + '%3.2e' % VF.euclid_norm() + ", DVF: " + '%3.2e' % (VFinc.euclid_norm()*(epsilonVFc/batch_size)) + ", FH: " + '%3.2e' % FH.euclid_norm() + ", DFH: " + '%3.2e' % (FHinc.euclid_norm()*(epsilonFHc/batch_size)) + ", bias_cov: " + '%3.2e' % bias_cov.euclid_norm() + ", Dbias_cov: " + '%3.2e' % (bias_covinc.euclid_norm()*(epsilonbc/batch_size)) + ", bias_vis: " + '%3.2e' % bias_vis.euclid_norm() + ", Dbias_vis: " + '%3.2e' % (bias_visinc.euclid_norm()*(epsilonbc/batch_size)) + ", wm: " + '%3.2e' % w_mean.euclid_norm() + ", Dwm: " + '%3.2e' % (w_meaninc.euclid_norm()*(epsilonw_meanc/batch_size)) + ", bm: " + '%3.2e' % bias_mean.euclid_norm() + ", Dbm: " + '%3.2e' % (bias_meaninc.euclid_norm()*(epsilonb_meanc/batch_size)) + ", step: " + '%3.2e' % hmc_step  +  ", rej: " + '%3.2e' % hmc_ave_rej 
+            sys.stdout.flush()
+        # back-up every once in a while 
+        if np.mod(epoch,10) == 0:
+            VF.copy_to_host()
+            FH.copy_to_host()
+            bias_cov.copy_to_host()
+            w_mean.copy_to_host()
+            bias_mean.copy_to_host()
+            bias_vis.copy_to_host()
+            savemat("ws_temp", {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array,'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch})    
+    # final back-up
+    VF.copy_to_host()
+    FH.copy_to_host()
+    bias_cov.copy_to_host()
+    bias_vis.copy_to_host()
+    w_mean.copy_to_host()
+    bias_mean.copy_to_host()
+    savemat("ws_fac" + str(num_fac) + "_cov" + str(num_hid_cov) + "_mean" + str(num_hid_mean), {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array, 'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch})