Mercurial > pylearn
changeset 1509:b709f6b53b17
auto fix white space.
author | Frederic Bastien <nouiz@nouiz.org> |
---|---|
date | Mon, 12 Sep 2011 11:46:27 -0400 |
parents | b28e8730c948 |
children | 07b48bd449cd |
files | pylearn/algorithms/tests/test_mcRBM.py |
diffstat | 1 files changed, 15 insertions(+), 16 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/algorithms/tests/test_mcRBM.py Mon Sep 12 11:45:56 2011 -0400 +++ b/pylearn/algorithms/tests/test_mcRBM.py Mon Sep 12 11:46:27 2011 -0400 @@ -100,7 +100,7 @@ trainer = trainer_alloc( rbm_alloc(n_I=n_vis), train_batch, - batchsize, + batchsize, initial_lr_per_example=lr_per_example, l1_penalty=l1_penalty, l1_penalty_start=l1_penalty_start, @@ -109,7 +109,7 @@ if persistent_chains: grads = trainer.contrastive_grads() - learn_fn = theano.function([batch_idx], + learn_fn = theano.function([batch_idx], outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)], updates=trainer.cd_updates()) else: @@ -170,7 +170,7 @@ print 'l2(U)', l2(rbm.U.value), print 'l2(W)', l2(rbm.W.value), - print 'l1_penalty', + print 'l1_penalty', try: print trainer.effective_l1_penalty.value except: @@ -183,7 +183,7 @@ print 'c min max', rbm.c.value.min(), rbm.c.value.max() if persistent_chains: - print 'parts min', smplr.positions.value.min(), + print 'parts min', smplr.positions.value.min(), print 'max',smplr.positions.value.max(), print 'HMC step', smplr.stepsize.value, print 'arate', smplr.avg_acceptance_rate.value @@ -231,7 +231,7 @@ # Set up mcRBM Trainer # Initialize P using topological 3x3 overlapping patches thing # start learning P matrix after 2 passes through dataset - # + # rbm_filename = 'mcRBM.rbm.%06i.pkl'%46000 try: open(rbm_filename).close() @@ -252,7 +252,7 @@ trainer = mcRBMTrainer.alloc_for_P( rbm=mcRBM_withP.alloc_topo_P(n_I=n_vis, n_J=81), visible_batch=train_batch, - batchsize=batchsize, + batchsize=batchsize, initial_lr_per_example=0.05, l1_penalty=1e-3, l1_penalty_start=sys.maxint, @@ -277,7 +277,7 @@ print 'l2(U)', l2(rbm.U.value), print 'l2(W)', l2(rbm.W.value), - print 'l1_penalty', + print 'l1_penalty', try: print trainer.effective_l1_penalty.value except: @@ -317,9 +317,9 @@ feat_idx = tensor.lscalar() feat_idx_range = feat_idx * batchsize + tensor.arange(batchsize) train_batch_x, train_batch_y = pylearn.dataset_ops.cifar10.cifar10( - feat_idx_range, - split='all', - dtype='uint8', + feat_idx_range, + split='all', + dtype='uint8', rasterized=False, color='rgb') @@ -397,7 +397,7 @@ #l1_regularization = float(sys.argv[1]) #1.e-3 #l2_regularization = float(sys.argv[2]) #1.e-3*0 - feature_logreg = LogisticRegression.new(x_i, + feature_logreg = LogisticRegression.new(x_i, n_in = 11025, n_out=10, dtype=x_i.dtype) @@ -407,7 +407,7 @@ traincost = feature_logreg.nll(y_i).sum() traincost = traincost + abs(feature_logreg.w).sum() * l1_regularization #traincost = traincost + (feature_logreg.w**2).sum() * l2_regularization - train_logreg_fn = theano.function([x_i, y_i, lr], + train_logreg_fn = theano.function([x_i, y_i, lr], [feature_logreg.nll(y_i).mean(), feature_logreg.errors(y_i).mean()], updates=pylearn.gd.sgd.sgd_updates( @@ -459,7 +459,7 @@ y_i = valid_labels[i*batchsize:(i+1)*batchsize] #lr=0.0 -> no learning, safe for validation set - nll, l01 = train_logreg_fn(preproc(x_i), y_i, 0.0) + nll, l01 = train_logreg_fn(preproc(x_i), y_i, 0.0) nlls.append(nll) l01s.append(l01) print 'validate log_reg', numpy.mean(nlls), numpy.mean(l01s) @@ -473,7 +473,7 @@ y_i = test_labels[i*batchsize:(i+1)*batchsize] #lr=0.0 -> no learning, safe for validation set - nll, l01 = train_logreg_fn(preproc(x_i), y_i, 0.0) + nll, l01 = train_logreg_fn(preproc(x_i), y_i, 0.0) nlls.append(nll) l01s.append(l01) print 'test log_reg', numpy.mean(nlls), numpy.mean(l01s) @@ -495,7 +495,7 @@ import pickle as cPickle #import cPickle if __name__ == '__main__': - if 0: + if 0: #learning 16 x 16 pinwheel filters from official cifar patches (MAR) rbm,smplr = test_reproduce_ranzato_hinton_2010( as_unittest=False, @@ -524,4 +524,3 @@ def checkpoint(): return checkpoint run_classif_experiment(checkpoint=checkpoint) -