diff pylearn/algorithms/tests/test_mcRBM.py @ 1284:1817485d586d

mcRBM - many changes incl. adding support for pooling matrix
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 15 Sep 2010 17:49:21 -0400
parents 7bb5dd98e671
children 8905186b176c
line wrap: on
line diff
--- a/pylearn/algorithms/tests/test_mcRBM.py	Wed Sep 15 17:46:21 2010 -0400
+++ b/pylearn/algorithms/tests/test_mcRBM.py	Wed Sep 15 17:49:21 2010 -0400
@@ -1,33 +1,166 @@
 from pylearn.algorithms.mcRBM import *
+import pylearn.datasets.cifar10
+
+import pylearn.dataset_ops.cifar10
+
+def _mar_train_patches(dtype):
+    R,C=16,16
+    train_data = pylearn.dataset_ops.cifar10.train_data_labels(dtype)[0][:40000]
+    #train_data shape is (40000, 3072)
+    train_data = train_data.reshape((40000,3,32,32)).transpose([0,2,3,1])
+    patches = train_data[:, :R, :C, :].reshape((40000, 3*R*C))
+    patches -= patches.mean(axis=0)
+    wpatches = numpy.dot(patches, d['pcatransf'].T)
+    return wpatches
+
+def mar_centered(s_idx, split, dtype='float64', rasterized=False, color='grey'):
+    """ 
+    Returns a pair (img, label) of theano expressions for cifar-10 samples
+
+    :param s_idx: the indexes
+
+    :param split:
+
+    :param dtype:
+
+    :param rasterized: return examples as vectors (True) or 28x28 matrices (False)
+
+    :param color: control how to deal with the color in the images'
+      - grey   greyscale (with luminance weighting)
+      - rgb    add a trailing dimension of length 3 with rgb colour channels
+
+    """
+
+    split_options = {'train':(train_data, train_labels),
+            'valid': (valid_data, valid_labels),
+            'test': (test_data, test_labels)}
+
+    if split not in split_options:
+        raise ValueError('invalid split option', (split, split_options.keys()))
+
+    color_options = ('grey', 'rgb')
+    if color not in color_options:
+        raise ValueError('invalid color option', (color, color_options))
+
+    x_fn, y_fn = split_options[split]
+
+    x_op = TensorFnDataset(dtype, (False,), (x_fn, (dtype,)), (3072,))
+    y_op = TensorFnDataset('int32', (), y_fn)
+
+    x = x_op(s_idx)
+    y = y_op(s_idx)
+
+    # Y = 0.3R + 0.59G + 0.11B from
+    # http://gimp-savvy.com/BOOK/index.html?node54.html
+    rgb_dtype = 'float32'
+    if dtype == 'float64':
+        rgb_dtype = dtype
+    r = numpy.asarray(.3, dtype=rgb_dtype)
+    g = numpy.asarray(.59, dtype=rgb_dtype)
+    b = numpy.asarray(.11, dtype=rgb_dtype)
 
-def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True, n_train_iters=5000):
-    dataset='MAR'
+    if x.ndim == 1:
+        if rasterized:
+            if color=='grey':
+                x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((3,32*32)).T
+            else:
+                raise NotImplemented('color', color)
+        else:
+            if color=='grey':
+                x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+                x = x.reshape((32,32))
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((3,32,32)).dimshuffle(1, 2, 0)
+            else:
+                raise NotImplemented('color', color)
+    elif x.ndim == 2:
+        N = x.shape[0] # symbolic
+        if rasterized:
+            if color=='grey':
+                x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((N, 3,32*32)).dimshuffle(0, 2, 1)
+            else:
+                raise NotImplemented('color', color)
+        else:
+            if color=='grey':
+                x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+                x.reshape((N, 32, 32))
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((N,3,32,32)).dimshuffle(0, 2, 3, 1)
+            else:
+                raise NotImplemented('color', color)
+    else:
+        raise ValueError('x has too many dimensions', x.ndim)
+
+    return x, y
+
+
+def _default_rbm_alloc(n_I, n_K=256, n_J=100):
+    return mcRBM.alloc(n_I, n_K, n_J)
+
+def _default_trainer_alloc(rbm, train_batch, batchsize, l1_penalty, l1_penalty_start):
+    return mcRBMTrainer.alloc(rbm, train_batch, batchsize, l1_penalty=l1_penalty,
+            l1_penalty_start=l1_penalty_start,persistent_chains=persistent_chains)
+
+
+def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True, n_train_iters=5000,
+        rbm_alloc=_default_rbm_alloc, trainer_alloc=_default_trainer_alloc,
+        lr_per_example=.075,
+        l1_penalty=1e-3,
+        l1_penalty_start=1000,
+        persistent_chains=True,
+        ):
+
+    batchsize = 128
+
     if dataset == 'MAR':
         n_vis=105
         n_patches=10240
+        epoch_size=n_patches
+    elif dataset=='cifar10patches8x8':
+        R,C= 8,8 # the size of image patches
+        n_vis=96 # pca components
+        epoch_size=batchsize*500
+        n_patches=epoch_size*20
     else:
         R,C= 16,16 # the size of image patches
         n_vis=R*C
         n_patches=100000
-
-    n_burnin_steps=10000
-
-
-    l1_penalty=1e-3
-    no_l1_epochs = 10
-    effective_l1_penalty=0.0
-
-    epoch_size=n_patches
-    batchsize = 128
-    lr = 0.075 / batchsize
-    s_lr = TT.scalar()
-    n_K=256
-    n_J=100
+        epoch_size=n_patches
 
     def l2(X):
         return numpy.sqrt((X**2).sum())
+
     if dataset == 'MAR':
         tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
+    elif dataset == 'cifar10patches8x8':
+        def tile(X, fname):
+            _img = pylearn.datasets.cifar10.tile_rasterized_examples(
+                    pylearn.preprocessing.pca.pca_whiten_inverse(
+                        pylearn.dataset_ops.cifar10.random_cifar_patches_pca(
+                            n_vis, None, 'float32', n_patches, R, C,),
+                        X),
+                    img_shape=(R,C))
+            image_tiling.save_tiled_raster_images(_img, fname)
     else:
         def tile(X, fname):
             _img = image_tiling.tile_raster_images(X,
@@ -39,6 +172,10 @@
 
     if dataset == 'MAR':
         train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize))
+    elif dataset == 'cifar10patches8x8':
+        train_batch = pylearn.dataset_ops.cifar10.cifar10_patches(batch_idx * batchsize +
+                np.arange(batchsize), 'train', n_patches=n_patches, patch_size=(R,C),
+                pca_components=n_vis)
     else:
         train_batch = pylearn.dataset_ops.image_patches.image_patches(
                 s_idx = (batch_idx * batchsize + np.arange(batchsize)),
@@ -51,17 +188,34 @@
     if not as_unittest:
         imgs_fn = function([batch_idx], outputs=train_batch)
 
-    trainer = mcRBMTrainer.alloc(
-            mcRBM.alloc(n_I=n_vis, n_K=n_K, n_J=n_J),
+    trainer = trainer_alloc(
+            rbm_alloc(n_I=n_vis),
             train_batch,
-            batchsize, l1_penalty=TT.scalar())
+            batchsize, 
+            initial_lr_per_example=lr_per_example,
+            l1_penalty=l1_penalty,
+            l1_penalty_start=l1_penalty_start,
+            persistent_chains=persistent_chains)
     rbm=trainer.rbm
-    smplr = trainer.sampler
+
+    if persistent_chains:
+        grads = trainer.contrastive_grads()
+        learn_fn = function([batch_idx], 
+                outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
+                updates=trainer.cd_updates())
+    else:
+        learn_fn = function([batch_idx], outputs=[], updates=trainer.cd_updates())
 
-    grads = trainer.contrastive_grads()
-    learn_fn = function([batch_idx, trainer.l1_penalty], 
-            outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
-            updates=trainer.cd_updates())
+    if persistent_chains:
+        smplr = trainer.sampler
+    else:
+        smplr = trainer._last_cd1_sampler
+
+    if dataset == 'cifar10patches8x8':
+        cPickle.dump(
+                pylearn.dataset_ops.cifar10.random_cifar_patches_pca(
+                    n_vis, None, 'float32', n_patches, R, C,),
+                open('test_mcRBM.pca.pkl','w'))
 
     print "Learning..."
     last_epoch = -1
@@ -98,14 +252,20 @@
         if print_jj:
             if not as_unittest:
                 tile(imgs_fn(jj), "imgs_%06i.png"%jj)
-                tile(smplr.positions.value, "sample_%06i.png"%jj)
+                if persistent_chains:
+                    tile(smplr.positions.value, "sample_%06i.png"%jj)
                 tile(rbm.U.value.T, "U_%06i.png"%jj)
                 tile(rbm.W.value.T, "W_%06i.png"%jj)
 
             print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)
 
             print 'l2(U)', l2(rbm.U.value),
-            print 'l2(W)', l2(rbm.W.value)
+            print 'l2(W)', l2(rbm.W.value),
+            print 'l1_penalty', 
+            try:
+                print trainer.effective_l1_penalty.value
+            except:
+                print trainer.effective_l1_penalty
 
             print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
             print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
@@ -113,15 +273,16 @@
             print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
             print 'c min max', rbm.c.value.min(), rbm.c.value.max()
 
-            print 'parts min', smplr.positions.value.min(), 
-            print 'max',smplr.positions.value.max(),
-            print 'HMC step', smplr.stepsize,
-            print 'arate', smplr.avg_acceptance_rate
+            if persistent_chains:
+                print 'parts min', smplr.positions.value.min(), 
+                print 'max',smplr.positions.value.max(),
+            print 'HMC step', smplr.stepsize.value,
+            print 'arate', smplr.avg_acceptance_rate.value
 
 
-        l2_of_Ugrad = learn_fn(jj, effective_l1_penalty)
+        l2_of_Ugrad = learn_fn(jj)
 
-        if print_jj:
+        if persistent_chains and print_jj:
             print 'l2(U_grad)', float(l2_of_Ugrad[0]),
             print 'l2(U_inc)', float(l2_of_Ugrad[1]),
             print 'l2(W_inc)', float(l2_of_Ugrad[2]),
@@ -131,9 +292,42 @@
             #print 'FE+[2]', float(l2_of_Ugrad[5]),
             #print 'FE+[3]', float(l2_of_Ugrad[6])
 
-        if jj == no_l1_epochs * epoch_size/batchsize:
-            print "Activating L1 weight decay"
-            effective_l1_penalty = 1e-3
+        if not as_unittest:
+            if jj % 2000 == 0:
+                print ''
+                print 'Saving rbm...'
+                cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%jj, 'w'), -1)
+                if persistent_chains:
+                    print 'Saving sampler...'
+                    cPickle.dump(smplr, open('mcRBM.smplr.%06i.pkl'%jj, 'w'), -1)
+
 
     if not as_unittest:
         return rbm, smplr
+
+import pickle as cPickle
+#import cPickle
+if __name__ == '__main__':
+    if 0: 
+        #learning 16 x 16 pinwheel filters from official cifar patches (MAR)
+        rbm,smplr = test_reproduce_ranzato_hinton_2010(
+                as_unittest=False,
+                n_train_iters=5000,
+                rbm_alloc=lambda n_I : mcRBM_withP.alloc_topo_P(n_I, n_J=81),
+                trainer_alloc=mcRBMTrainer.alloc_for_P,
+                dataset='MAR'
+                )
+
+    if 1:
+        rbm,smplr = test_reproduce_ranzato_hinton_2010(
+                as_unittest=False,
+                n_train_iters=60000,
+                rbm_alloc=lambda n_I : mcRBM_withP.alloc_topo_P(n_I, n_J=81),
+                trainer_alloc=mcRBMTrainer.alloc_for_P,
+                lr_per_example=0.05,
+                dataset='cifar10patches8x8',
+                l1_penalty=1e-3,
+                l1_penalty_start=30000,
+                #l1_penalty_start=350, #DEBUG
+                persistent_chains=False
+                )