view pylearn/algorithms/tests/test_mcRBM.py @ 1334:6fd2610c1706

merge
author James Bergstra <bergstrj@iro.umontreal.ca>
date Mon, 18 Oct 2010 14:58:52 -0400
parents c7b2da4e2df6 837768915081
children 7c51c0355d86
line wrap: on
line source

import sys
from pylearn.algorithms.mcRBM import *
import pylearn.datasets.cifar10
import pylearn.dataset_ops.tinyimages

import pylearn.dataset_ops.cifar10
from theano import tensor
from pylearn.shared.layers.logreg import LogisticRegression


def _default_rbm_alloc(n_I, n_K=256, n_J=100):
    return mcRBM.alloc(n_I, n_K, n_J)

def _default_trainer_alloc(rbm, train_batch, batchsize, l1_penalty, l1_penalty_start):
    return mcRBMTrainer.alloc(rbm, train_batch, batchsize, l1_penalty=l1_penalty,
            l1_penalty_start=l1_penalty_start,persistent_chains=persistent_chains)


def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True, n_train_iters=5000,
        rbm_alloc=_default_rbm_alloc, trainer_alloc=_default_trainer_alloc,
        lr_per_example=.075,
        l1_penalty=1e-3,
        l1_penalty_start=1000,
        persistent_chains=True,
        ):

    batchsize = 128

    if dataset == 'MAR':
        n_vis=105
        n_patches=10240
        epoch_size=n_patches
    elif dataset=='cifar10patches8x8':
        R,C= 8,8 # the size of image patches
        n_vis=96 # pca components
        epoch_size=batchsize*500
        n_patches=epoch_size*20
    elif dataset=='tinyimages_patches':
        R,C=8,8
        n_vis=81
        epoch_size=batchsize*500
        n_patches=epoch_size*20
    else:
        R,C= 16,16 # the size of image patches
        n_vis=R*C
        n_patches=100000
        epoch_size=n_patches

    def l2(X):
        return numpy.sqrt((X**2).sum())

    if dataset == 'MAR':
        tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
    elif dataset == 'cifar10patches8x8':
        def tile(X, fname):
            _img = pylearn.datasets.cifar10.tile_rasterized_examples(
                    pylearn.preprocessing.pca.pca_whiten_inverse(
                        pylearn.dataset_ops.cifar10.random_cifar_patches_pca(
                            n_vis, None, 'float32', n_patches, R, C,),
                        X),
                    img_shape=(R,C))
            image_tiling.save_tiled_raster_images(_img, fname)
    elif dataset == 'tinyimages_patches':
        tile = pylearn.dataset_ops.tinyimages.save_filters
    else:
        def tile(X, fname):
            _img = image_tiling.tile_raster_images(X,
                    img_shape=(R,C),
                    min_dynamic_range=1e-2)
            image_tiling.save_tiled_raster_images(_img, fname)

    batch_idx = TT.iscalar()
    batch_range =batch_idx * batchsize + np.arange(batchsize)

    if dataset == 'MAR':
        train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_range)
    elif dataset == 'cifar10patches8x8':
        train_batch = pylearn.dataset_ops.cifar10.cifar10_patches(
                batch_range, 'train', n_patches=n_patches, patch_size=(R,C),
                pca_components=n_vis)
    elif dataset == 'tinyimages_patches':
        train_batch = pylearn.dataset_ops.tinyimages.tinydataset_op(batch_range)
    else:
        train_batch = pylearn.dataset_ops.image_patches.image_patches(
                s_idx = (batch_idx * batchsize + np.arange(batchsize)),
                dims = (n_patches,R,C),
                center=True,
                unitvar=True,
                dtype=floatX,
                rasterized=True)

    if not as_unittest:
        imgs_fn = function([batch_idx], outputs=train_batch)

    trainer = trainer_alloc(
            rbm_alloc(n_I=n_vis),
            train_batch,
            batchsize, 
            initial_lr_per_example=lr_per_example,
            l1_penalty=l1_penalty,
            l1_penalty_start=l1_penalty_start,
            persistent_chains=persistent_chains)
    rbm=trainer.rbm

    if persistent_chains:
        grads = trainer.contrastive_grads()
        learn_fn = function([batch_idx], 
                outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
                updates=trainer.cd_updates())
    else:
        learn_fn = function([batch_idx], outputs=[], updates=trainer.cd_updates())

    if persistent_chains:
        smplr = trainer.sampler
    else:
        smplr = trainer._last_cd1_sampler

    if dataset == 'cifar10patches8x8':
        cPickle.dump(
                pylearn.dataset_ops.cifar10.random_cifar_patches_pca(
                    n_vis, None, 'float32', n_patches, R, C,),
                open('test_mcRBM.pca.pkl','w'))

    print "Learning..."
    last_epoch = -1
    for jj in xrange(n_train_iters):
        epoch = jj*batchsize / epoch_size

        print_jj = epoch != last_epoch
        last_epoch = epoch

        if as_unittest and epoch == 5:
            U = rbm.U.value
            W = rbm.W.value
            def allclose(a,b):
                return numpy.allclose(a,b,rtol=1.01,atol=1e-3)
            print ""
            print "--------------"
            print "assert allclose(l2(U), %f)"%l2(U)
            print "assert allclose(l2(W), %f)"%l2(W)
            print "assert allclose(U.min(), %f)"%U.min()
            print "assert allclose(U.max(), %f)"%U.max()
            print "assert allclose(W.min(),%f)"%W.min()
            print "assert allclose(W.max(), %f)"%W.max()
            print "--------------"

            assert allclose(l2(U), 21.351664)
            assert allclose(l2(W), 6.275828)
            assert allclose(U.min(), -1.176703)
            assert allclose(U.max(), 0.859802)
            assert allclose(W.min(),-0.223128)
            assert allclose(W.max(), 0.227558 )

            break

        if print_jj:
            if not as_unittest:
                tile(imgs_fn(jj), "imgs_%06i.png"%jj)
                if persistent_chains:
                    tile(smplr.positions.value, "sample_%06i.png"%jj)
                tile(rbm.U.value.T, "U_%06i.png"%jj)
                tile(rbm.W.value.T, "W_%06i.png"%jj)

            print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)

            print 'l2(U)', l2(rbm.U.value),
            print 'l2(W)', l2(rbm.W.value),
            print 'l1_penalty', 
            try:
                print trainer.effective_l1_penalty.value
            except:
                print trainer.effective_l1_penalty

            print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
            print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
            print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
            print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
            print 'c min max', rbm.c.value.min(), rbm.c.value.max()

            if persistent_chains:
                print 'parts min', smplr.positions.value.min(), 
                print 'max',smplr.positions.value.max(),
            print 'HMC step', smplr.stepsize.value,
            print 'arate', smplr.avg_acceptance_rate.value


        l2_of_Ugrad = learn_fn(jj)

        if persistent_chains and print_jj:
            print 'l2(U_grad)', float(l2_of_Ugrad[0]),
            print 'l2(U_inc)', float(l2_of_Ugrad[1]),
            print 'l2(W_inc)', float(l2_of_Ugrad[2]),
            #print 'FE+', float(l2_of_Ugrad[2]),
            #print 'FE+[0]', float(l2_of_Ugrad[3]),
            #print 'FE+[1]', float(l2_of_Ugrad[4]),
            #print 'FE+[2]', float(l2_of_Ugrad[5]),
            #print 'FE+[3]', float(l2_of_Ugrad[6])

        if not as_unittest:
            if jj % 2000 == 0:
                print ''
                print 'Saving rbm...'
                cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%jj, 'w'), -1)
                if persistent_chains:
                    print 'Saving sampler...'
                    cPickle.dump(smplr, open('mcRBM.smplr.%06i.pkl'%jj, 'w'), -1)


    if not as_unittest:
        return rbm, smplr


def run_classif_experiment(checkpoint):

    R,C=8,8
    n_vis=74
    # PRETRAIN
    #
    # extract 1 million 8x8 patches from TinyImages
    # pre-process them the right way
    # find 74 dims of PCA
    # filter patches through PCA
    whitened_patches, pca_dct = pylearn.dataset_ops.tinyimages.main(n_imgs=100000,
            max_components=n_vis, seed=234)
    #
    # Set up mcRBM Trainer
    # Initialize P using topological 3x3 overlapping patches thing
    # start learning P matrix after 2 passes through dataset
    # 
    rbm_filename = 'mcRBM.rbm.%06i.pkl'%46000
    try:
        open(rbm_filename).close()
        load_mcrbm = True
    except:
        load_mcrbm = False

    if load_mcrbm:
        print 'loading mcRBM from file', rbm_filename
        rbm = cPickle.load(open(rbm_filename))

    else:
        print "Training mcRBM"
        batchsize=128
        epoch_size=len(whitened_patches)
        tile = pylearn.dataset_ops.tinyimages.save_filters
        train_batch = theano.tensor.matrix()
        trainer = mcRBMTrainer.alloc_for_P(
                rbm=mcRBM_withP.alloc_topo_P(n_I=n_vis, n_J=81),
                visible_batch=train_batch,
                batchsize=batchsize, 
                initial_lr_per_example=0.05,
                l1_penalty=1e-3,
                l1_penalty_start=sys.maxint,
                p_training_start=2*epoch_size//batchsize,
                persistent_chains=False)
        rbm=trainer.rbm
        learn_fn = function([train_batch], outputs=[], updates=trainer.cd_updates())
        smplr = trainer._last_cd1_sampler

        ii = 0
        for i_epoch in range(6):
            for i_batch in xrange(epoch_size // batchsize):
                batch_vals = whitened_patches[i_batch*batchsize:(i_batch+1)*batchsize]
                learn_fn(batch_vals)

                if (ii % 1000) == 0:
                    #tile(imgs_fn(ii), "imgs_%06i.png"%ii)
                    tile(rbm.U.value.T, "U_%06i.png"%ii)
                    tile(rbm.W.value.T, "W_%06i.png"%ii)

                    print 'saving samples', ii, 'epoch', i_epoch, i_batch

                    print 'l2(U)', l2(rbm.U.value),
                    print 'l2(W)', l2(rbm.W.value),
                    print 'l1_penalty', 
                    try:
                        print trainer.effective_l1_penalty.value
                    except:
                        print trainer.effective_l1_penalty

                    print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
                    print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
                    print 'a min max', rbm.a.value.min(), rbm.a.value.max(),
                    print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
                    print 'c min max', rbm.c.value.min(), rbm.c.value.max()

                    print 'HMC step', smplr.stepsize.value,
                    print 'arate', smplr.avg_acceptance_rate.value
                    print 'P min max', rbm.P.value.min(), rbm.P.value.max(),
                    print 'P_lr', trainer.p_lr.value
                    print ''
                    print 'Saving rbm...'
                    cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%ii, 'w'), -1)

                ii += 1


    # extract convolutional features from the CIFAR10 data
    feat_filename = 'mcrbm_features.npy'
    feat_filename = 'cifar10.features.46000.npy'
    try:
        open(feat_filename).close()
        load_features = True
    except:
        load_features = False

    if load_features:
        print 'Loading features from', feat_filename
        all_features = numpy.load(feat_filename, mmap_mode='r')
    else:
        batchsize=100
        feat_idx = tensor.lscalar()
        feat_idx_range = feat_idx * batchsize + tensor.arange(batchsize)
        train_batch_x, train_batch_y = pylearn.dataset_ops.cifar10.cifar10(
                feat_idx_range, 
                split='all', 
                dtype='uint8', 
                rasterized=False,
                color='rgb')

        WINDOW_SIZE=8
        WINDOW_STRIDE=4

        # put these into shared vars because support for big matrix constants is bad,
        # (comparing them is slow)
        pca_eigvecs = shared(pca_dct['eig_vecs'].astype('float32'))
        pca_eigvals = shared(pca_dct['eig_vals'].astype('float32'))
        pca_mean    = shared(pca_dct['mean'].astype('float32'))

        def theano_pca_whiten(X):
            #copying preprepcessing.pca.pca_whiten
            return tensor.true_div(
                tensor.dot(X-pca_mean, pca_eigvecs),
                tensor.sqrt(pca_eigvals)+1e-8)

        h_list = []
        g_list = []
        for r_offset in range(0, 32-WINDOW_SIZE+1, WINDOW_STRIDE):
            for c_offset in range(0, 32-WINDOW_SIZE+1, WINDOW_STRIDE):
                window = train_batch_x[:, r_offset:r_offset+WINDOW_SIZE,
                        c_offset:c_offset+WINDOW_SIZE, :]
                assert window.dtype=='uint8'

                #rasterize the patches
                raster_window = tensor.flatten(tensor.cast(window, 'float32'),2)

                #subtract off the mean of each image
                raster_window = raster_window - raster_window.mean(axis=1).reshape((batchsize,1))

                h,g = rbm.expected_h_g_given_v(theano_pca_whiten(raster_window))

                h_list.append(h)
                g_list.append(g)

        hg = tensor.concatenate(h_list + g_list, axis=1)

        feat_fn = function([feat_idx], hg)
        features = numpy.empty((60000, 11025), dtype='float32')
        for i in xrange(60000//batchsize):
            if i % 100 == 0:
                print("feature batch %i"%i)
            features[i*batchsize:(i+1)*batchsize] = feat_fn(i)

        print("saving features to %s"%feat_filename)
        numpy.save(feat_filename, features)
        all_features = features
        del features


    # CLASSIFY FEATURES

    if 0:
        # nothing to load
        pass
    else:
        batchsize=100

        if feat_filename.startswith('cifar'):
            learnrate = 0.002
            l1_regularization = 0.004
            anneal_epoch=100
            n_epochs = 500
        else:
            learnrate = 0.005
            l1_regularization = 0.004
            n_epochs = 100
            anneal_epoch=20

        x_i = tensor.matrix()
        y_i = tensor.ivector()
        lr = tensor.scalar()
        #l1_regularization = float(sys.argv[1]) #1.e-3
        #l2_regularization = float(sys.argv[2]) #1.e-3*0

        feature_logreg = LogisticRegression.new(x_i, 
                n_in = 11025, n_out=10,
                dtype=x_i.dtype)

        # marc'aurelio does this...
        feature_logreg.w.value = numpy.random.RandomState(44).randn(11025,10)*.02

        traincost = feature_logreg.nll(y_i).sum()
        traincost = traincost + abs(feature_logreg.w).sum() * l1_regularization
        #traincost = traincost + (feature_logreg.w**2).sum() * l2_regularization
        train_logreg_fn = function([x_i, y_i, lr], 
                [feature_logreg.nll(y_i).mean(),
                    feature_logreg.errors(y_i).mean()],
                updates=pylearn.gd.sgd.sgd_updates(
                    params=feature_logreg.params,
                    grads=tensor.grad(traincost, feature_logreg.params),
                    stepsizes=[lr,lr/10.]))

        all_labels = pylearn.dataset_ops.cifar10.all_data_labels('uint8')[1]
        pylearn.dataset_ops.cifar10.all_data_labels.forget() # clear memo cache
        assert len(all_labels)==60000
        if 0:
            print "Using validation set"
            train_labels = all_labels[:40000]
            valid_labels = all_labels[40000:50000]
            test_labels = all_labels[50000:60000]
            train_features = all_features[:40000]
            valid_features = all_features[40000:50000]
            test_features = all_features[50000:60000]
        else:
            print "NOT USING validation set"
            train_labels = all_labels[:50000]
            valid_labels = None
            test_labels = all_labels[50000:60000]
            train_features = all_features[:50000]
            valid_features = None
            test_features = all_features[50000:60000]

        if 1:
            print "Computing mean and std.dev"
            train_mean = train_features.mean(axis=0)
            train_std = train_features.std(axis=0)+1e-4
            preproc = lambda x: (x-train_mean)/(0.1+train_std)
        else:
            print "Not centering data"
            preproc = lambda x:x

        for epoch in xrange(n_epochs):
            print 'epoch', epoch
            # validate
            # Marc'Aurelio, you crazy!!
            # the division by batchsize is done in the cost function
            e_lr = learnrate / (batchsize*max(1.0, numpy.floor(max(1., epoch/float(anneal_epoch))-2)))

            if valid_features is not None:
                l01s = []
                nlls = []
                for i in xrange(10000/batchsize):
                    x_i = valid_features[i*batchsize:(i+1)*batchsize]
                    y_i = valid_labels[i*batchsize:(i+1)*batchsize]

                    #lr=0.0 -> no learning, safe for validation set
                    nll, l01 = train_logreg_fn(preproc(x_i), y_i, 0.0) 
                    nlls.append(nll)
                    l01s.append(l01)
                print 'validate log_reg', numpy.mean(nlls), numpy.mean(l01s)

            # test

            l01s = []
            nlls = []
            for i in xrange(len(test_features)//batchsize):
                x_i = test_features[i*batchsize:(i+1)*batchsize]
                y_i = test_labels[i*batchsize:(i+1)*batchsize]

                #lr=0.0 -> no learning, safe for validation set
                nll, l01 = train_logreg_fn(preproc(x_i), y_i, 0.0) 
                nlls.append(nll)
                l01s.append(l01)
            print 'test log_reg', numpy.mean(nlls), numpy.mean(l01s)

            #train
            l01s = []
            nlls = []
            for i in xrange(len(train_features)//batchsize):
                x_i = train_features[i*batchsize:(i+1)*batchsize]
                y_i = train_labels[i*batchsize:(i+1)*batchsize]
                nll, l01 = train_logreg_fn(preproc(x_i), y_i, e_lr)
                nlls.append(nll)
                l01s.append(l01)
            print 'train log_reg', numpy.mean(nlls), numpy.mean(l01s)




import pickle as cPickle
#import cPickle
if __name__ == '__main__':
    if 0: 
        #learning 16 x 16 pinwheel filters from official cifar patches (MAR)
        rbm,smplr = test_reproduce_ranzato_hinton_2010(
                as_unittest=False,
                n_train_iters=5000,
                rbm_alloc=lambda n_I : mcRBM_withP.alloc_topo_P(n_I, n_J=81),
                trainer_alloc=mcRBMTrainer.alloc_for_P,
                dataset='MAR'
                )

    if 0:
        # pretraining settings
        rbm,smplr = test_reproduce_ranzato_hinton_2010(
                as_unittest=False,
                n_train_iters=60000,
                rbm_alloc=lambda n_I : mcRBM_withP.alloc_topo_P(n_I, n_J=81),
                trainer_alloc=mcRBMTrainer.alloc_for_P,
                lr_per_example=0.05,
                dataset='tinyimages_patches',
                l1_penalty=1e-3,
                l1_penalty_start=30000,
                #l1_penalty_start=350, #DEBUG
                persistent_chains=False
                )

    if 1:
        def checkpoint():
            return checkpoint
        run_classif_experiment(checkpoint=checkpoint)



if 0: # TEST IDEA OUT HERE


    class doc_db(dict):
        # A key->document dictionary.
        # A "document" is itself a dictionary.

        # A "document" can be a small or large object, but it cannot be partially retrieved.

        # This simple data structure is used in pylearn to cache intermediate reults between
        # several process invocations.

    class UNSPECIFIED(object): pass

    class CtrlObj(object):

        def get(self, key, default_val=UNSPECIFIED, copy=True):
            # Default to return a COPY because a set() is required to make a change persistent.
            # Inplace changes that the CtrlObj does not know about (via set) will not be saved.
            pass

        def get_key(self, val):
            """Return the key that retrieved `val`.
            
            This is useful for specifying cache keys for unhashable (e.g. numpy) objects that
            happen to be stored in the db.
            """
            # if 
            # lookup whether val is an obj
            pass
        def set(self, key, val):
            pass
        def delete(self, key):
            pass
        def checkpoint(self):
            pass

        @staticmethod
        def cache_pickle(pass_ctrl=False):
            def decorator(f):
                # cache rval using pickle mechanism
                def rval(*args, **kwargs):
                    pass
                return rval
            return decorator

        @staticmethod
        def cache_dict(pass_ctrl=False):
            def decorator(f):
                # cache rval dict directly
                def rval(*args, **kwargs):
                    pass
                return rval
            return decorator

        @staticmethod(f):
        def cache_numpy(pass_ctrl=False, memmap_thresh=100*1000*1000):
            def decorator(f):
                # cache rval dict directly
                def rval(*args, **kwargs):
                    pass
                return rval
            return decorator

    @CtrlObj.cache_numpy()
    def get_whitened_dataset(pca_parameters):
        # do computations
        return None

    @CtrlObj.cache_pickle(pass_ctrl=True)
    def train_mcRBM(data, lr, n_hid, ctrl):

        rbm = 45
        for i in 10000:
            # do some training
            rbm += 1
            ctrl.checkpoint()
        return rbm

    def run_experiment(args):

        ctrl_obj = CtrlObj.factory(args)
        # Could use db, or filesystem, or both, etc.
        # There would be generic ones, but the experimenter should be very aware of what is being
        # cached where, when, and how.  This is how results are stored and retrieved after all.
        # Cluster-friendly jobs should not use local files directly, but should store cached
        # computations and results to such a database.
        #  Different jobs should avoid using the same keys in the database because coordinating
        #  writes is difficult, and conflicts will inevitably arise.

        raw_data = get_raw_data(ctrl=ctrl)
        raw_data_key = ctrl.get_key(raw_data)
        pca = get_pca(raw_data, max_energy=.05, ctrl=ctrl, 
                _ctrl_raw_data_key=raw_data_key)
        whitened_data = get_whitened_dataset(pca_parameters, ctrl=ctrl,
                _ctrl_data_key=raw_data_key)

        rbm = train_mcRBM(
                data=whitened_data,
                lr=0.01,
                n_hid=100,
                ctrl=ctrl,
                _ctrl_data_key=raw_data_key
                )