Mercurial > pylearn
diff pylearn/algorithms/tests/test_mcRBM.py @ 1284:1817485d586d
mcRBM - many changes incl. adding support for pooling matrix
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 15 Sep 2010 17:49:21 -0400 |
parents | 7bb5dd98e671 |
children | 8905186b176c |
line wrap: on
line diff
--- a/pylearn/algorithms/tests/test_mcRBM.py Wed Sep 15 17:46:21 2010 -0400 +++ b/pylearn/algorithms/tests/test_mcRBM.py Wed Sep 15 17:49:21 2010 -0400 @@ -1,33 +1,166 @@ from pylearn.algorithms.mcRBM import * +import pylearn.datasets.cifar10 + +import pylearn.dataset_ops.cifar10 + +def _mar_train_patches(dtype): + R,C=16,16 + train_data = pylearn.dataset_ops.cifar10.train_data_labels(dtype)[0][:40000] + #train_data shape is (40000, 3072) + train_data = train_data.reshape((40000,3,32,32)).transpose([0,2,3,1]) + patches = train_data[:, :R, :C, :].reshape((40000, 3*R*C)) + patches -= patches.mean(axis=0) + wpatches = numpy.dot(patches, d['pcatransf'].T) + return wpatches + +def mar_centered(s_idx, split, dtype='float64', rasterized=False, color='grey'): + """ + Returns a pair (img, label) of theano expressions for cifar-10 samples + + :param s_idx: the indexes + + :param split: + + :param dtype: + + :param rasterized: return examples as vectors (True) or 28x28 matrices (False) + + :param color: control how to deal with the color in the images' + - grey greyscale (with luminance weighting) + - rgb add a trailing dimension of length 3 with rgb colour channels + + """ + + split_options = {'train':(train_data, train_labels), + 'valid': (valid_data, valid_labels), + 'test': (test_data, test_labels)} + + if split not in split_options: + raise ValueError('invalid split option', (split, split_options.keys())) + + color_options = ('grey', 'rgb') + if color not in color_options: + raise ValueError('invalid color option', (color, color_options)) + + x_fn, y_fn = split_options[split] + + x_op = TensorFnDataset(dtype, (False,), (x_fn, (dtype,)), (3072,)) + y_op = TensorFnDataset('int32', (), y_fn) + + x = x_op(s_idx) + y = y_op(s_idx) + + # Y = 0.3R + 0.59G + 0.11B from + # http://gimp-savvy.com/BOOK/index.html?node54.html + rgb_dtype = 'float32' + if dtype == 'float64': + rgb_dtype = dtype + r = numpy.asarray(.3, dtype=rgb_dtype) + g = numpy.asarray(.59, dtype=rgb_dtype) + b = numpy.asarray(.11, dtype=rgb_dtype) -def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True, n_train_iters=5000): - dataset='MAR' + if x.ndim == 1: + if rasterized: + if color=='grey': + x = r * x[:1024] + g * x[1024:2048] + b * x[2048:] + if dtype=='uint8': + x = theano.tensor.cast(x, 'uint8') + elif color=='rgb': + # the strides aren't what you'd expect between channels, + # but theano is all about weird strides + x = x.reshape((3,32*32)).T + else: + raise NotImplemented('color', color) + else: + if color=='grey': + x = r * x[:1024] + g * x[1024:2048] + b * x[2048:] + if dtype=='uint8': + x = theano.tensor.cast(x, 'uint8') + x = x.reshape((32,32)) + elif color=='rgb': + # the strides aren't what you'd expect between channels, + # but theano is all about weird strides + x = x.reshape((3,32,32)).dimshuffle(1, 2, 0) + else: + raise NotImplemented('color', color) + elif x.ndim == 2: + N = x.shape[0] # symbolic + if rasterized: + if color=='grey': + x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:] + if dtype=='uint8': + x = theano.tensor.cast(x, 'uint8') + elif color=='rgb': + # the strides aren't what you'd expect between channels, + # but theano is all about weird strides + x = x.reshape((N, 3,32*32)).dimshuffle(0, 2, 1) + else: + raise NotImplemented('color', color) + else: + if color=='grey': + x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:] + if dtype=='uint8': + x = theano.tensor.cast(x, 'uint8') + x.reshape((N, 32, 32)) + elif color=='rgb': + # the strides aren't what you'd expect between channels, + # but theano is all about weird strides + x = x.reshape((N,3,32,32)).dimshuffle(0, 2, 3, 1) + else: + raise NotImplemented('color', color) + else: + raise ValueError('x has too many dimensions', x.ndim) + + return x, y + + +def _default_rbm_alloc(n_I, n_K=256, n_J=100): + return mcRBM.alloc(n_I, n_K, n_J) + +def _default_trainer_alloc(rbm, train_batch, batchsize, l1_penalty, l1_penalty_start): + return mcRBMTrainer.alloc(rbm, train_batch, batchsize, l1_penalty=l1_penalty, + l1_penalty_start=l1_penalty_start,persistent_chains=persistent_chains) + + +def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True, n_train_iters=5000, + rbm_alloc=_default_rbm_alloc, trainer_alloc=_default_trainer_alloc, + lr_per_example=.075, + l1_penalty=1e-3, + l1_penalty_start=1000, + persistent_chains=True, + ): + + batchsize = 128 + if dataset == 'MAR': n_vis=105 n_patches=10240 + epoch_size=n_patches + elif dataset=='cifar10patches8x8': + R,C= 8,8 # the size of image patches + n_vis=96 # pca components + epoch_size=batchsize*500 + n_patches=epoch_size*20 else: R,C= 16,16 # the size of image patches n_vis=R*C n_patches=100000 - - n_burnin_steps=10000 - - - l1_penalty=1e-3 - no_l1_epochs = 10 - effective_l1_penalty=0.0 - - epoch_size=n_patches - batchsize = 128 - lr = 0.075 / batchsize - s_lr = TT.scalar() - n_K=256 - n_J=100 + epoch_size=n_patches def l2(X): return numpy.sqrt((X**2).sum()) + if dataset == 'MAR': tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010 + elif dataset == 'cifar10patches8x8': + def tile(X, fname): + _img = pylearn.datasets.cifar10.tile_rasterized_examples( + pylearn.preprocessing.pca.pca_whiten_inverse( + pylearn.dataset_ops.cifar10.random_cifar_patches_pca( + n_vis, None, 'float32', n_patches, R, C,), + X), + img_shape=(R,C)) + image_tiling.save_tiled_raster_images(_img, fname) else: def tile(X, fname): _img = image_tiling.tile_raster_images(X, @@ -39,6 +172,10 @@ if dataset == 'MAR': train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize)) + elif dataset == 'cifar10patches8x8': + train_batch = pylearn.dataset_ops.cifar10.cifar10_patches(batch_idx * batchsize + + np.arange(batchsize), 'train', n_patches=n_patches, patch_size=(R,C), + pca_components=n_vis) else: train_batch = pylearn.dataset_ops.image_patches.image_patches( s_idx = (batch_idx * batchsize + np.arange(batchsize)), @@ -51,17 +188,34 @@ if not as_unittest: imgs_fn = function([batch_idx], outputs=train_batch) - trainer = mcRBMTrainer.alloc( - mcRBM.alloc(n_I=n_vis, n_K=n_K, n_J=n_J), + trainer = trainer_alloc( + rbm_alloc(n_I=n_vis), train_batch, - batchsize, l1_penalty=TT.scalar()) + batchsize, + initial_lr_per_example=lr_per_example, + l1_penalty=l1_penalty, + l1_penalty_start=l1_penalty_start, + persistent_chains=persistent_chains) rbm=trainer.rbm - smplr = trainer.sampler + + if persistent_chains: + grads = trainer.contrastive_grads() + learn_fn = function([batch_idx], + outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)], + updates=trainer.cd_updates()) + else: + learn_fn = function([batch_idx], outputs=[], updates=trainer.cd_updates()) - grads = trainer.contrastive_grads() - learn_fn = function([batch_idx, trainer.l1_penalty], - outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)], - updates=trainer.cd_updates()) + if persistent_chains: + smplr = trainer.sampler + else: + smplr = trainer._last_cd1_sampler + + if dataset == 'cifar10patches8x8': + cPickle.dump( + pylearn.dataset_ops.cifar10.random_cifar_patches_pca( + n_vis, None, 'float32', n_patches, R, C,), + open('test_mcRBM.pca.pkl','w')) print "Learning..." last_epoch = -1 @@ -98,14 +252,20 @@ if print_jj: if not as_unittest: tile(imgs_fn(jj), "imgs_%06i.png"%jj) - tile(smplr.positions.value, "sample_%06i.png"%jj) + if persistent_chains: + tile(smplr.positions.value, "sample_%06i.png"%jj) tile(rbm.U.value.T, "U_%06i.png"%jj) tile(rbm.W.value.T, "W_%06i.png"%jj) print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize) print 'l2(U)', l2(rbm.U.value), - print 'l2(W)', l2(rbm.W.value) + print 'l2(W)', l2(rbm.W.value), + print 'l1_penalty', + try: + print trainer.effective_l1_penalty.value + except: + print trainer.effective_l1_penalty print 'U min max', rbm.U.value.min(), rbm.U.value.max(), print 'W min max', rbm.W.value.min(), rbm.W.value.max(), @@ -113,15 +273,16 @@ print 'b min max', rbm.b.value.min(), rbm.b.value.max(), print 'c min max', rbm.c.value.min(), rbm.c.value.max() - print 'parts min', smplr.positions.value.min(), - print 'max',smplr.positions.value.max(), - print 'HMC step', smplr.stepsize, - print 'arate', smplr.avg_acceptance_rate + if persistent_chains: + print 'parts min', smplr.positions.value.min(), + print 'max',smplr.positions.value.max(), + print 'HMC step', smplr.stepsize.value, + print 'arate', smplr.avg_acceptance_rate.value - l2_of_Ugrad = learn_fn(jj, effective_l1_penalty) + l2_of_Ugrad = learn_fn(jj) - if print_jj: + if persistent_chains and print_jj: print 'l2(U_grad)', float(l2_of_Ugrad[0]), print 'l2(U_inc)', float(l2_of_Ugrad[1]), print 'l2(W_inc)', float(l2_of_Ugrad[2]), @@ -131,9 +292,42 @@ #print 'FE+[2]', float(l2_of_Ugrad[5]), #print 'FE+[3]', float(l2_of_Ugrad[6]) - if jj == no_l1_epochs * epoch_size/batchsize: - print "Activating L1 weight decay" - effective_l1_penalty = 1e-3 + if not as_unittest: + if jj % 2000 == 0: + print '' + print 'Saving rbm...' + cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%jj, 'w'), -1) + if persistent_chains: + print 'Saving sampler...' + cPickle.dump(smplr, open('mcRBM.smplr.%06i.pkl'%jj, 'w'), -1) + if not as_unittest: return rbm, smplr + +import pickle as cPickle +#import cPickle +if __name__ == '__main__': + if 0: + #learning 16 x 16 pinwheel filters from official cifar patches (MAR) + rbm,smplr = test_reproduce_ranzato_hinton_2010( + as_unittest=False, + n_train_iters=5000, + rbm_alloc=lambda n_I : mcRBM_withP.alloc_topo_P(n_I, n_J=81), + trainer_alloc=mcRBMTrainer.alloc_for_P, + dataset='MAR' + ) + + if 1: + rbm,smplr = test_reproduce_ranzato_hinton_2010( + as_unittest=False, + n_train_iters=60000, + rbm_alloc=lambda n_I : mcRBM_withP.alloc_topo_P(n_I, n_J=81), + trainer_alloc=mcRBMTrainer.alloc_for_P, + lr_per_example=0.05, + dataset='cifar10patches8x8', + l1_penalty=1e-3, + l1_penalty_start=30000, + #l1_penalty_start=350, #DEBUG + persistent_chains=False + )