Mercurial > pylearn
changeset 1000:d4a14c6c36e0
mcRBM - post code-review #1 with Guillaume
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Tue, 24 Aug 2010 19:24:54 -0400 |
parents | c6d08a760960 |
children | 660d784d14c7 |
files | pylearn/algorithms/mcRBM.py pylearn/algorithms/tests/__init__.py pylearn/algorithms/tests/test_mcRBM.py pylearn/sandbox/train_mcRBM.py |
diffstat | 3 files changed, 695 insertions(+), 163 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py Tue Aug 24 17:01:09 2010 -0400 +++ b/pylearn/algorithms/mcRBM.py Tue Aug 24 19:24:54 2010 -0400 @@ -190,8 +190,7 @@ # + 0.5 \sum_i v_i^2 # - \sum_i a_i v_i -import sys -import logging +import sys, os, logging import numpy as np import numpy @@ -201,21 +200,54 @@ floatX = theano.config.floatX import pylearn +#TODO: clean up the HMC_sampler code +#TODO: think of naming convention for acronyms + suffix? from pylearn.sampling.hmc import HMC_sampler from pylearn.io import image_tiling from pylearn.gd.sgd import sgd_updates +import pylearn.dataset_ops.image_patches -#TODO: This should be in the datasets folder -import pylearn.datasets.config -import pylearn.dataset_ops.image_patches -from pylearn.dataset_ops.protocol import TensorFnDataset -from pylearn.dataset_ops.memo import memo -import pylearn -import scipy.io -import os +########################################### +# +# Candidates for factoring +# +########################################### + +#TODO: Document, move to pylearn's math lib +def l1(X): + return abs(X).sum() + +#TODO: Document, move to pylearn's math lib +def l2(X): + return TT.sqrt((X**2).sum()) + +#TODO: Document, move to pylearn's math lib +def contrastive_cost(free_energy_fn, pos_v, neg_v): + return (free_energy_fn(pos_v) - free_energy_fn(neg_v)).sum() +#TODO: Typical use of contrastive_cost is to later use tensor.grad, but in that case we want to +# block gradient going through neg_v +def contrastive_grad(free_energy_fn, pos_v, neg_v, params, other_cost=0): + """ + :param pos_v: positive-phase sample of visible units + :param neg_v: negative-phase sample of visible units + """ + #block the grad through neg_v + cost=contrastive_cost(free_energy_fn, pos_v, neg_v) + if other_cost: + cost = cost + other_cost + return theano.tensor.grad(cost, + wrt=params, + consider_constant=[neg_v]) -#TODO: This should be in the nnet part of the library +########################################### +# +# Expressions that are mcRBM-specific +# +########################################### + +#TODO: make global function to initialize parameter + def hidden_cov_units_preactivation_given_v(rbm, v, small=0.5): """Return argument to the sigmoid that would give mean of covariance hid units @@ -248,24 +280,6 @@ """ return sum(free_energy_terms_given_v(rbm,v)) -def contrastive_gradient(rbm, pos_v, neg_v, U_l1_penalty=0, W_l1_penalty=0): - """Return a list of gradient expressions for the rbm parameters - - :param pos_v: positive-phase sample of visible units - :param neg_v: negative-phase sample of visible units - :param U_l1_penalty: a scalar-valued multiplier on the L1 penalty on U - :param W_l1_penalty: a scalar-valued multiplier on the L1 penalty on W - """ - U, W, a, b, c = rbm - pos_FE = free_energy_given_v(rbm, pos_v) - neg_FE = free_energy_given_v(rbm, neg_v) - c0 = (pos_FE - neg_FE).sum() - c1 = abs(U).sum()*U_l1_penalty - c2 = abs(W).sum()*W_l1_penalty - cost = c0 + c1 + c2 - rval = theano.tensor.grad(cost, list(rbm)) - return rval - def expected_h_g_given_v(rbm, v): """Returns tuple (`h`, `g`) of theano expression conditional expectations in an mcRBM. @@ -291,7 +305,6 @@ except AttributeError: return W.shape[0] - def sampler(rbm, n_particles, n_visible=None, rng=7823748): """Return an `HMC_sampler` that will draw samples from the distribution over visible units specified by this RBM. @@ -313,6 +326,12 @@ seed=int(rng.randint(2**30))) return rval +############################# +# +# Convenient data container +# +############################# + class MeanCovRBM(object): """Container for mcRBM parameters @@ -380,140 +399,12 @@ d[key] = shared(d[key], name=key) self.__init__(**d) -if __name__ == '__main__': - - dataset='MAR' - if dataset == 'MAR': - n_vis=105 - n_patches=10240 - else: - R,C= 16,16 # the size of image patches - n_vis=R*C - n_patches=100000 - - n_train_iters=5000 - - n_burnin_steps=10000 - - l1_penalty=1e-3 - no_l1_epochs = 10 - effective_l1_penalty=0.0 - - epoch_size=n_patches - batchsize = 128 - lr = 0.075 / batchsize - s_lr = TT.scalar() - s_l1_penalty=TT.scalar() - n_K=256 - n_J=100 - rbm = MeanCovRBM.new_from_dims(n_I=n_vis, n_K=n_K, n_J=n_J) - - smplr = sampler(rbm, n_particles=batchsize) - - def l2(X): - return numpy.sqrt((X**2).sum()) - if dataset == 'MAR': - tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010 - else: - def tile(X, fname): - _img = image_tiling.tile_raster_images(X, - img_shape=(R,C), - min_dynamic_range=1e-2) - image_tiling.save_tiled_raster_images(_img, fname) +#TODO: put the normalization of U as a global function - batch_idx = TT.iscalar() - - if dataset == 'MAR': - train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize)) - else: - train_batch = pylearn.dataset_ops.image_patches.image_patches( - s_idx = (batch_idx * batchsize + np.arange(batchsize)), - dims = (n_patches,R,C), - center=True, - unitvar=True, - dtype=floatX, - rasterized=True) - - imgs_fn = function([batch_idx], outputs=train_batch) - grads = contrastive_gradient(rbm, - pos_v=train_batch, - neg_v=smplr.positions[0], - U_l1_penalty=s_l1_penalty, - W_l1_penalty=s_l1_penalty) - sgd_ups = sgd_updates( - rbm.params, - grads, - stepsizes=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ]) - learn_fn = function([batch_idx, s_lr, s_l1_penalty], - outputs=[ - grads[0].norm(2), - (sgd_ups[0][1] - sgd_ups[0][0]).norm(2), - (sgd_ups[1][1] - sgd_ups[1][0]).norm(2), - ], - updates = sgd_ups) - - print "Learning..." - normVF=1 - last_epoch = -1 - for jj in xrange(n_train_iters): - epoch = jj*batchsize / epoch_size - - print_jj = epoch != last_epoch - last_epoch = epoch - - if epoch > 10: - break - - if print_jj: - tile(imgs_fn(jj), "imgs_%06i.png"%jj) - tile(smplr.positions[0].value, "sample_%06i.png"%jj) - tile(rbm.U.value.T, "U_%06i.png"%jj) - tile(rbm.W.value.T, "W_%06i.png"%jj) - - print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize) - - print 'l2(U)', l2(rbm.U.value), - print 'l2(W)', l2(rbm.W.value) +#TODO: put the learning loop as a global function or class, so that someone could load and *TRAIN* an mcRBM!!! - print 'U min max', rbm.U.value.min(), rbm.U.value.max(), - print 'W min max', rbm.W.value.min(), rbm.W.value.max(), - print 'a min max', rbm.a.value.min(), rbm.a.value.max(), - print 'b min max', rbm.b.value.min(), rbm.b.value.max(), - print 'c min max', rbm.c.value.min(), rbm.c.value.max() - - print 'parts min', smplr.positions[0].value.min(), - print 'max',smplr.positions[0].value.max(), - print 'HMC step', smplr.stepsize, - print 'arate', smplr.avg_acceptance_rate - - smplr.simulate() - - l2_of_Ugrad = learn_fn(jj, - lr/max(1, jj/(20*epoch_size/batchsize)), - effective_l1_penalty) - - if print_jj: - print 'l2(U_grad)', float(l2_of_Ugrad[0]), - print 'l2(U_inc)', float(l2_of_Ugrad[1]), - print 'l2(W_inc)', float(l2_of_Ugrad[2]), - #print 'FE+', float(l2_of_Ugrad[2]), - #print 'FE+[0]', float(l2_of_Ugrad[3]), - #print 'FE+[1]', float(l2_of_Ugrad[4]), - #print 'FE+[2]', float(l2_of_Ugrad[5]), - #print 'FE+[3]', float(l2_of_Ugrad[6]) - - if jj == no_l1_epochs * epoch_size/batchsize: - print "Activating L1 weight decay" - effective_l1_penalty = 1e-3 - - # weird normalization technique... - # It constrains all the columns of the matrix to have the same length - # But the matrix itself is re-scaled to have an arbitrary abslute size. - U = rbm.U.value - U_norms = np.sqrt((U*U).sum(axis=0)) - assert len(U_norms) == n_K - normVF = .95 * normVF + .05 * np.mean(U_norms) - rbm.U.value = rbm.U.value * normVF/U_norms - +if __name__ == '__main__': + import pylearn.algorithms.tests.test_mcRBM + pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010(as_unittest=True)
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/algorithms/tests/test_mcRBM.py Tue Aug 24 19:24:54 2010 -0400 @@ -0,0 +1,169 @@ + + +from pylearn.algorithms.mcRBM import * + + +def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True): + dataset='MAR' + if dataset == 'MAR': + n_vis=105 + n_patches=10240 + else: + R,C= 16,16 # the size of image patches + n_vis=R*C + n_patches=100000 + + n_train_iters=5000 + + n_burnin_steps=10000 + + l1_penalty=1e-3 + no_l1_epochs = 10 + effective_l1_penalty=0.0 + + epoch_size=n_patches + batchsize = 128 + lr = 0.075 / batchsize + s_lr = TT.scalar() + s_l1_penalty=TT.scalar() + n_K=256 + n_J=100 + + rbm = MeanCovRBM.new_from_dims(n_I=n_vis, n_K=n_K, n_J=n_J) + + smplr = sampler(rbm, n_particles=batchsize) + + def l2(X): + return numpy.sqrt((X**2).sum()) + if dataset == 'MAR': + tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010 + else: + def tile(X, fname): + _img = image_tiling.tile_raster_images(X, + img_shape=(R,C), + min_dynamic_range=1e-2) + image_tiling.save_tiled_raster_images(_img, fname) + + batch_idx = TT.iscalar() + + if dataset == 'MAR': + train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize)) + else: + train_batch = pylearn.dataset_ops.image_patches.image_patches( + s_idx = (batch_idx * batchsize + np.arange(batchsize)), + dims = (n_patches,R,C), + center=True, + unitvar=True, + dtype=floatX, + rasterized=True) + + if not as_unittest: + imgs_fn = function([batch_idx], outputs=train_batch) + + grads = contrastive_grad( + free_energy_fn=lambda v: free_energy_given_v(rbm, v), + pos_v=train_batch, + neg_v=smplr.positions[0], + params=list(rbm), + other_cost=(l1(rbm.U)+l1(rbm.W)) * s_l1_penalty) + sgd_ups = sgd_updates( + rbm.params, + grads, + stepsizes=[2*s_lr, .2*s_lr, .02*s_lr, .1*s_lr, .02*s_lr ]) + learn_fn = function([batch_idx, s_lr, s_l1_penalty], + outputs=[ + grads[0].norm(2), + (sgd_ups[0][1] - sgd_ups[0][0]).norm(2), + (sgd_ups[1][1] - sgd_ups[1][0]).norm(2), + ], + updates = sgd_ups) + + print "Learning..." + normVF=1 + last_epoch = -1 + for jj in xrange(n_train_iters): + epoch = jj*batchsize / epoch_size + + print_jj = epoch != last_epoch + last_epoch = epoch + + if epoch > 10: + break + + if as_unittest and epoch == 5: + U = rbm.U.value + W = rbm.W.value + def allclose(a,b): + return numpy.allclose(a,b,rtol=1.01,atol=1e-3) + print "" + print "--------------" + print "assert allclose(l2(U), %f)"%l2(U) + print "assert allclose(l2(W), %f)"%l2(W) + print "assert allclose(U.min(), %f)"%U.min() + print "assert allclose(U.max(), %f)"%U.max() + print "assert allclose(W.min(),%f)"%W.min() + print "assert allclose(W.max(), %f)"%W.max() + print "--------------" + + assert allclose(l2(U), 21.351664) + assert allclose(l2(W), 6.275828) + assert allclose(U.min(), -1.176703) + assert allclose(U.max(), 0.859802) + assert allclose(W.min(),-0.223128) + assert allclose(W.max(), 0.227558 ) + + break + + if print_jj: + if not as_unittest: + tile(imgs_fn(jj), "imgs_%06i.png"%jj) + tile(smplr.positions[0].value, "sample_%06i.png"%jj) + tile(rbm.U.value.T, "U_%06i.png"%jj) + tile(rbm.W.value.T, "W_%06i.png"%jj) + + print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize) + + print 'l2(U)', l2(rbm.U.value), + print 'l2(W)', l2(rbm.W.value) + + print 'U min max', rbm.U.value.min(), rbm.U.value.max(), + print 'W min max', rbm.W.value.min(), rbm.W.value.max(), + print 'a min max', rbm.a.value.min(), rbm.a.value.max(), + print 'b min max', rbm.b.value.min(), rbm.b.value.max(), + print 'c min max', rbm.c.value.min(), rbm.c.value.max() + + print 'parts min', smplr.positions[0].value.min(), + print 'max',smplr.positions[0].value.max(), + print 'HMC step', smplr.stepsize, + print 'arate', smplr.avg_acceptance_rate + + # Continue HMC chain + smplr.simulate() + + # Do CD update + l2_of_Ugrad = learn_fn(jj, + lr/max(1, jj/(20*epoch_size/batchsize)), + effective_l1_penalty) + + if print_jj: + print 'l2(U_grad)', float(l2_of_Ugrad[0]), + print 'l2(U_inc)', float(l2_of_Ugrad[1]), + print 'l2(W_inc)', float(l2_of_Ugrad[2]), + #print 'FE+', float(l2_of_Ugrad[2]), + #print 'FE+[0]', float(l2_of_Ugrad[3]), + #print 'FE+[1]', float(l2_of_Ugrad[4]), + #print 'FE+[2]', float(l2_of_Ugrad[5]), + #print 'FE+[3]', float(l2_of_Ugrad[6]) + + if jj == no_l1_epochs * epoch_size/batchsize: + print "Activating L1 weight decay" + effective_l1_penalty = 1e-3 + + # weird normalization technique... + # It constrains all the columns of the matrix to have the same length + # But the matrix itself is re-scaled to have an arbitrary abslute size. + U = rbm.U.value + U_norms = np.sqrt((U*U).sum(axis=0)) + assert len(U_norms) == n_K + normVF = .95 * normVF + .05 * np.mean(U_norms) + rbm.U.value = rbm.U.value * normVF/U_norms
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/sandbox/train_mcRBM.py Tue Aug 24 19:24:54 2010 -0400 @@ -0,0 +1,472 @@ +""" +This is a copy of mcRBM training that James modified to print out more information, visualize +filters, etc. Once mcRBM is stable, it can be deleted. +""" +# mcRBM training +# Refer to Ranzato and Hinton CVPR 2010 "Modeling Pixel Means and Covariances Using Factorized Third-Order BMs" +# +# Marc'Aurelio Ranzato +# 28 July 2010 + +import sys +import numpy as np +import cudamat as cmt +from scipy.io import loadmat, savemat +#import gpu_lock # put here you locking system package, if any +from ConfigParser import * + +demodata = None + +from pylearn.io import image_tiling +def tile(X, fname): + X = np.dot(X, demodata['invpcatransf'].T) + R=16 + C=16 + X = (X[:,:256], X[:,256:512], X[:,512:], None) + #X = (X[:,0::3], X[:,1::3], X[:,2::3], None) + _img = image_tiling.tile_raster_images(X, + img_shape=(R,C), + min_dynamic_range=1e-2) + image_tiling.save_tiled_raster_images(_img, fname) + +def save_imshow(X, fname): + image_tiling.Image.fromarray( + (image_tiling.scale_to_unit_interval(X)*255).astype('uint8'), + 'L').save(fname) + +###################################################################### +# compute the value of the free energy at a given input +# F = - sum log(1+exp(- .5 FH (VF data/norm(data))^2 + bias_cov)) +... +# - sum log(1+exp(w_mean data + bias_mean)) + ... +# - bias_vis data + 0.5 data^2 +# NOTE: FH is constrained to be positive +# (in the paper the sign is negative but the sign in front of it is also flipped) +def compute_energy_mcRBM(data,normdata,vel,energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis): + # normalize input data vectors + data.mult(data, target = t6) # DxP (nr input dims x nr samples) + t6.sum(axis = 0, target = lengthsq) # 1xP + lengthsq.mult(0.5, target = energy) # energy of quadratic regularization term + lengthsq.mult(1./num_vis) # normalize by number of components (like std) + lengthsq.add(small) # small prevents division by 0 + cmt.sqrt(lengthsq, target = length) + length.reciprocal(target = normcoeff) # 1xP + data.mult_by_row(normcoeff, target = normdata) # normalized data + ## potential + # covariance contribution + cmt.dot(VF.T, normdata, target = feat) # HxP (nr factors x nr samples) + feat.mult(feat, target = featsq) # HxP + cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples) + t1.mult(-0.5) + t1.add_col_vec(bias_cov) # OxP + cmt.exp(t1) # OxP + t1.add(1, target = t2) # OxP + cmt.log(t2) + t2.mult(-1) + energy.add_sums(t2, axis=0) + # mean contribution + cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples) + feat_mean.add_col_vec(bias_mean) # HxP + cmt.exp(feat_mean) + feat_mean.add(1) + cmt.log(feat_mean) + feat_mean.mult(-1) + energy.add_sums(feat_mean, axis=0) + # visible bias term + data.mult_by_col(bias_vis, target = t6) + t6.mult(-1) # DxP + energy.add_sums(t6, axis=0) # 1xP + # kinetic + vel.mult(vel, target = t6) + energy.add_sums(t6, axis = 0, mult = .5) + +################################################################# +# compute the derivative if the free energy at a given input +def compute_gradient_mcRBM(data,normdata,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t3,t4,t6,feat,featsq,feat_mean,gradient,normgradient,length,lengthsq,normcoeff,small,num_vis): + # normalize input data + data.mult(data, target = t6) # DxP + t6.sum(axis = 0, target = lengthsq) # 1xP + lengthsq.mult(1./num_vis) # normalize by number of components (like std) + lengthsq.add(small) + cmt.sqrt(lengthsq, target = length) + length.reciprocal(target = normcoeff) # 1xP + data.mult_by_row(normcoeff, target = normdata) # normalized data + cmt.dot(VF.T, normdata, target = feat) # HxP + feat.mult(feat, target = featsq) # HxP + cmt.dot(FH.T,featsq, target = t1) # OxP + t1.mult(-.5) + t1.add_col_vec(bias_cov) # OxP + t1.apply_sigmoid(target = t2) # OxP + cmt.dot(FH,t2, target = t3) # HxP + t3.mult(feat) + cmt.dot(VF, t3, target = normgradient) # VxP + # final bprop through normalization + length.mult(lengthsq, target = normcoeff) + normcoeff.reciprocal() # 1xP + normgradient.mult(data, target = gradient) # VxP + gradient.sum(axis = 0, target = t4) # 1xP + t4.mult(-1./num_vis) + data.mult_by_row(t4, target = gradient) + normgradient.mult_by_row(lengthsq, target = t6) + gradient.add(t6) + gradient.mult_by_row(normcoeff) + # add quadratic term gradient + gradient.add(data) + # add visible bias term + gradient.add_col_mult(bias_vis, -1) + # add MEAN contribution to gradient + cmt.dot(w_mean.T, data, target = feat_mean) # HxP + feat_mean.add_col_vec(bias_mean) # HxP + feat_mean.apply_sigmoid() # HxP + gradient.subtract_dot(w_mean,feat_mean) # VxP + +############################################################3 +# Hybrid Monte Carlo sampler +def draw_HMC_samples(data,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis): + vel.fill_with_randn() + negdata.assign(data) + compute_energy_mcRBM(negdata,normdata,vel,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis) + compute_gradient_mcRBM(negdata,normdata,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t3,t4,t6,feat,featsq,feat_mean,gradient,normgradient,length,lengthsq,normcoeff,small,num_vis) + # half step + vel.add_mult(gradient, -0.5*hmc_step) + negdata.add_mult(vel,hmc_step) + # full leap-frog steps + for ss in range(hmc_step_nr - 1): + ## re-evaluate the gradient + compute_gradient_mcRBM(negdata,normdata,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t3,t4,t6,feat,featsq,feat_mean,gradient,normgradient,length,lengthsq,normcoeff,small,num_vis) + # update variables + vel.add_mult(gradient, -hmc_step) + negdata.add_mult(vel,hmc_step) + # final half-step + compute_gradient_mcRBM(negdata,normdata,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t3,t4,t6,feat,featsq,feat_mean,gradient,normgradient,length,lengthsq,normcoeff,small,num_vis) + vel.add_mult(gradient, -0.5*hmc_step) + # compute new energy + compute_energy_mcRBM(negdata,normdata,vel,new_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,t1,t2,t6,feat,featsq,feat_mean,length,lengthsq,normcoeff,small,num_vis) + # rejecton + old_energy.subtract(new_energy, target = thresh) + cmt.exp(thresh) + t4.fill_with_rand() + t4.less_than(thresh) + # update negdata and rejection rate + t4.mult(-1) + t4.add(1) # now 1's detect rejections + t4.sum(axis = 1, target = t5) + t5.copy_to_host() + rej = t5.numpy_array[0,0]/batch_size + data.mult_by_row(t4, target = t6) + negdata.mult_by_row(t4, target = t7) + negdata.subtract(t7) + negdata.add(t6) + hmc_ave_rej = 0.9*hmc_ave_rej + 0.1*rej + if hmc_ave_rej < hmc_target_ave_rej: + hmc_step = min(hmc_step*1.01,0.25) + else: + hmc_step = max(hmc_step*0.99,.001) + return hmc_step, hmc_ave_rej + + +###################################################### +# mcRBM trainer: sweeps over the training set. +# For each batch of samples compute derivatives to update the parameters +# at the training samples and at the negative samples drawn calling HMC sampler. +def train_mcRBM(): + + config = ConfigParser() + config.read('input_configuration') + + verbose = config.getint('VERBOSITY','verbose') + + num_epochs = config.getint('MAIN_PARAMETER_SETTING','num_epochs') + batch_size = config.getint('MAIN_PARAMETER_SETTING','batch_size') + startFH = config.getint('MAIN_PARAMETER_SETTING','startFH') + startwd = config.getint('MAIN_PARAMETER_SETTING','startwd') + doPCD = config.getint('MAIN_PARAMETER_SETTING','doPCD') + + # model parameters + num_fac = config.getint('MODEL_PARAMETER_SETTING','num_fac') + num_hid_cov = config.getint('MODEL_PARAMETER_SETTING','num_hid_cov') + num_hid_mean = config.getint('MODEL_PARAMETER_SETTING','num_hid_mean') + apply_mask = config.getint('MODEL_PARAMETER_SETTING','apply_mask') + + # load data + data_file_name = config.get('DATA','data_file_name') + d = loadmat(data_file_name) # input in the format PxD (P vectorized samples with D dimensions) + global demodata + demodata = d + totnumcases = d["whitendata"].shape[0] + d = d["whitendata"][0:np.floor(totnumcases/batch_size)*batch_size,:].copy() + totnumcases = d.shape[0] + num_vis = d.shape[1] + num_batches = int(totnumcases/batch_size) + dev_dat = cmt.CUDAMatrix(d.T) # VxP + + tile(d[:100], "100_whitened_data.png") + + # training parameters + epsilon = config.getfloat('OPTIMIZER_PARAMETERS','epsilon') + epsilonVF = 2*epsilon + epsilonFH = 0.02*epsilon + epsilonb = 0.02*epsilon + epsilonw_mean = 0.2*epsilon + epsilonb_mean = 0.1*epsilon + weightcost_final = config.getfloat('OPTIMIZER_PARAMETERS','weightcost_final') + + # HMC setting + hmc_step_nr = config.getint('HMC_PARAMETERS','hmc_step_nr') + hmc_step = 0.01 + hmc_target_ave_rej = config.getfloat('HMC_PARAMETERS','hmc_target_ave_rej') + hmc_ave_rej = hmc_target_ave_rej + + # initialize weights + VF = cmt.CUDAMatrix(np.array(0.02 * np.random.randn(num_vis, num_fac), dtype=np.float32, order='F')) # VxH + if apply_mask == 0: + FH = cmt.CUDAMatrix( np.array( np.eye(num_fac,num_hid_cov), dtype=np.float32, order='F') ) # HxO + else: + dd = loadmat('your_FHinit_mask_file.mat') # see CVPR2010paper_material/topo2D_3x3_stride2_576filt.mat for an example + FH = cmt.CUDAMatrix( np.array( dd["FH"], dtype=np.float32, order='F') ) + bias_cov = cmt.CUDAMatrix( np.array(2.0*np.ones((num_hid_cov, 1)), dtype=np.float32, order='F') ) + bias_vis = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F') ) + w_mean = cmt.CUDAMatrix( np.array( 0.05 * np.random.randn(num_vis, num_hid_mean), dtype=np.float32, order='F') ) # VxH + bias_mean = cmt.CUDAMatrix( np.array( -2.0*np.ones((num_hid_mean,1)), dtype=np.float32, order='F') ) + + # initialize variables to store derivatives + VFinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_fac)), dtype=np.float32, order='F')) + FHinc = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F')) + bias_covinc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_cov, 1)), dtype=np.float32, order='F')) + bias_visinc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, 1)), dtype=np.float32, order='F')) + w_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_vis, num_hid_mean)), dtype=np.float32, order='F')) + bias_meaninc = cmt.CUDAMatrix( np.array(np.zeros((num_hid_mean, 1)), dtype=np.float32, order='F')) + + # initialize temporary storage + data = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + normdata = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + negdataini = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + feat = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F')) + featsq = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F')) + negdata = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F')) + old_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) + new_energy = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) + gradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + normgradient = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) # VxP + thresh = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) + feat_mean = cmt.CUDAMatrix( np.array(np.empty((num_hid_mean, batch_size)), dtype=np.float32, order='F')) + vel = cmt.CUDAMatrix( np.array(np.random.randn(num_vis, batch_size), dtype=np.float32, order='F')) + length = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP + lengthsq = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP + normcoeff = cmt.CUDAMatrix( np.array(np.zeros((1, batch_size)), dtype=np.float32, order='F')) # 1xP + if apply_mask==1: # this used to constrain very large FH matrices only allowing to change values in a neighborhood + dd = loadmat('your_FHinit_mask_file.mat') + mask = cmt.CUDAMatrix( np.array(dd["mask"], dtype=np.float32, order='F')) + normVF = 1 + small = 0.5 + + # other temporary vars + t1 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F')) + t2 = cmt.CUDAMatrix( np.array(np.empty((num_hid_cov, batch_size)), dtype=np.float32, order='F')) + t3 = cmt.CUDAMatrix( np.array(np.empty((num_fac, batch_size)), dtype=np.float32, order='F')) + t4 = cmt.CUDAMatrix( np.array(np.empty((1,batch_size)), dtype=np.float32, order='F')) + t5 = cmt.CUDAMatrix( np.array(np.empty((1,1)), dtype=np.float32, order='F')) + t6 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) + t7 = cmt.CUDAMatrix( np.array(np.empty((num_vis, batch_size)), dtype=np.float32, order='F')) + t8 = cmt.CUDAMatrix( np.array(np.empty((num_vis, num_fac)), dtype=np.float32, order='F')) + t9 = cmt.CUDAMatrix( np.array(np.zeros((num_fac, num_hid_cov)), dtype=np.float32, order='F')) + t10 = cmt.CUDAMatrix( np.array(np.empty((1,num_fac)), dtype=np.float32, order='F')) + t11 = cmt.CUDAMatrix( np.array(np.empty((1,num_hid_cov)), dtype=np.float32, order='F')) + # start training + for epoch in range(num_epochs): + + def print_stuff(): + print "VF: " + '%3.2e' % VF.euclid_norm() \ + + ", DVF: " + '%3.2e' % (VFinc.euclid_norm()*(epsilonVFc/batch_size))\ + + ", VF_inc: " + '%3.2e' % (VFinc.euclid_norm())\ + + ", FH: " + '%3.2e' % FH.euclid_norm() \ + + ", DFH: " + '%3.2e' % (FHinc.euclid_norm()*(epsilonFHc/batch_size)) \ + + ", bias_cov: " + '%3.2e' % bias_cov.euclid_norm() \ + + ", Dbias_cov: " + '%3.2e' % (bias_covinc.euclid_norm()*(epsilonbc/batch_size)) \ + + ", bias_vis: " + '%3.2e' % bias_vis.euclid_norm() \ + + ", Dbias_vis: " + '%3.2e' % (bias_visinc.euclid_norm()*(epsilonbc/batch_size)) \ + + ", wm: " + '%3.2e' % w_mean.euclid_norm() \ + + ", Dwm: " + '%3.2e' % (w_meaninc.euclid_norm()*(epsilonw_meanc/batch_size)) \ + + ", bm: " + '%3.2e' % bias_mean.euclid_norm() \ + + ", Dbm: " + '%3.2e' % (bias_meaninc.euclid_norm()*(epsilonb_meanc/batch_size)) \ + + ", step: " + '%3.2e' % hmc_step \ + + ", rej: " + '%3.2e' % hmc_ave_rej + sys.stdout.flush() + + def save_stuff(): + VF.copy_to_host() + FH.copy_to_host() + bias_cov.copy_to_host() + w_mean.copy_to_host() + bias_mean.copy_to_host() + bias_vis.copy_to_host() + savemat("ws_temp", { + 'VF':VF.numpy_array, + 'FH':FH.numpy_array, + 'bias_cov': bias_cov.numpy_array, + 'bias_vis': bias_vis.numpy_array, + 'w_mean': w_mean.numpy_array, + 'bias_mean': bias_mean.numpy_array, + 'epoch':epoch}) + + tile(VF.numpy_array.T, 'VF_%000i.png'%epoch) + tile(w_mean.numpy_array.T, 'w_mean_%000i.png'%epoch) + save_imshow(FH.numpy_array, 'FH_%000i.png'%epoch) + + # anneal learning rates + epsilonVFc = epsilonVF/max(1,epoch/20) + epsilonFHc = epsilonFH/max(1,epoch/20) + epsilonbc = epsilonb/max(1,epoch/20) + epsilonw_meanc = epsilonw_mean/max(1,epoch/20) + epsilonb_meanc = epsilonb_mean/max(1,epoch/20) + weightcost = weightcost_final + + if epoch <= startFH: + epsilonFHc = 0 + if epoch <= startwd: + weightcost = 0 + + print "Epoch " + str(epoch + 1), 'num_batches', num_batches + if epoch == 0: + print_stuff() + + for batch in range(num_batches): + + # get current minibatch + data = dev_dat.slice(batch*batch_size,(batch + 1)*batch_size) # DxP (nr dims x nr samples) + + # normalize input data + data.mult(data, target = t6) # DxP + t6.sum(axis = 0, target = lengthsq) # 1xP + lengthsq.mult(1./num_vis) # normalize by number of components (like std) + lengthsq.add(small) # small avoids division by 0 + cmt.sqrt(lengthsq, target = length) + length.reciprocal(target = normcoeff) # 1xP + data.mult_by_row(normcoeff, target = normdata) # normalized data + ## compute positive sample derivatives + # covariance part + cmt.dot(VF.T, normdata, target = feat) # HxP (nr facs x nr samples) + feat.mult(feat, target = featsq) # HxP + cmt.dot(FH.T,featsq, target = t1) # OxP (nr cov hiddens x nr samples) + t1.mult(-0.5) + t1.add_col_vec(bias_cov) # OxP + t1.apply_sigmoid(target = t2) # OxP + cmt.dot(featsq, t2.T, target = FHinc) # HxO + cmt.dot(FH,t2, target = t3) # HxP + t3.mult(feat) + cmt.dot(normdata, t3.T, target = VFinc) # VxH + t2.sum(axis = 1, target = bias_covinc) + bias_covinc.mult(-1) + # visible bias + data.sum(axis = 1, target = bias_visinc) + bias_visinc.mult(-1) + # mean part + cmt.dot(w_mean.T, data, target = feat_mean) # HxP (nr mean hiddens x nr samples) + feat_mean.add_col_vec(bias_mean) # HxP + feat_mean.apply_sigmoid() # HxP + feat_mean.mult(-1) + cmt.dot(data, feat_mean.T, target = w_meaninc) + feat_mean.sum(axis = 1, target = bias_meaninc) + + # HMC sampling: draw an approximate sample from the model + if doPCD == 0: # CD-1 (set negative data to current training samples) + hmc_step, hmc_ave_rej = draw_HMC_samples(data,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis) + else: # PCD-1 (use previous negative data as starting point for chain) + negdataini.assign(negdata) + hmc_step, hmc_ave_rej = draw_HMC_samples(negdataini,negdata,normdata,vel,gradient,normgradient,new_energy,old_energy,VF,FH,bias_cov,bias_vis,w_mean,bias_mean,hmc_step,hmc_step_nr,hmc_ave_rej,hmc_target_ave_rej,t1,t2,t3,t4,t5,t6,t7,thresh,feat,featsq,batch_size,feat_mean,length,lengthsq,normcoeff,small,num_vis) + + # compute derivatives at the negative samples + # normalize input data + negdata.mult(negdata, target = t6) # DxP + t6.sum(axis = 0, target = lengthsq) # 1xP + lengthsq.mult(1./num_vis) # normalize by number of components (like std) + lengthsq.add(small) + cmt.sqrt(lengthsq, target = length) + length.reciprocal(target = normcoeff) # 1xP + negdata.mult_by_row(normcoeff, target = normdata) # normalized data + # covariance part + cmt.dot(VF.T, normdata, target = feat) # HxP + feat.mult(feat, target = featsq) # HxP + cmt.dot(FH.T,featsq, target = t1) # OxP + t1.mult(-0.5) + t1.add_col_vec(bias_cov) # OxP + t1.apply_sigmoid(target = t2) # OxP + FHinc.subtract_dot(featsq, t2.T) # HxO + FHinc.mult(0.5) + cmt.dot(FH,t2, target = t3) # HxP + t3.mult(feat) + VFinc.subtract_dot(normdata, t3.T) # VxH + bias_covinc.add_sums(t2, axis = 1) + # visible bias + bias_visinc.add_sums(negdata, axis = 1) + # mean part + cmt.dot(w_mean.T, negdata, target = feat_mean) # HxP + feat_mean.add_col_vec(bias_mean) # HxP + feat_mean.apply_sigmoid() # HxP + w_meaninc.add_dot(negdata, feat_mean.T) + bias_meaninc.add_sums(feat_mean, axis = 1) + + # update parameters + VFinc.add_mult(VF.sign(), weightcost) # L1 regularization + VF.add_mult(VFinc, -epsilonVFc/batch_size) + # normalize columns of VF: normalize by running average of their norm + VF.mult(VF, target = t8) + t8.sum(axis = 0, target = t10) + cmt.sqrt(t10) + t10.sum(axis=1,target = t5) + t5.copy_to_host() + normVF = .95*normVF + (.05/num_fac) * t5.numpy_array[0,0] # estimate norm + t10.reciprocal() + VF.mult_by_row(t10) + VF.mult(normVF) + bias_cov.add_mult(bias_covinc, -epsilonbc/batch_size) + bias_vis.add_mult(bias_visinc, -epsilonbc/batch_size) + + if epoch > startFH: + FHinc.add_mult(FH.sign(), weightcost) # L1 regularization + FH.add_mult(FHinc, -epsilonFHc/batch_size) # update + # set to 0 negative entries in FH + FH.greater_than(0, target = t9) + FH.mult(t9) + if apply_mask==1: + FH.mult(mask) + # normalize columns of FH: L1 norm set to 1 in each column + FH.sum(axis = 0, target = t11) + t11.reciprocal() + FH.mult_by_row(t11) + w_meaninc.add_mult(w_mean.sign(),weightcost) + w_mean.add_mult(w_meaninc, -epsilonw_meanc/batch_size) + bias_mean.add_mult(bias_meaninc, -epsilonb_meanc/batch_size) + + if verbose == 1: + print_stuff() + # back-up every once in a while + if np.mod(epoch,1) == 0: + save_stuff() + # final back-up + VF.copy_to_host() + FH.copy_to_host() + bias_cov.copy_to_host() + bias_vis.copy_to_host() + w_mean.copy_to_host() + bias_mean.copy_to_host() + savemat("ws_fac" + str(num_fac) + "_cov" + str(num_hid_cov) + "_mean" + str(num_hid_mean), {'VF':VF.numpy_array,'FH':FH.numpy_array,'bias_cov': bias_cov.numpy_array, 'bias_vis': bias_vis.numpy_array, 'w_mean': w_mean.numpy_array, 'bias_mean': bias_mean.numpy_array, 'epoch':epoch}) + + + +###################################33 +# main +if __name__ == "__main__": + # initialize CUDA + #cmt.cuda_set_device(gpu_lock.obtain_lock_id()) # uncomment if you have a locking system or desire to choose the GPU board number + cmt.cublas_init() + cmt.CUDAMatrix.init_random(1) + train_mcRBM() + cmt.cublas_shutdown() + + + + + + + +