# HG changeset patch # User James Bergstra # Date 1286736662 14400 # Node ID d6726417cf5753183665e1b873a64c0cdff26e48 # Parent cdda4f98c2a288f3567a1e29318798fdba25f9e9 adding training script for test_mcRBM to reproduce classification results diff -r cdda4f98c2a2 -r d6726417cf57 pylearn/algorithms/tests/test_mcRBM.py --- a/pylearn/algorithms/tests/test_mcRBM.py Sun Oct 10 13:45:21 2010 -0400 +++ b/pylearn/algorithms/tests/test_mcRBM.py Sun Oct 10 14:51:02 2010 -0400 @@ -1,8 +1,12 @@ +import sys from pylearn.algorithms.mcRBM import * import pylearn.datasets.cifar10 import pylearn.dataset_ops.tinyimages import pylearn.dataset_ops.cifar10 +from theano import tensor +from pylearn.shared.layers.logreg import LogisticRegression + def _default_rbm_alloc(n_I, n_K=256, n_J=100): return mcRBM.alloc(n_I, n_K, n_J) @@ -205,6 +209,251 @@ if not as_unittest: return rbm, smplr + +def run_classif_experiment(checkpoint): + + R,C=8,8 + n_vis=74 + # PRETRAIN + # + # extract 1 million 8x8 patches from TinyImages + # pre-process them the right way + # find 74 dims of PCA + # filter patches through PCA + whitened_patches, pca_dct = pylearn.dataset_ops.tinyimages.main(n_imgs=100000, + max_components=n_vis, seed=234) + # + # Set up mcRBM Trainer + # Initialize P using topological 3x3 overlapping patches thing + # start learning P matrix after 2 passes through dataset + # + rbm_filename = 'mcRBM.rbm.%06i.pkl'%46000 + try: + open(rbm_filename).close() + load_mcrbm = True + except: + load_mcrbm = False + + if load_mcrbm: + print 'loading mcRBM from file', rbm_filename + rbm = cPickle.load(open(rbm_filename)) + + else: + print "Training mcRBM" + batchsize=128 + epoch_size=len(whitened_patches) + tile = pylearn.dataset_ops.tinyimages.save_filters + train_batch = theano.tensor.matrix() + trainer = mcRBMTrainer.alloc_for_P( + rbm=mcRBM_withP.alloc_topo_P(n_I=n_vis, n_J=81), + visible_batch=train_batch, + batchsize=batchsize, + initial_lr_per_example=0.05, + l1_penalty=1e-3, + l1_penalty_start=sys.maxint, + p_training_start=2*epoch_size//batchsize, + persistent_chains=False) + rbm=trainer.rbm + learn_fn = function([train_batch], outputs=[], updates=trainer.cd_updates()) + smplr = trainer._last_cd1_sampler + + ii = 0 + for i_epoch in range(6): + for i_batch in xrange(epoch_size // batchsize): + batch_vals = whitened_patches[i_batch*batchsize:(i_batch+1)*batchsize] + learn_fn(batch_vals) + + if (ii % 1000) == 0: + #tile(imgs_fn(ii), "imgs_%06i.png"%ii) + tile(rbm.U.value.T, "U_%06i.png"%ii) + tile(rbm.W.value.T, "W_%06i.png"%ii) + + print 'saving samples', ii, 'epoch', i_epoch, i_batch + + print 'l2(U)', l2(rbm.U.value), + print 'l2(W)', l2(rbm.W.value), + print 'l1_penalty', + try: + print trainer.effective_l1_penalty.value + except: + print trainer.effective_l1_penalty + + print 'U min max', rbm.U.value.min(), rbm.U.value.max(), + print 'W min max', rbm.W.value.min(), rbm.W.value.max(), + print 'a min max', rbm.a.value.min(), rbm.a.value.max(), + print 'b min max', rbm.b.value.min(), rbm.b.value.max(), + print 'c min max', rbm.c.value.min(), rbm.c.value.max() + + print 'HMC step', smplr.stepsize.value, + print 'arate', smplr.avg_acceptance_rate.value + print 'P min max', rbm.P.value.min(), rbm.P.value.max(), + print 'P_lr', trainer.p_lr.value + print '' + print 'Saving rbm...' + cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%ii, 'w'), -1) + + ii += 1 + + + # extract convolutional features from the CIFAR10 data + feat_filename = 'cifar10.features.46000.npy' + feat_filename = 'mcrbm_features.npy' + try: + open(feat_filename).close() + load_features = True + except: + load_features = False + + if load_features: + print 'Loading features from', feat_filename + all_features = numpy.load(feat_filename, mmap_mode='r') + else: + batchsize=100 + feat_idx = tensor.lscalar() + feat_idx_range = feat_idx * batchsize + tensor.arange(batchsize) + train_batch_x, train_batch_y = pylearn.dataset_ops.cifar10.cifar10( + feat_idx_range, + split='all', + dtype='uint8', + rasterized=False, + color='rgb') + + WINDOW_SIZE=8 + WINDOW_STRIDE=4 + + # put these into shared vars because support for big matrix constants is bad, + # (comparing them is slow) + pca_eigvecs = shared(pca_dct['eig_vecs'].astype('float32')) + pca_eigvals = shared(pca_dct['eig_vals'].astype('float32')) + pca_mean = shared(pca_dct['mean'].astype('float32')) + + def theano_pca_whiten(X): + #copying preprepcessing.pca.pca_whiten + return tensor.true_div( + tensor.dot(X-pca_mean, pca_eigvecs), + tensor.sqrt(pca_eigvals)+1e-8) + + h_list = [] + g_list = [] + for r_offset in range(0, 32-WINDOW_SIZE+1, WINDOW_STRIDE): + for c_offset in range(0, 32-WINDOW_SIZE+1, WINDOW_STRIDE): + window = train_batch_x[:, r_offset:r_offset+WINDOW_SIZE, + c_offset:c_offset+WINDOW_SIZE, :] + assert window.dtype=='uint8' + + #rasterize the patches + raster_window = tensor.flatten(tensor.cast(window, 'float32'),2) + + #subtract off the mean of each image + raster_window = raster_window - raster_window.mean(axis=1).reshape((batchsize,1)) + + h,g = rbm.expected_h_g_given_v(theano_pca_whiten(raster_window)) + + h_list.append(h) + g_list.append(g) + + hg = tensor.concatenate(h_list + g_list, axis=1) + + feat_fn = function([feat_idx], hg) + features = numpy.empty((60000, 11025), dtype='float32') + for i in xrange(60000//batchsize): + if i % 100 == 0: + print("feature batch %i"%i) + features[i*batchsize:(i+1)*batchsize] = feat_fn(i) + + print("saving features to %s"%feat_filename) + numpy.save(feat_filename, features) + all_features = features + del features + + + # CLASSIFY FEATURES + + if 0: + # nothing to load + pass + else: + batchsize=10 + x_i = tensor.matrix() + y_i = tensor.ivector() + lr = tensor.scalar() + l1_regularization = float(sys.argv[1]) #1.e-3 + l2_regularization = float(sys.argv[2]) #1.e-3*0 + + feature_logreg = LogisticRegression.new(x_i, + n_in = 11025, n_out=10, + dtype=x_i.dtype) + + traincost = feature_logreg.nll(y_i).sum() + traincost = traincost + abs(feature_logreg.w).sum() * l1_regularization + traincost = traincost + (feature_logreg.w**2).sum() * l2_regularization + train_logreg_fn = function([x_i, y_i, lr], + [feature_logreg.nll(y_i).mean(), + feature_logreg.errors(y_i).mean()], + updates=pylearn.gd.sgd.sgd_updates( + params=feature_logreg.params, + grads=tensor.grad(traincost, feature_logreg.params), + stepsizes=[lr,lr])) + + all_labels = pylearn.dataset_ops.cifar10.all_data_labels('uint8')[1] + pylearn.dataset_ops.cifar10.all_data_labels.forget() # clear memo cache + assert len(all_labels)==60000 + train_labels = all_labels[:40000] + valid_labels = all_labels[40000:50000] + test_labels = all_labels[50000:60000] + train_features = all_features[:40000] + valid_features = all_features[40000:50000] + test_features = all_features[50000:60000] + + print "Computing mean and std.dev" + train_mean = train_features.mean(axis=0) + train_std = train_features.std(axis=0)+1e-4 + + + for epoch in xrange(20): + print 'epoch', epoch + # validate + + l01s = [] + nlls = [] + for i in xrange(10000/batchsize): + x_i = valid_features[i*batchsize:(i+1)*batchsize] + y_i = valid_labels[i*batchsize:(i+1)*batchsize] + + #lr=0.0 -> no learning, safe for validation set + nll, l01 = train_logreg_fn((x_i-train_mean)/train_std, y_i, 0.0) + nlls.append(nll) + l01s.append(l01) + print 'validate log_reg', numpy.mean(nlls), numpy.mean(l01s) + + # test + + l01s = [] + nlls = [] + for i in xrange(10000/batchsize): + x_i = test_features[i*batchsize:(i+1)*batchsize] + y_i = test_labels[i*batchsize:(i+1)*batchsize] + + #lr=0.0 -> no learning, safe for validation set + nll, l01 = train_logreg_fn((x_i-train_mean)/train_std, y_i, 0.0) + nlls.append(nll) + l01s.append(l01) + print 'test log_reg', numpy.mean(nlls), numpy.mean(l01s) + + #train + l01s = [] + nlls = [] + for i in xrange(40000/batchsize): + x_i = train_features[i*batchsize:(i+1)*batchsize] + y_i = train_labels[i*batchsize:(i+1)*batchsize] + nll, l01 = train_logreg_fn((x_i-train_mean)/train_std, y_i, 0.00003) + nlls.append(nll) + l01s.append(l01) + print 'train log_reg', numpy.mean(nlls), numpy.mean(l01s) + + + + import pickle as cPickle #import cPickle if __name__ == '__main__': @@ -218,7 +467,7 @@ dataset='MAR' ) - if 1: + if 0: # pretraining settings rbm,smplr = test_reproduce_ranzato_hinton_2010( as_unittest=False, @@ -232,3 +481,8 @@ #l1_penalty_start=350, #DEBUG persistent_chains=False ) + + if 1: + def checkpoint(): + return checkpoint + run_classif_experiment(checkpoint=checkpoint)