changeset 1284:1817485d586d

mcRBM - many changes incl. adding support for pooling matrix
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 15 Sep 2010 17:49:21 -0400
parents a73db8d65abb
children 976539956475
files pylearn/algorithms/mcRBM.py pylearn/algorithms/tests/test_mcRBM.py
diffstat 2 files changed, 499 insertions(+), 85 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py	Wed Sep 15 17:46:21 2010 -0400
+++ b/pylearn/algorithms/mcRBM.py	Wed Sep 15 17:49:21 2010 -0400
@@ -45,9 +45,9 @@
         - \sum_j c_j g_j
 
 For the energy function to correspond to a probability distribution, P must be non-positive.  P
-is initialized to be a diagonal, and in our experience it can be left as such because even in
-the paper it has a very low learning rate, and is only allowed to be updated after the filters
-in U are learned (in effect).
+is initialized to be a diagonal or a topological pooling matrix, and in our experience it can
+be left as such because even in the paper it has a very low learning rate, and is only allowed
+to be updated after the filters in U are learned (in effect).
 
 Version in published train_mcRBM code
 -------------------------------------
@@ -90,6 +90,13 @@
         - \sum_j \sum_i W_{ij} g_j v_i
         - \sum_j c_j g_j
 
+    E (v, h, g) =
+        - 0.5 \sum_f \sum_k P_{fk} h_k (\sum_i U_{if} v_i / sqrt(\sum_i v_i^2/I + 0.5))^2 
+        - \sum_k b_k h_k
+        + 0.5 \sum_i v_i^2
+        - \sum_j \sum_i W_{ij} g_j v_i
+        - \sum_j c_j g_j
+
       
 
 Conventions in this file
@@ -101,12 +108,14 @@
 
 Global functions like `free_energy` work on an mcRBM as parametrized in a particular way.
 Suppose we have 
-I input dimensions, 
-F squared filters, 
-J mean variables, and 
-K covariance variables.
-The mcRBM is parametrized by 5 variables:
+ - I input dimensions, 
+ - F squared filters, 
+ - J mean variables, and
+ - K covariance variables.
 
+The mcRBM is parametrized by 6 variables:
+
+ - `P`, a matrix whose rows indicate covariance filter groups (F x K)
  - `U`, a matrix whose rows are visible covariance directions (I x F)
  - `W`, a matrix whose rows are visible mean directions (I x J)
  - `b`, a vector of hidden covariance biases (K)
@@ -202,8 +211,6 @@
 sharedX = lambda X, name : shared(numpy.asarray(X, dtype=floatX), name=name)
 
 import pylearn
-#TODO: clean up the HMC_sampler code
-#TODO: think of naming convention for acronyms + suffix?
 from pylearn.sampling.hmc import HMC_sampler
 from pylearn.io import image_tiling
 from pylearn.gd.sgd import sgd_updates
@@ -357,7 +364,7 @@
         
         """
         h = TT.nnet.sigmoid(self.hidden_cov_units_preactivation_given_v(v))
-        g = nnet.sigmoid(self.c + dot(v,self.W))
+        g = TT.nnet.sigmoid(self.c + dot(v,self.W))
         return (h, g)
 
     def n_visible_units(self):
@@ -372,6 +379,51 @@
         except AttributeError:
             return self.W.shape[0]
 
+    def n_hidden_cov_units(self):
+        """Return the number of hidden units for the covariance in this RBM
+
+        For an RBM made from shared variables, this will return an integer,
+        for a purely symbolic RBM this will return a theano expression.
+        
+        """
+        try:
+            return self.U.value.shape[1]
+        except AttributeError:
+            return self.U.shape[1]
+
+    def n_hidden_mean_units(self):
+        """Return the number of hidden units for the mean in this RBM
+
+        For an RBM made from shared variables, this will return an integer,
+        for a purely symbolic RBM this will return a theano expression.
+        
+        """
+        try:
+            return self.W.value.shape[1]
+        except AttributeError:
+            return self.W.shape[1]
+
+    def CD1_sampler(self, v, n_particles, n_visible=None, rng=8923984):
+        """Return a symbolic negative-phase particle obtained by simulating the Hamiltonian
+        associated with the energy function.
+        """
+        #TODO: why not expose all HMC arguments somehow?
+        if not hasattr(rng, 'randn'):
+            rng = np.random.RandomState(rng)
+        if n_visible is None:
+            n_visible = self.n_visible_units()
+
+        # create a dummy hmc object because we want to use *some* of it
+        hmc = HMC_sampler.new_from_shared_positions(
+                shared_positions=v, # v is not shared, so some functionality will not work
+                energy_fn=self.free_energy_given_v,
+                seed=int(rng.randint(2**30)),
+                shared_positions_shape=(n_particles,n_visible),
+                compile_simulate=False)
+        updates = dict(hmc.updates())
+        final_p = updates.pop(v)
+        return hmc, final_p, updates
+
     def sampler(self, n_particles, n_visible=None, rng=7823748):
         """Return an `HMC_sampler` that will draw samples from the distribution over visible
         units specified by this RBM.
@@ -379,6 +431,8 @@
         :param n_particles: this many parallel chains will be simulated.
         :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation.
         """
+        #TODO: why not expose all HMC arguments somehow?
+        #TODO: Consider returning a sample kwargs for passing to HMC_sampler?
         if not hasattr(rng, 'randn'):
             rng = np.random.RandomState(rng)
         if n_visible is None:
@@ -393,25 +447,26 @@
             seed=int(rng.randint(2**30)))
         return rval
 
-    def as_feedforward_layer(self, v):
-        """Return a dictionary with keys: inputs, outputs and params
+    if 0:
+        def as_feedforward_layer(self, v):
+            """Return a dictionary with keys: inputs, outputs and params
 
-        The inputs is [v]
+            The inputs is [v]
 
-        The outputs is :math:`[E[h|v], E[g|v]]` where `h` is the covariance hidden units and `g` is
-        the mean hidden units.
+            The outputs is :math:`[E[h|v], E[g|v]]` where `h` is the covariance hidden units and `g` is
+            the mean hidden units.
 
-        The params are ``[U, W, b, c]``, the model parameters that enter into the conditional
-        expectations.
+            The params are ``[U, W, b, c]``, the model parameters that enter into the conditional
+            expectations.
 
-        :TODO: add an optional parameter to return only one of the expections.
+            :TODO: add an optional parameter to return only one of the expections.
 
-        """
-        return dict(
-                inputs = [v],
-                outputs = list(self.expected_h_g_given_v(v)),
-                params = [self.U, self.W, self.b, self.c],
-                )
+            """
+            return dict(
+                    inputs = [v],
+                    outputs = list(self.expected_h_g_given_v(v)),
+                    params = [self.U, self.W, self.b, self.c],
+                    )
 
     def params(self):
         """Return the elements of [U,W,a,b,c] that are shared variables
@@ -452,6 +507,118 @@
         rval._params = [rval.U, rval.W, rval.a, rval.b, rval.c]
         return rval
 
+def topological_connectivity(out_shape=(12,12), window_shape=(3,3), window_stride=(2,2),
+        **kwargs):
+
+    in_shape = (window_stride[0] * out_shape[0],
+            window_stride[1] * out_shape[1])
+
+    rval = numpy.zeros(in_shape + out_shape, dtype=theano.config.floatX)
+    A,B,C,D = rval.shape
+
+    # for each output position (out_r, out_c)
+    for out_r in range(out_shape[0]):
+        for out_c in range(out_shape[1]):
+            # for each window position (win_r, win_c)
+            for win_r in range(window_shape[0]):
+                for win_c in range(window_shape[1]):
+                    # add 1 to the corresponding input location
+                    in_r = out_r * window_stride[0] + win_r
+                    in_c = out_c * window_stride[1] + win_c
+                    rval[in_r%A, in_c%B, out_r%C, out_c%D] += 1
+
+    # This normalization algorithm is a guess, based on inspection of the matrix loaded from 
+    # see CVPR2010paper_material/topo2D_3x3_stride2_576filt.mat
+    rval = rval.reshape((A*B, C*D))
+    rval = (rval.T / rval.sum(axis=1)).T
+
+    rval /= rval.sum(axis=0)
+    return rval
+
+class mcRBM_withP(mcRBM):
+    """Light-weight class that provides the math related to inference
+
+    Attributes:
+
+      - U - the covariance filters (theano shared variable)
+      - W - the mean filters (theano shared variable)
+      - a - the visible bias (theano shared variable)
+      - b - the covariance bias (theano shared variable)
+      - c - the mean bias (theano shared variable)
+
+    """
+    def __init__(self, U, W, a, b, c, P):
+        self.P = P
+        super(mcRBM_withP, self).__init__(U,W,a,b,c)
+
+    def hidden_cov_units_preactivation_given_v(self, v, small=0.5):
+        """Return argument to the sigmoid that would give mean of covariance hid units
+
+        See the math at the top of this file for what 'adjusted' means.
+
+        return b - 0.5 * dot(adjusted(v), U)**2
+        """
+        unit_v = v / (TT.sqrt(TT.mean(v**2, axis=1)+small)).dimshuffle(0,'x') # adjust row norm
+        return self.b + 0.5 * dot(dot(unit_v, self.U)**2, self.P)
+
+    def n_hidden_cov_units(self):
+        """Return the number of hidden units for the covariance in this RBM
+
+        For an RBM made from shared variables, this will return an integer,
+        for a purely symbolic RBM this will return a theano expression.
+        
+        """
+        try:
+            return self.P.value.shape[1]
+        except AttributeError:
+            return self.P.shape[1]
+
+    @classmethod
+    def alloc(cls, n_I, n_K, n_J, *args, **kwargs):
+        """
+        Return a MeanCovRBM instance with randomly-initialized shared variable parameters.
+
+        :param n_I: input dimensionality
+        :param n_K: number of covariance hidden units
+        :param n_J: number of mean filters (linear)
+        :param rng: seed or numpy RandomState object to initialize parameters
+        
+        :note:
+        Constants for initial ranges and values taken from train_mcRBM.py.
+        """
+        return cls.alloc_with_P(
+            -numpy.eye((n_K, n_K)).astype(theano.config.floatX),
+            n_I,
+            n_J,
+            *args, **kwargs)
+
+    @classmethod
+    def alloc_topo_P(cls, n_I, n_J, p_out_shape=(12,12), p_win_shape=(3,3), p_win_stride=(2,2),
+            **kwargs):
+        return cls.alloc_with_P(
+                -topological_connectivity(p_out_shape, p_win_shape, p_win_stride),
+                n_I=n_I, n_J=n_J, **kwargs)
+
+    @classmethod
+    def alloc_with_P(cls, Pval, n_I, n_J, rng = 8923402190,
+            U_range=0.02,
+            W_range=0.05,
+            a_ival=0,
+            b_ival=2,
+            c_ival=-2):
+        n_F, n_K = Pval.shape
+        if not hasattr(rng, 'randn'):
+            rng = np.random.RandomState(rng)
+        rval =  cls(
+                U = sharedX(U_range * rng.randn(n_I, n_F),'U'),
+                W = sharedX(W_range * rng.randn(n_I, n_J),'W'),
+                a = sharedX(np.ones(n_I)*a_ival,'a'),
+                b = sharedX(np.ones(n_K)*b_ival,'b'),
+                c = sharedX(np.ones(n_J)*c_ival,'c'),
+                P = sharedX(Pval, 'P'),)
+        rval._params = [rval.U, rval.W, rval.a, rval.b, rval.c, rval.P]
+        return rval
+
 class mcRBMTrainer(object):
     """Light-weight class encapsulating math for mcRBM training 
 
@@ -470,17 +637,43 @@
     """
     # TODO: accept a GD algo as an argument?
     @classmethod
-    def alloc(cls, rbm, visible_batch, batchsize, initial_lr=0.075, rng=234,
+    def alloc_for_P(cls, rbm, visible_batch, batchsize, initial_lr_per_example=0.075, rng=234,
             l1_penalty=0,
+            l1_penalty_start=0,
+            learn_rate_multipliers=None,
+            lr_anneal_start=2000,
+            p_training_start=4000,
+            p_training_lr=0.02,
+            persistent_chains=True
+            ):
+        if learn_rate_multipliers is None:
+            p_lr = sharedX(0.0, 'P_lr_multiplier')
+            learn_rate_multipliers = [2, .2, .02, .1, .02, p_lr]
+        else:
+            p_lr = None
+        rval = cls.alloc(rbm, visible_batch, batchsize, initial_lr_per_example, rng, l1_penalty,
+                l1_penalty_start, learn_rate_multipliers, lr_anneal_start, persistent_chains)
+
+        rval.p_lr = p_lr
+        rval.p_training_start=p_training_start
+        rval.p_training_lr=p_training_lr
+        return rval
+
+
+    @classmethod
+    def alloc(cls, rbm, visible_batch, batchsize, initial_lr_per_example=0.075, rng=234,
+            l1_penalty=0,
+            l1_penalty_start=0,
             learn_rate_multipliers=[2, .2, .02, .1, .02],
             lr_anneal_start=2000,
+            persistent_chains=True
             ):
 
         """
         :param rbm: mcRBM instance to train
         :param visible_batch: TensorType variable for training data
         :param batchsize: the number of rows in visible_batch
-        :param initial_lr: the learning rate (may be annealed)
+        :param initial_lr_per_example: the learning rate (may be annealed)
         :param rng: seed or RandomState to initialze PCD sampler
         :param l1_penalty: see class doc
         :param learn_rate_multipliers: see class doc
@@ -494,16 +687,30 @@
 
         # TODO: should normVF be initialized to match the size of rbm.U ?
 
+        if (l1_penalty_start > 0) and (l1_penalty != 0.0):
+            effective_l1_penalty = sharedX(0.0, 'l1_penalty')
+        else:
+            effective_l1_penalty = l1_penalty
+
+        if persistent_chains:
+            sampler = rbm.sampler(batchsize, rng=rng)
+        else:
+            sampler = None
+
         return cls(
                 rbm=rbm,
+                batchsize=batchsize,
                 visible_batch=visible_batch,
-                sampler=rbm.sampler(batchsize, rng=rng),
+                sampler=sampler,
                 normVF=sharedX(1.0, 'normVF'),
-                learn_rate=sharedX(initial_lr/batchsize, 'learn_rate'),
+                learn_rate=sharedX(initial_lr_per_example/batchsize, 'learn_rate'),
                 iter=sharedX(0, 'iter'),
+                effective_l1_penalty=effective_l1_penalty,
                 l1_penalty=l1_penalty,
+                l1_penalty_start=l1_penalty_start,
                 learn_rate_multipliers=learn_rate_multipliers,
-                lr_anneal_start=lr_anneal_start)
+                lr_anneal_start=lr_anneal_start,
+                persistent_chains=persistent_chains,)
 
     def __init__(self, **kwargs):
         self.__dict__.update(kwargs)
@@ -522,14 +729,16 @@
         new_normVF = .95 * self.normVF + .05 * TT.mean(U_norms)
         return (new_U * new_normVF / U_norms), new_normVF
 
-    def contrastive_grads(self):
+    def contrastive_grads(self, neg_v = None):
         """Return the contrastive divergence gradients on the parameters of self.rbm """
+        if neg_v is None:
+            neg_v = self.sampler.positions
         return contrastive_grad(
                 free_energy_fn=self.rbm.free_energy_given_v,
                 pos_v=self.visible_batch, 
-                neg_v=self.sampler.positions,
+                neg_v=neg_v,
                 wrt = self.rbm.params(),
-                other_cost=(l1(self.rbm.U)+l1(self.rbm.W)) * self.l1_penalty)
+                other_cost=(l1(self.rbm.U)+l1(self.rbm.W)) * self.effective_l1_penalty)
 
     def cd_updates(self):
         """
@@ -537,7 +746,19 @@
         learning by stochastic gradient descent with an annealed learning rate.
         """
 
-        grads = self.contrastive_grads()
+        ups = {}
+
+        if self.persistent_chains:
+            grads = self.contrastive_grads()
+            ups.update(dict(self.sampler.updates()))
+        else:
+            cd1_sampler, final_p, cd1_updates = self.rbm.CD1_sampler(self.visible_batch,
+                    self.batchsize)
+            self._last_cd1_sampler = cd1_sampler # hacked in here for the unit test
+            #ignore the cd1_sampler
+            grads = self.contrastive_grads(neg_v = final_p)
+            ups.update(dict(cd1_updates))
+
 
         # contrastive divergence updates
         # TODO: sgd_updates is a particular optization algo (others are possible)
@@ -552,30 +773,29 @@
                 0.0, #min
                 self.learn_rate) #max
 
-        ups = dict(sgd_updates(
+        ups.update(dict(sgd_updates(
                     self.rbm.params(),
                     grads,
-                    stepsizes=[a*lr for a in self.learn_rate_multipliers]))
+                    stepsizes=[a*lr for a in self.learn_rate_multipliers])))
 
         ups[self.iter] = self.iter + 1
 
-        # sampler updates
-        ups.update(dict(self.sampler.updates()))
 
         # add trainer updates (replace CD update of U)
         ups[self.rbm.U], ups[self.normVF] = self.normalize_U(ups[self.rbm.U])
 
+        #l1_updates:
+        if (self.l1_penalty_start > 0) and (self.l1_penalty != 0.0):
+            ups[self.effective_l1_penalty] = TT.switch(
+                    self.iter >= self.l1_penalty_start,
+                    self.l1_penalty,
+                    0.0)
+
+        if getattr(self,'p_lr', None):
+            ups[self.p_lr] = TT.switch(self.iter > self.p_training_start,
+                    self.p_training_lr,
+                    0)
+            ups[self.rbm.P] = TT.clip(ups[self.rbm.P], -5, 0)
+
         return ups
 
-if __name__ == '__main__':
-    import pylearn.algorithms.tests.test_mcRBM
-    rbm,smplr = pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010(
-            as_unittest=False,
-            n_train_iters=10)
-    import cPickle
-    print ''
-    print 'Saving rbm...'
-    cPickle.dump(rbm, open('mcRBM.rbm.pkl', 'w'), -1)
-    print 'Saving sampler...'
-    cPickle.dump(smplr, open('mcRBM.smplr.pkl', 'w'), -1)
-
--- a/pylearn/algorithms/tests/test_mcRBM.py	Wed Sep 15 17:46:21 2010 -0400
+++ b/pylearn/algorithms/tests/test_mcRBM.py	Wed Sep 15 17:49:21 2010 -0400
@@ -1,33 +1,166 @@
 from pylearn.algorithms.mcRBM import *
+import pylearn.datasets.cifar10
+
+import pylearn.dataset_ops.cifar10
+
+def _mar_train_patches(dtype):
+    R,C=16,16
+    train_data = pylearn.dataset_ops.cifar10.train_data_labels(dtype)[0][:40000]
+    #train_data shape is (40000, 3072)
+    train_data = train_data.reshape((40000,3,32,32)).transpose([0,2,3,1])
+    patches = train_data[:, :R, :C, :].reshape((40000, 3*R*C))
+    patches -= patches.mean(axis=0)
+    wpatches = numpy.dot(patches, d['pcatransf'].T)
+    return wpatches
+
+def mar_centered(s_idx, split, dtype='float64', rasterized=False, color='grey'):
+    """ 
+    Returns a pair (img, label) of theano expressions for cifar-10 samples
+
+    :param s_idx: the indexes
+
+    :param split:
+
+    :param dtype:
+
+    :param rasterized: return examples as vectors (True) or 28x28 matrices (False)
+
+    :param color: control how to deal with the color in the images'
+      - grey   greyscale (with luminance weighting)
+      - rgb    add a trailing dimension of length 3 with rgb colour channels
+
+    """
+
+    split_options = {'train':(train_data, train_labels),
+            'valid': (valid_data, valid_labels),
+            'test': (test_data, test_labels)}
+
+    if split not in split_options:
+        raise ValueError('invalid split option', (split, split_options.keys()))
+
+    color_options = ('grey', 'rgb')
+    if color not in color_options:
+        raise ValueError('invalid color option', (color, color_options))
+
+    x_fn, y_fn = split_options[split]
+
+    x_op = TensorFnDataset(dtype, (False,), (x_fn, (dtype,)), (3072,))
+    y_op = TensorFnDataset('int32', (), y_fn)
+
+    x = x_op(s_idx)
+    y = y_op(s_idx)
+
+    # Y = 0.3R + 0.59G + 0.11B from
+    # http://gimp-savvy.com/BOOK/index.html?node54.html
+    rgb_dtype = 'float32'
+    if dtype == 'float64':
+        rgb_dtype = dtype
+    r = numpy.asarray(.3, dtype=rgb_dtype)
+    g = numpy.asarray(.59, dtype=rgb_dtype)
+    b = numpy.asarray(.11, dtype=rgb_dtype)
 
-def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True, n_train_iters=5000):
-    dataset='MAR'
+    if x.ndim == 1:
+        if rasterized:
+            if color=='grey':
+                x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((3,32*32)).T
+            else:
+                raise NotImplemented('color', color)
+        else:
+            if color=='grey':
+                x = r * x[:1024] + g * x[1024:2048] + b * x[2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+                x = x.reshape((32,32))
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((3,32,32)).dimshuffle(1, 2, 0)
+            else:
+                raise NotImplemented('color', color)
+    elif x.ndim == 2:
+        N = x.shape[0] # symbolic
+        if rasterized:
+            if color=='grey':
+                x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((N, 3,32*32)).dimshuffle(0, 2, 1)
+            else:
+                raise NotImplemented('color', color)
+        else:
+            if color=='grey':
+                x = r * x[:,:1024] + g * x[:,1024:2048] + b * x[:,2048:]
+                if dtype=='uint8':
+                    x = theano.tensor.cast(x, 'uint8')
+                x.reshape((N, 32, 32))
+            elif color=='rgb':
+                # the strides aren't what you'd expect between channels,
+                # but theano is all about weird strides
+                x = x.reshape((N,3,32,32)).dimshuffle(0, 2, 3, 1)
+            else:
+                raise NotImplemented('color', color)
+    else:
+        raise ValueError('x has too many dimensions', x.ndim)
+
+    return x, y
+
+
+def _default_rbm_alloc(n_I, n_K=256, n_J=100):
+    return mcRBM.alloc(n_I, n_K, n_J)
+
+def _default_trainer_alloc(rbm, train_batch, batchsize, l1_penalty, l1_penalty_start):
+    return mcRBMTrainer.alloc(rbm, train_batch, batchsize, l1_penalty=l1_penalty,
+            l1_penalty_start=l1_penalty_start,persistent_chains=persistent_chains)
+
+
+def test_reproduce_ranzato_hinton_2010(dataset='MAR', as_unittest=True, n_train_iters=5000,
+        rbm_alloc=_default_rbm_alloc, trainer_alloc=_default_trainer_alloc,
+        lr_per_example=.075,
+        l1_penalty=1e-3,
+        l1_penalty_start=1000,
+        persistent_chains=True,
+        ):
+
+    batchsize = 128
+
     if dataset == 'MAR':
         n_vis=105
         n_patches=10240
+        epoch_size=n_patches
+    elif dataset=='cifar10patches8x8':
+        R,C= 8,8 # the size of image patches
+        n_vis=96 # pca components
+        epoch_size=batchsize*500
+        n_patches=epoch_size*20
     else:
         R,C= 16,16 # the size of image patches
         n_vis=R*C
         n_patches=100000
-
-    n_burnin_steps=10000
-
-
-    l1_penalty=1e-3
-    no_l1_epochs = 10
-    effective_l1_penalty=0.0
-
-    epoch_size=n_patches
-    batchsize = 128
-    lr = 0.075 / batchsize
-    s_lr = TT.scalar()
-    n_K=256
-    n_J=100
+        epoch_size=n_patches
 
     def l2(X):
         return numpy.sqrt((X**2).sum())
+
     if dataset == 'MAR':
         tile = pylearn.dataset_ops.image_patches.save_filters_of_ranzato_hinton_2010
+    elif dataset == 'cifar10patches8x8':
+        def tile(X, fname):
+            _img = pylearn.datasets.cifar10.tile_rasterized_examples(
+                    pylearn.preprocessing.pca.pca_whiten_inverse(
+                        pylearn.dataset_ops.cifar10.random_cifar_patches_pca(
+                            n_vis, None, 'float32', n_patches, R, C,),
+                        X),
+                    img_shape=(R,C))
+            image_tiling.save_tiled_raster_images(_img, fname)
     else:
         def tile(X, fname):
             _img = image_tiling.tile_raster_images(X,
@@ -39,6 +172,10 @@
 
     if dataset == 'MAR':
         train_batch = pylearn.dataset_ops.image_patches.ranzato_hinton_2010_op(batch_idx * batchsize + np.arange(batchsize))
+    elif dataset == 'cifar10patches8x8':
+        train_batch = pylearn.dataset_ops.cifar10.cifar10_patches(batch_idx * batchsize +
+                np.arange(batchsize), 'train', n_patches=n_patches, patch_size=(R,C),
+                pca_components=n_vis)
     else:
         train_batch = pylearn.dataset_ops.image_patches.image_patches(
                 s_idx = (batch_idx * batchsize + np.arange(batchsize)),
@@ -51,17 +188,34 @@
     if not as_unittest:
         imgs_fn = function([batch_idx], outputs=train_batch)
 
-    trainer = mcRBMTrainer.alloc(
-            mcRBM.alloc(n_I=n_vis, n_K=n_K, n_J=n_J),
+    trainer = trainer_alloc(
+            rbm_alloc(n_I=n_vis),
             train_batch,
-            batchsize, l1_penalty=TT.scalar())
+            batchsize, 
+            initial_lr_per_example=lr_per_example,
+            l1_penalty=l1_penalty,
+            l1_penalty_start=l1_penalty_start,
+            persistent_chains=persistent_chains)
     rbm=trainer.rbm
-    smplr = trainer.sampler
+
+    if persistent_chains:
+        grads = trainer.contrastive_grads()
+        learn_fn = function([batch_idx], 
+                outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
+                updates=trainer.cd_updates())
+    else:
+        learn_fn = function([batch_idx], outputs=[], updates=trainer.cd_updates())
 
-    grads = trainer.contrastive_grads()
-    learn_fn = function([batch_idx, trainer.l1_penalty], 
-            outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
-            updates=trainer.cd_updates())
+    if persistent_chains:
+        smplr = trainer.sampler
+    else:
+        smplr = trainer._last_cd1_sampler
+
+    if dataset == 'cifar10patches8x8':
+        cPickle.dump(
+                pylearn.dataset_ops.cifar10.random_cifar_patches_pca(
+                    n_vis, None, 'float32', n_patches, R, C,),
+                open('test_mcRBM.pca.pkl','w'))
 
     print "Learning..."
     last_epoch = -1
@@ -98,14 +252,20 @@
         if print_jj:
             if not as_unittest:
                 tile(imgs_fn(jj), "imgs_%06i.png"%jj)
-                tile(smplr.positions.value, "sample_%06i.png"%jj)
+                if persistent_chains:
+                    tile(smplr.positions.value, "sample_%06i.png"%jj)
                 tile(rbm.U.value.T, "U_%06i.png"%jj)
                 tile(rbm.W.value.T, "W_%06i.png"%jj)
 
             print 'saving samples', jj, 'epoch', jj/(epoch_size/batchsize)
 
             print 'l2(U)', l2(rbm.U.value),
-            print 'l2(W)', l2(rbm.W.value)
+            print 'l2(W)', l2(rbm.W.value),
+            print 'l1_penalty', 
+            try:
+                print trainer.effective_l1_penalty.value
+            except:
+                print trainer.effective_l1_penalty
 
             print 'U min max', rbm.U.value.min(), rbm.U.value.max(),
             print 'W min max', rbm.W.value.min(), rbm.W.value.max(),
@@ -113,15 +273,16 @@
             print 'b min max', rbm.b.value.min(), rbm.b.value.max(),
             print 'c min max', rbm.c.value.min(), rbm.c.value.max()
 
-            print 'parts min', smplr.positions.value.min(), 
-            print 'max',smplr.positions.value.max(),
-            print 'HMC step', smplr.stepsize,
-            print 'arate', smplr.avg_acceptance_rate
+            if persistent_chains:
+                print 'parts min', smplr.positions.value.min(), 
+                print 'max',smplr.positions.value.max(),
+            print 'HMC step', smplr.stepsize.value,
+            print 'arate', smplr.avg_acceptance_rate.value
 
 
-        l2_of_Ugrad = learn_fn(jj, effective_l1_penalty)
+        l2_of_Ugrad = learn_fn(jj)
 
-        if print_jj:
+        if persistent_chains and print_jj:
             print 'l2(U_grad)', float(l2_of_Ugrad[0]),
             print 'l2(U_inc)', float(l2_of_Ugrad[1]),
             print 'l2(W_inc)', float(l2_of_Ugrad[2]),
@@ -131,9 +292,42 @@
             #print 'FE+[2]', float(l2_of_Ugrad[5]),
             #print 'FE+[3]', float(l2_of_Ugrad[6])
 
-        if jj == no_l1_epochs * epoch_size/batchsize:
-            print "Activating L1 weight decay"
-            effective_l1_penalty = 1e-3
+        if not as_unittest:
+            if jj % 2000 == 0:
+                print ''
+                print 'Saving rbm...'
+                cPickle.dump(rbm, open('mcRBM.rbm.%06i.pkl'%jj, 'w'), -1)
+                if persistent_chains:
+                    print 'Saving sampler...'
+                    cPickle.dump(smplr, open('mcRBM.smplr.%06i.pkl'%jj, 'w'), -1)
+
 
     if not as_unittest:
         return rbm, smplr
+
+import pickle as cPickle
+#import cPickle
+if __name__ == '__main__':
+    if 0: 
+        #learning 16 x 16 pinwheel filters from official cifar patches (MAR)
+        rbm,smplr = test_reproduce_ranzato_hinton_2010(
+                as_unittest=False,
+                n_train_iters=5000,
+                rbm_alloc=lambda n_I : mcRBM_withP.alloc_topo_P(n_I, n_J=81),
+                trainer_alloc=mcRBMTrainer.alloc_for_P,
+                dataset='MAR'
+                )
+
+    if 1:
+        rbm,smplr = test_reproduce_ranzato_hinton_2010(
+                as_unittest=False,
+                n_train_iters=60000,
+                rbm_alloc=lambda n_I : mcRBM_withP.alloc_topo_P(n_I, n_J=81),
+                trainer_alloc=mcRBMTrainer.alloc_for_P,
+                lr_per_example=0.05,
+                dataset='cifar10patches8x8',
+                l1_penalty=1e-3,
+                l1_penalty_start=30000,
+                #l1_penalty_start=350, #DEBUG
+                persistent_chains=False
+                )