# HG changeset patch # User James Bergstra # Date 1282327136 14400 # Node ID 90e11d5d0a4123ddd006897071a06e77b63ec293 # Parent e88d7b7d53edcd9b87ac49972cc97e42198b8dc1 adding algorithms/mcRBM, but it is not done yet diff -r e88d7b7d53ed -r 90e11d5d0a41 pylearn/algorithms/mcRBM.py --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/algorithms/mcRBM.py Fri Aug 20 13:58:56 2010 -0400 @@ -0,0 +1,657 @@ +""" +This file implements the Mean & Covariance RBM discussed in + + Ranzato, M. and Hinton, G. E. (2010) + Modeling pixel means and covariances using factored third-order Boltzmann machines. + IEEE Conference on Computer Vision and Pattern Recognition. + +and performs one of the experiments on CIFAR-10 discussed in that paper. + + +Math +==== + +Energy of "covariance RBM" + + E = -0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i C_{if} v_i )^2 + = -0.5 \sum_f (\sum_k P_{fk} h_k) ( \sum_i C_{if} v_i )^2 + "vector element f" "vector element f" + +In some parts of the paper, the P matrix is chosen to be a diagonal matrix with non-positive +diagonal entries, so it is helpful to see this as a simpler equation: + + E = \sum_f h_f ( \sum_i C_{if} v_i )^2 + + + +Full Energy of mean and Covariance RBM, with +:math:`h_k = h_k^{(c)}`, +:math:`g_j = h_j^{(m)}`, +:math:`b_k = b_k^{(c)}`, +:math:`c_j = b_j^{(m)}`, +:math:`U_{if} = C_{if}`, + +: + + E (v, h, g) = + - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / |U_{*f}|^2 |v|^2 + - \sum_k b_k h_k + + 0.5 \sum_i v_i^2 + - \sum_j \sum_i W_{ij} g_j v_i + - \sum_j c_j g_j + + +Conventions in this file +======================== + +This file contains some global functions, as well as a class (MeanCovRBM) that makes using them a little +more convenient. + + +Global functions like `free_energy` work on an mcRBM as parametrized in a particular way. +Suppose we have +I input dimensions, +F squared filters, +J mean variables, and +K covariance variables. +The mcRBM is parametrized by 5 variables: + + - `P`, a matrix (probably sparse) of pooling (F x K) + - `U`, a matrix whose rows are visible covariance directions (I x F) + - `W`, a matrix whose rows are visible mean directions (I x J) + - `b`, a vector of hidden covariance biases (K) + - `c`, a vector of hidden mean biases (J) + +Matrices are generally layed out according to a C-order convention. + +""" + +# Free energy is the marginal energy of visible units +# Recall: +# Q(x) = exp(-E(x))/Z ==> -log(Q(x)) - log(Z) = E(x) +# +# Derivation, in which partition functions are ignored. +# +# E(v) = -\log(Q(v)) +# = -\log( \sum_{h,g} Q(v,h,g)) +# = -\log( \sum_{h,g} exp(-E(v,h,g))) +# = -\log( \sum_{h,g} exp(- +# - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|) +# - \sum_k b_k h_k +# + 0.5 \sum_i v_i^2 +# - \sum_j \sum_i W_{ij} g_j v_i +# - \sum_j c_j g_j )) +# = -\log( \sum_{h} exp( +# + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|) +# + \sum_k b_k h_k +# - 0.5 \sum_i v_i^2 +# ) * \sum_{g} exp( +# + \sum_j \sum_i W_{ij} g_j v_i +# + \sum_j c_j g_j ))) +# = -\log( \sum_{h} exp( +# + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|) +# + \sum_k b_k h_k +# )) +# -\log( \sum_{g} exp( +# + \sum_j \sum_i W_{ij} g_j v_i +# + \sum_j c_j g_j ))) +# + 0.5 \sum_i v_i^2 +# = -\log(\sum_{h} exp( +# + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|* |v|) +# + \sum_k b_k h_k +# )) +# - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j )) +# + 0.5 \sum_i v_i^2 +# = - \sum_{k} \log(1 + exp(b_k + 0.5 \sum_f P_{fk}( \sum_i U_{if} v_i )^2 / (|U_{*f}|* # |v|))) +# - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j )) +# + 0.5 \sum_i v_i^2 + +import sys +import logging +import numpy as np +from theano import function, shared, dot +from theano import tensor as TT +import theano.sparse #installs the sparse shared var handler +floatX = theano.config.floatX + +from pylearn.sampling.hmc import HMC_sampler +from pylearn.io import image_tiling + +from sparse_coding import numpy_project_onto_ball + +#TODO: This should be in the nnet part of the library +def sgd_updates(params, grads, lr): + try: + float(lr) + lr = [lr for p in params] + except TypeError: + pass + updates = [(p, p + plr * gp) for (plr, p, gp) in zip(lr, params, grads)] + return updates + +def as_shared(x, name=None, dtype=floatX): + if hasattr(x, 'type'): + return x + else: + if 'float' in str(x.dtype): + return shared(x.astype(floatX), name=name) + else: + return shared(x, name=name) + +def hidden_cov_units_preactivation_given_v(rbm, v, small=1e-8): + (U,W,a,b,c) = rbm + unit_v = v / (TT.sqrt(TT.sum(v**2, axis=1))+small).dimshuffle(0,'x') # unit rows + unit_U = U # assuming unit cols! + #unit_U = U / (TT.sqrt(TT.sum(U**2, axis=0))+small) #unit cols + return b - 0.5 * dot(unit_v, unit_U)**2 + +def free_energy_given_v(rbm, v): + """Returns theano expression for free energy of visible vector `v` in an mcRBM + + An mcRBM is parametrized + by `U`, `W`, `b`, `c`. + See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters. + + + The free energy of v is what we need for learning and hybrid Monte-carlo negative-phase + sampling. + + """ + U, W, a, b, c = rbm + + t0 = -TT.sum(TT.log1p(TT.exp(hidden_cov_units_preactivation_given_v(rbm, v))),axis=1) + t1 = -TT.sum(TT.log1p(TT.exp(c + dot(v,W))), axis=1) + t2 = 0.5 * TT.sum(v**2, axis=1) + t3 = -TT.dot(v, a) + return t0 + t1 + t2 + t3 + +def expected_h_g_given_v(P, U, W, b, c, v): + """Returns theano expression conditional expectations (`h`, `g`) in an mcRBM. + + An mcRBM is parametrized + by `U`, `W`, `b`, `c`. + See module - level documentation for explanations of the `U`, `W`, `b` and `c` parameters. + + + The conditional E[h, g | v] is what we need to classify images. + """ + raise NotImplementedError() + + #TODO: check to see if these args should be negated? + + if P is None: + h = nnet.sigmoid(b + 0.5 * cosines(v,U)) + else: + h = nnet.sigmoid(b + 0.5 * dot(cosines(v,U), P)) + g = nnet.sigmoid(c + dot(v,W)) + return (h, g) + +class MeanCovRBM(object): + """Container for mcRBM parameters that gives more convenient access to mcRBM methods. + """ + + params = property(lambda s: [s.U, s.W, s.a, s.b, s.c]) + + n_visible = property(lambda s: s.W.value.shape[0]) + + def __init__(self, U, W, a, b, c): + self.U = as_shared(U, 'U') + self.W = as_shared(W, 'W') + self.a = as_shared(a, 'a') + self.b = as_shared(b, 'b') + self.c = as_shared(c, 'c') + + assert self.b.type.dtype == 'float32' + + @classmethod + def new_from_dims(cls, + n_I, # input dimensionality + n_K, # number of covariance hidden units + n_F, # number of covariance filters (squared) + n_J, # number of mean filters (linear) + seed = 8923402190, + ): + """ + Return a MeanCovRBM instance with randomly-initialized parameters. + """ + + + if 0: + if P_init == 'diag': + if n_K != n_F: + raise ValueError('cannot use diagonal initialization of non-square P matrix') + import scipy.sparse + P = -scipy.sparse.identity(n_K).tocsr() + else: + raise NotImplementedError() + + rng = np.random.RandomState(seed) + + # initialization taken from Marc'Aurelio + + return cls( + U = numpy_project_onto_ball(rng.randn(n_I, n_F).T).T, + W = rng.randn(n_I, n_J)/np.sqrt((n_I+n_J)/2), + a = np.ones(n_I)*(-2), + b = np.ones(n_K)*2, + c = np.zeros(n_J),) + + def __getstate__(self): + # unpack shared containers, which may have references to Theano stuff + # and are not a long-term stable data type. + return dict( + U = self.U.value, + W = self.W.value, + b = self.b.value, + c = self.c.value) + + def __setstate__(self, dct): + self.__init__(**dct) # calls as_shared on pickled arrays + + def hmc_sampler(self, n_particles=100, seed=7823748): + return HMC_sampler( + positions = [as_shared( + np.random.RandomState(seed^20893).rand( + n_particles, + self.n_visible ))], + energy_fn = lambda p : self.free_energy_given_v(p[0]), + seed=seed) + + def free_energy_given_v(self, v): + return free_energy_given_v(self.params, v) + + def contrastive_gradient(self, pos_v, neg_v): + """Return a list of gradient expressions for self.params + + :param pos_v: positive-phase sample of visible units + :param neg_v: negative-phase sample of visible units + """ + pos_FE = self.free_energy_given_v(pos_v) + neg_FE = self.free_energy_given_v(neg_v) + + gpos_FE = theano.tensor.grad(pos_FE.sum(), self.params) + gneg_FE = theano.tensor.grad(neg_FE.sum(), self.params) + return [ gn - gp for (gp,gn) in zip(gpos_FE, gneg_FE)] + +if __name__ == '__main__': + + print >> sys.stderr, "TODO: use P matrix (aka FH matrix)" + + R,C= 8,8 # the size of image patches + l1_penalty=1e-3 + no_l1_epochs = 10 + + epoch_size=50000 + batchsize = 128 + lr = 0.075 / batchsize + s_lr = TT.scalar() + n_K=256 + n_F=256 + n_J=100 + + rbm = MeanCovRBM.new_from_dims(n_I=R*C, + n_K=n_K, + n_J=n_J, + n_F=n_F, + ) + + sampler = rbm.hmc_sampler(n_particles=100) + + from pylearn.dataset_ops import image_patches + + batch_idx = TT.iscalar() + train_batch = image_patches.image_patches( + s_idx = (batch_idx * batchsize + np.arange(batchsize)), + dims = (1000,R,C), + dtype=floatX, + rasterized=True) + + grads = rbm.contrastive_gradient(pos_v=train_batch, neg_v=sampler.positions[0]) + + learn_fn = function([batch_idx, s_lr], + outputs=[ + grads[0].norm(2), + rbm.U.norm(2) + ], + updates = sgd_updates( + rbm.params, + grads, + lr=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ])) + + for jj in xrange(10000): + sampler.simulate() + l2_of_Ugrad = learn_fn(jj, lr/max(1, jj/(20*epoch_size/batchsize))) + + if jj > no_l1_epochs * epoch_size/batchsize: + rbm.U.value -= l1_penalty * np.sign(rbm.U.value) + rbm.W.value -= l1_penalty * np.sign(rbm.W.value) + + if jj % 5 == 0: + rbm.U.value = numpy_project_onto_ball(rbm.U.value.T).T + + if ((jj < 10) + or (jj < 100 and 0==jj%10) + or (jj < 1000 and 0==jj%100) + or (jj < 10000 and 0==jj%1000)): + print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize), l2_of_Ugrad + print 'neg particles', sampler.positions[0].value.min(), sampler.positions[0].value.max() + image_tiling.save_tiled_raster_images( + image_tiling.tile_raster_images(sampler.positions[0].value, (R,C)), + "sample_%06i.png"%jj) + image_tiling.save_tiled_raster_images( + image_tiling.tile_raster_images(rbm.U.value.T, (R,C)), + "U_%06i.png"%jj) + image_tiling.save_tiled_raster_images( + image_tiling.tile_raster_images(rbm.W.value.T, (R,C)), + "W_%06i.png"%jj) + + + +# +# +# Marc'Aurelio Ranzato's code +# +###################################################################### +# compute the value of the free energy at a given input +# F = - sum log(1+exp(- .5 FH (VF data/norm(data))^2 + bias_cov)) +... +# - sum log(1+exp(w_mean data + bias_mean)) + ... +# - bias_vis data + 0.5 data^2 +# NOTE: FH is constrained to be positive +# (in the paper the sign is negative but the sign in front of it is also flipped) +def compute_energy_mcRBM(data,normdata,vel,energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis): + # normalize input data vectors + data.mult(data, target = t6) # DxP (nr input dims x nr samples) + t6.sum(axis = 0, target = lengthsq) # 1xP + lengthsq.mult(0.5, target = energy) # energy of quadratic regularization term + lengthsq.mult(1./num_vis) # normalize by number of components (like std) + + lengthsq.add(small) # small prevents division by 0 + # energy_j = \sum_i 0.5 data_ij ^2 + # lengthsq_j = 1/ (\sum_i data_ij ^2 + small) + cmt.sqrt(lengthsq, target = length) + # length_j = sqrt(lengthsq_j) + length.reciprocal(target = normcoeff) # 1xP + # normcoef_j = 1/sqrt(lengthsq_j) + data.mult_by_row(normcoeff, target = normdata) # normalized data + # normdata is like data, but cols have unit L2 norm + + ## potential + # covariance contribution + cmt.dot(VF.T, normdata, target = feat) # HxP (nr factors x nr samples) + feat.mult(feat, target = featsq) # HxP + + # featsq is the squared cosines (VF with data) + cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples) + t1.mult(-0.5) + t1.add_col_vec(bias_cov) # OxP + cmt.exp(t1) # OxP + t1.add(1, target = t2) # OxP + cmt.log(t2) + t2.mult(-1) + energy.add_sums(t2, axis=0) + # mean contribution + cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples) + feat_mean.add_col_vec(bias_mean) # HxP + cmt.exp(feat_mean) + feat_mean.add(1) + cmt.log(feat_mean) + feat_mean.mult(-1) + energy.add_sums(feat_mean, axis=0) + # visible bias term + data.mult_by_col(bias_vis, target = t6) + t6.mult(-1) # DxP + energy.add_sums(t6, axis=0) # 1xP + # kinetic + vel.mult(vel, target = t6) + energy.add_sums(t6, axis = 0, mult = .5) + +###################################################### +# mcRBM trainer: sweeps over the training set. +# For each batch of samples compute derivatives to update the parameters +# at the training samples and at the negative samples drawn calling HMC sampler. +def train_mcRBM(): + + config = ConfigParser() + config.read('input_configuration') + + verbose = config.getint('VERBOSITY','verbose') + + num_epochs = config.getint('MAIN_PARAMETER_SETTING','num_epochs') + batch_size = config.getint('MAIN_PARAMETER_SETTING','batch_size') + startFH = config.getint('MAIN_PARAMETER_SETTING','startFH') + startwd = config.getint('MAIN_PARAMETER_SETTING','startwd') + doPCD = config.getint('MAIN_PARAMETER_SETTING','doPCD') + + # model parameters + num_fac = config.getint('MODEL_PARAMETER_SETTING','num_fac') + num_hid_cov = config.getint('MODEL_PARAMETER_SETTING','num_hid_cov') + num_hid_mean = config.getint('MODEL_PARAMETER_SETTING','num_hid_mean') + apply_mask = config.getint('MODEL_PARAMETER_SETTING','apply_mask') + + # load data + data_file_name = config.get('DATA','data_file_name') + d = loadmat(data_file_name) # input in the format PxD (P vectorized samples with D dimensions) + totnumcases = d["whitendata"].shape[0] + d = d["whitendata"][0:floor(totnumcases/batch_size)*batch_size,:].copy() + totnumcases = d.shape[0] + num_vis = d.shape[1] + num_batches = int(totnumcases/batch_size) + dev_dat = cmt.CUDAMatrix(d.T) # VxP + + # training parameters + epsilon = config.getfloat('OPTIMIZER_PARAMETERS','epsilon') + epsilonVF = 2*epsilon + epsilonFH = 0.02*epsilon + epsilonb = 0.02*epsilon + epsilonw_mean = 0.2*epsilon + epsilonb_mean = 0.1*epsilon + weightcost_final = config.getfloat('OPTIMIZER_PARAMETERS','weightcost_final') + + # HMC setting + hmc_step_nr = config.getint('HMC_PARAMETERS','hmc_step_nr') + hmc_step = 0.01 + hmc_target_ave_rej = config.getfloat('HMC_PARAMETERS','hmc_target_ave_rej') + hmc_ave_rej = hmc_target_ave_rej + + # initialize weights + VF = cmt.CUDAMatrix(np.array(0.02 * np.random.randn(num_vis, num_fac), dtype=np.float32, order='F')) # VxH + if apply_mask == 0: + FH = cmt.CUDAMatrix( np.array( np.eye(num_fac,num_hid_cov), dtype=np.float32, order='F') ) # HxO + else: + dd = loadmat('your_FHinit_mask_file.mat') # see CVPR2010paper_material/topo2D_3x3_stride2_576filt.mat for an example + FH = cmt.CUDAMatrix( np.array( dd["FH"], dtype=np.float32, order='F') ) + bias_cov = cmt.CUDAMatrix( np.array(2.0*np.ones((num_hid_cov, 1)), dtype=np.float32, order='F') ) + bias_vis = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F') ) + w_mean = cmt.CUDAMatrix( np.array( 0.05 * np.random.randn(num_vis, num_hid_mean), dtype=np.float32, order='F') ) # VxH + bias_mean = cmt.CUDAMatrix( np.array( -2.0*np.ones((num_hid_mean,1)), dtype=np.float32, order='F') ) + + # initialize variables to store derivatives + VFinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_fac)), dtype=np.float32, order='F')) + FHinc = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F')) + bias_covinc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_cov, 1)), dtype=np.float32, order='F')) + bias_visinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F')) + w_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_hid_mean)), dtype=np.float32, order='F')) + bias_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_mean, 1)), dtype=np.float32, order='F')) + + # initialize temporary storage + data = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + normdata = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + negdataini = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + feat = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F')) + featsq = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F')) + negdata = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F')) + old_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) + new_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) + gradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + normgradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + thresh = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) + feat_mean = cmt.CUDAMatrix( np.array(np.empty((num_hid_mean, batch_size)), dtype=np.float32, order='F')) + vel = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F')) + length = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP + lengthsq = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP + normcoeff = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP + if apply_mask==1: # this used to constrain very large FH matrices only allowing to change values in a neighborhood + dd = loadmat('your_FHinit_mask_file.mat') + mask = cmt.CUDAMatrix( np.array(dd["mask"], dtype=np.float32, order='F')) + normVF = 1 + small = 0.5 + + # other temporary vars + t1 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F')) + t2 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F')) + t3 = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F')) + t4 = cmt.CUDAMatrix( np.array(np.empty((1,batch_size)), dtype=np.float32, order='F')) + t5 = cmt.CUDAMatrix( np.array(np.empty((1,1)), dtype=np.float32, order='F')) + t6 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) + t7 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) + t8 = cmt.CUDAMatrix( np.array(np.empty((num_vis, num_fac)), dtype=np.float32, order='F')) + t9 = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F')) + t10 = cmt.CUDAMatrix( np.array(np.empty((1,num_fac)), dtype=np.float32, order='F')) + t11 = cmt.CUDAMatrix( np.array(np.empty((1,num_hid_cov)), dtype=np.float32, order='F')) + + # start training + for epoch in range(num_epochs): + + print "Epoch " + str(epoch + 1) + + # anneal learning rates + epsilonVFc = epsilonVF/max(1,epoch/20) + epsilonFHc = epsilonFH/max(1,epoch/20) + epsilonbc = epsilonb/max(1,epoch/20) + epsilonw_meanc = epsilonw_mean/max(1,epoch/20) + epsilonb_meanc = epsilonb_mean/max(1,epoch/20) + weightcost = weightcost_final + + if epoch <= startFH: + epsilonFHc = 0 + if epoch <= startwd: + weightcost = 0 + + for batch in range(num_batches): + + # get current minibatch + data = dev_dat.slice(batch*batch_size,(batch + 1)*batch_size) # DxP (nr dims x nr samples) + + # normalize input data + data.mult(data, target = t6) # DxP + t6.sum(axis = 0, target = lengthsq) # 1xP + lengthsq.mult(1./num_vis) # normalize by number of components (like std) + lengthsq.add(small) # small avoids division by 0 + cmt.sqrt(lengthsq, target = length) + length.reciprocal(target = normcoeff) # 1xP + data.mult_by_row(normcoeff, target = normdata) # normalized data + ## compute positive sample derivatives + # covariance part + cmt.dot(VF.T, normdata, target = feat) # HxP (nr facs x nr samples) + feat.mult(feat, target = featsq) # HxP + cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples) + t1.mult(-0.5) + t1.add_col_vec(bias_cov) # OxP + t1.apply_sigmoid(target = t2) # OxP + cmt.dot(featsq, t2.T, target = FHinc) # HxO + cmt.dot(FH,t2, target = t3) # HxP + t3.mult(feat) + cmt.dot(normdata, t3.T, target = VFinc) # VxH + t2.sum(axis = 1, target = bias_covinc) + bias_covinc.mult(-1) + # visible bias + data.sum(axis = 1, target = bias_visinc) + bias_visinc.mult(-1) + # mean part + cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples) + feat_mean.add_col_vec(bias_mean) # HxP + feat_mean.apply_sigmoid() # HxP + feat_mean.mult(-1) + cmt.dot(data, feat_mean.T, target = w_meaninc) + feat_mean.sum(axis = 1, target = bias_meaninc) + + # HMC sampling: draw an approximate sample from the model + if doPCD == 0: # CD-1 (set negative data to current training samples) + hmc_step, hmc_ave_rej = draw_HMC_samples(data,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis) + else: # PCD-1 (use previous negative data as starting point for chain) + negdataini.assign(negdata) + hmc_step, hmc_ave_rej = draw_HMC_samples(negdataini,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis) + + # compute derivatives at the negative samples + # normalize input data + negdata.mult(negdata, target = t6) # DxP + t6.sum(axis = 0, target = lengthsq) # 1xP + lengthsq.mult(1./num_vis) # normalize by number of components (like std) + lengthsq.add(small) + cmt.sqrt(lengthsq, target = length) + length.reciprocal(target = normcoeff) # 1xP + negdata.mult_by_row(normcoeff, target = normdata) # normalized data + # covariance part + cmt.dot(VF.T, normdata, target = feat) # HxP + feat.mult(feat, target = featsq) # HxP + cmt.dot(FH.T,featsq, target = t1) # OxP + t1.mult(-0.5) + t1.add_col_vec(bias_cov) # OxP + t1.apply_sigmoid(target = t2) # OxP + FHinc.subtract_dot(featsq, t2.T) # HxO + FHinc.mult(0.5) + cmt.dot(FH,t2, target = t3) # HxP + t3.mult(feat) + VFinc.subtract_dot(normdata, t3.T) # VxH + bias_covinc.add_sums(t2, axis = 1) + # visible bias + bias_visinc.add_sums(negdata, axis = 1) + # mean part + cmt.dot(w_mean.T, negdata, target = feat_mean) # HxP + feat_mean.add_col_vec(bias_mean) # HxP + feat_mean.apply_sigmoid() # HxP + w_meaninc.add_dot(negdata, feat_mean.T) + bias_meaninc.add_sums(feat_mean, axis = 1) + + # update parameters + VFinc.add_mult(VF.sign(), weightcost) # L1 regularization + VF.add_mult(VFinc, -epsilonVFc/batch_size) + # normalize columns of VF: normalize by running average of their norm + VF.mult(VF, target = t8) + t8.sum(axis = 0, target = t10) + cmt.sqrt(t10) + t10.sum(axis=1,target = t5) + t5.copy_to_host() + normVF = .95*normVF + (.05/num_fac) * t5.numpy_array[0,0] # estimate norm + t10.reciprocal() + VF.mult_by_row(t10) + VF.mult(normVF) + bias_cov.add_mult(bias_covinc, -epsilonbc/batch_size) + bias_vis.add_mult(bias_visinc, -epsilonbc/batch_size) + + if epoch > startFH: + FHinc.add_mult(FH.sign(), weightcost) # L1 regularization + FH.add_mult(FHinc, -epsilonFHc/batch_size) # update + # set to 0 negative entries in FH + FH.greater_than(0, target = t9) + FH.mult(t9) + if apply_mask==1: + FH.mult(mask) + # normalize columns of FH: L1 norm set to 1 in each column + FH.sum(axis = 0, target = t11) + t11.reciprocal() + FH.mult_by_row(t11) + w_meaninc.add_mult(w_mean.sign(),weightcost) + w_mean.add_mult(w_meaninc, -epsilonw_meanc/batch_size) + bias_mean.add_mult(bias_meaninc, -epsilonb_meanc/batch_size) + + if verbose == 1: + print "VF: " + '%3.2e' % VF.euclid_norm() + ", DVF: " + '%3.2e' % (VFinc.euclid_norm()*(epsilonVFc/batch_size)) + ", FH: " + '%3.2e' % FH.euclid_norm() + ", DFH: " + '%3.2e' % (FHinc.euclid_norm()*(epsilonFHc/batch_size)) + ", bias_cov: " + '%3.2e' % bias_cov.euclid_norm() + ", Dbias_cov: " + '%3.2e' % (bias_covinc.euclid_norm()*(epsilonbc/batch_size)) + ", bias_vis: " + '%3.2e' % bias_vis.euclid_norm() + ", Dbias_vis: " + '%3.2e' % (bias_visinc.euclid_norm()*(epsilonbc/batch_size)) + ", wm: " + '%3.2e' % w_mean.euclid_norm() + ", Dwm: " + '%3.2e' % (w_meaninc.euclid_norm()*(epsilonw_meanc/batch_size)) + ", bm: " + '%3.2e' % bias_mean.euclid_norm() + ", Dbm: " + '%3.2e' % (bias_meaninc.euclid_norm()*(epsilonb_meanc/batch_size)) + ", step: " + '%3.2e' % hmc_step + ", rej: " + '%3.2e' % hmc_ave_rej + sys.stdout.flush() + # back-up every once in a while + if np.mod(epoch,10) == 0: + VF.copy_to_host() + FH.copy_to_host() + bias_cov.copy_to_host() + w_mean.copy_to_host() + bias_mean.copy_to_host() + bias_vis.copy_to_host() + savemat("ws_temp", {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array,'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch}) + # final back-up + VF.copy_to_host() + FH.copy_to_host() + bias_cov.copy_to_host() + bias_vis.copy_to_host() + w_mean.copy_to_host() + bias_mean.copy_to_host() + savemat("ws_fac" + str(num_fac) + "_cov" + str(num_hid_cov) + "_mean" + str(num_hid_mean), {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array, 'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch})