changeset 1272:ba25c6e4f55d

mcRBM working with whole learning algo in theano
author James Bergstra <bergstrj@iro.umontreal.ca>
date Sat, 04 Sep 2010 19:32:27 -0400
parents cc6c6d7234a7
children 7bb5dd98e671
files pylearn/algorithms/mcRBM.py pylearn/algorithms/tests/test_mcRBM.py
diffstat 2 files changed, 176 insertions(+), 89 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/algorithms/mcRBM.py	Sat Sep 04 19:31:16 2010 -0400
+++ b/pylearn/algorithms/mcRBM.py	Sat Sep 04 19:32:27 2010 -0400
@@ -199,6 +199,8 @@
 from theano import tensor as TT
 floatX = theano.config.floatX
 
+sharedX = lambda X, name : shared(numpy.asarray(X, dtype=floatX), name=name)
+
 import pylearn
 #TODO: clean up the HMC_sampler code
 #TODO: think of naming convention for acronyms + suffix?
@@ -213,31 +215,83 @@
 #
 ###########################################
 
-#TODO: Document, move to pylearn's math lib
 def l1(X):
+    """
+    :param X: TensorType variable
+
+    :rtype: TensorType scalar
+
+    :returns: the sum of absolute values of the terms in X
+
+    :math: \sum_i |X_i|
+
+    Where i is an appropriately dimensioned index.
+
+    """
     return abs(X).sum()
 
-#TODO: Document, move to pylearn's math lib
 def l2(X):
+    """
+    :param X: TensorType variable
+
+    :rtype: TensorType scalar
+
+    :returns: the sum of absolute values of the terms in X
+
+    :math: \sqrt{ \sum_i X_i^2 }
+
+    Where i is an appropriately dimensioned index.
+
+    """
     return TT.sqrt((X**2).sum())
 
-#TODO: Document, move to pylearn's math lib
 def contrastive_cost(free_energy_fn, pos_v, neg_v):
+    """
+    :param free_energy_fn: lambda (TensorType matrix MxN) ->  TensorType vector of M free energies
+    :param pos_v: TensorType matrix MxN of M "positive phase" particles
+    :param neg_v: TensorType matrix MxN of M "negative phase" particles
+
+    :returns: TensorType scalar that's the sum of the difference of free energies
+
+    :math: \sum_i free_energy(pos_v[i]) - free_energy(neg_v[i])
+
+    """
     return (free_energy_fn(pos_v) - free_energy_fn(neg_v)).sum()
 
-#TODO: Typical use of contrastive_cost is to later use tensor.grad, but in that case we want to
-#      block  gradient going through neg_v
-def contrastive_grad(free_energy_fn, pos_v, neg_v, params, other_cost=0):
+def contrastive_grad(free_energy_fn, pos_v, neg_v, wrt, other_cost=0):
     """
+    :param free_energy_fn: lambda (TensorType matrix MxN) ->  TensorType vector of M free energies
     :param pos_v: positive-phase sample of visible units
     :param neg_v: negative-phase sample of visible units
+    :param wrt: TensorType variables with respect to which we want gradients (similar to the
+        'wrt' argument to tensor.grad)
+    :param other_cost: TensorType scalar 
+
+    :returns: TensorType variables for the gradient on each of the 'wrt' arguments
+
+
+    :math: Cost = other_cost + \sum_i free_energy(pos_v[i]) - free_energy(neg_v[i])
+    :math: d Cost / dW for W in `wrt`
+
+
+    This function is similar to tensor.grad - it returns the gradient[s] on a cost with respect
+    to one or more parameters.  The difference between tensor.grad and this function is that
+    the negative phase term (`neg_v`) is considered constant, i.e. d `Cost` / d `neg_v` = 0.
+    This is desirable because `neg_v` might be the result of a sampling expression involving
+    some of the parameters, but the contrastive divergence algorithm does not call for
+    backpropagating through the sampling procedure.
+
+    Warning - if other_cost depends on pos_v or neg_v and you *do* want to backpropagate from
+    the `other_cost` through those terms, then this function is inappropriate.  In that case,
+    you should call tensor.grad separately for the other_cost and add the gradient expressions
+    you get from ``contrastive_grad(..., other_cost=0)``
+
     """
-    #block the grad through neg_v
     cost=contrastive_cost(free_energy_fn, pos_v, neg_v)
     if other_cost:
         cost = cost + other_cost
     return theano.tensor.grad(cost,
-            wrt=params,
+            wrt=wrt,
             consider_constant=[neg_v])
 
 ###########################################
@@ -256,6 +310,7 @@
       - a - the visible bias (theano shared variable)
       - b - the covariance bias (theano shared variable)
       - c - the mean bias (theano shared variable)
+
     """
     def __init__(self, U, W, a, b, c):
         self.U = U
@@ -329,84 +384,152 @@
         if n_visible is None:
             n_visible = self.n_visible_units()
         rval = HMC_sampler.new_from_shared_positions(
-            shared_positions = shared(
+            shared_positions = sharedX(
                 rng.randn(
                     n_particles,
-                    n_visible).astype(floatX),
+                    n_visible),
                 name='particles'),
             energy_fn=self.free_energy_given_v,
             seed=int(rng.randint(2**30)))
         return rval
 
     def as_feedforward_layer(self, v):
+        """Return a dictionary with keys: inputs, outputs and params
+
+        The inputs is [v]
+
+        The outputs is :math:`[E[h|v], E[g|v]]` where `h` is the covariance hidden units and `g` is
+        the mean hidden units.
+
+        The params are ``[U, W, b, c]``, the model parameters that enter into the conditional
+        expectations.
+
+        :TODO: add an optional parameter to return only one of the expections.
+
+        """
         return dict(
-                outputs = self.expected_h_g_given_v(v),
+                inputs = [v],
+                outputs = list(self.expected_h_g_given_v(v)),
                 params = [self.U, self.W, self.b, self.c],
                 )
 
     @classmethod
-    def alloc(cls, n_I, n_K, n_J, rng = 8923402190):
+    def alloc(cls, n_I, n_K, n_J, rng = 8923402190,
+            U_range=0.02,
+            W_range=0.05,
+            a_ival=0,
+            b_ival=2,
+            c_ival=-2):
         """
-        Return a MeanCovRBM instance with randomly-initialized parameters.
+        Return a MeanCovRBM instance with randomly-initialized shared variable parameters.
 
         :param n_I: input dimensionality
         :param n_K: number of covariance hidden units
         :param n_J: number of mean filters (linear)
-        :param rng: seed or numpy RandomState object to initialize params
+        :param rng: seed or numpy RandomState object to initialize parameters
+        
+        :note:
+        Constants for initial ranges and values taken from train_mcRBM.py.
         """
         if not hasattr(rng, 'randn'):
             rng = np.random.RandomState(rng)
 
-        def shrd(X,name):
-            return shared(X.astype(floatX), name=name)
+        rval =  cls(
+                U = sharedX(U_range * rng.randn(n_I, n_K),'U'),
+                W = sharedX(W_range * rng.randn(n_I, n_J),'W'),
+                a = sharedX(np.ones(n_I)*a_ival,'a'),
+                b = sharedX(np.ones(n_K)*b_ival,'b'),
+                c = sharedX(np.ones(n_J)*c_ival,'c'),)
 
-        # initialization taken from train_mcRBM.py
-        rval =  cls(
-                U = shrd(0.02 * rng.randn(n_I, n_K),'U'),
-                W = shrd(0.05 * rng.randn(n_I, n_J),'W'),
-                a = shrd(np.ones(n_I)*(0),'a'),
-                b = shrd(np.ones(n_K)*2,'b'),
-                c = shrd(np.ones(n_J)*(-2),'c'))
-
-        rval.params = [rval.U, rval.W, rval.a, rval.b, rval.c]
+        rval.params = lambda : [rval.U, rval.W, rval.a, rval.b, rval.c]
         return rval
 
 class mcRBMTrainer(object):
-    """
+    """Light-weight class encapsulating math for mcRBM training 
 
     Attributes:
-      - rbm 
-      - sampler
-      - normVF
-      - learn_rate
-      - learn_rate_multipliers
+      - rbm  - an mcRBM instance
+      - sampler - an HMC_sampler instance
+      - normVF - geometrically updated norm of U matrix columns (shared var)
+      - learn_rate - SGD learning rate [un-annealed]
+      - learn_rate_multipliers - the learning rates for each of the parameters of the rbm (in
+        order corresponding to what's returned by ``rbm.params()``)
+      - l1_penalty - float or TensorType scalar to modulate l1 penalty of rbm.U and rbm.W
+      - iter - number of cd_updates (shared var) - used to anneal the effective learn_rate
+      - lr_anneal_start - scalar or TensorType scalar - iter at which time to start decreasing
+            the learning rate proportional to 1/iter
 
     """
+    # TODO: accept a GD algo as an argument?
+    @classmethod
+    def alloc(cls, rbm, visible_batch, batchsize, initial_lr=0.075, rng=234,
+            l1_penalty=0,
+            learn_rate_multipliers=[2, .2, .02, .1, .02],
+            lr_anneal_start=2000,
+            ):
+
+        """
+        :param rbm: mcRBM instance to train
+        :param visible_batch: TensorType variable for training data
+        :param batchsize: the number of rows in visible_batch
+        :param initial_lr: the learning rate (may be annealed)
+        :param rng: seed or RandomState to initialze PCD sampler
+        :param l1_penalty: see class doc
+        :param learn_rate_multipliers: see class doc
+        :param lr_anneal_start: see class doc
+        """
+        #TODO: :param lr_anneal_iter: the iteration at which 1/t annealing will begin
+
+        #TODO: get batchsize from visible_batch??
+        # allocates shared var for negative phase particles
+
+
+        # TODO: should normVF be initialized to match the size of rbm.U ?
+
+        return cls(
+                rbm=rbm,
+                visible_batch=visible_batch,
+                sampler=rbm.sampler(batchsize, rng=rng),
+                normVF=sharedX(1.0, 'normVF'),
+                learn_rate=sharedX(initial_lr/batchsize, 'learn_rate'),
+                iter=sharedX(0, 'iter'),
+                l1_penalty=l1_penalty,
+                learn_rate_multipliers=learn_rate_multipliers,
+                lr_anneal_start=lr_anneal_start)
+
     def __init__(self, **kwargs):
         self.__dict__.update(kwargs)
 
     def normalize_U(self, new_U):
-        #TODO: write the docstring
+        """
+        :param new_U: a proposed new value for rbm.U
+
+        :returns: a pair of TensorType variables: 
+            a corrected new value for U, and a new value for self.normVF
+
+        This is a weird normalization procedure, but the sample code for the paper has it, and
+        it seems to be important.
+        """
         U_norms = TT.sqrt((new_U**2).sum(axis=0))
         new_normVF = .95 * self.normVF + .05 * TT.mean(U_norms)
-        return (new_U * this_normVF / U_norms), new_normVF
+        return (new_U * new_normVF / U_norms), new_normVF
 
-    def contrastive_grads(self, visible_batch, params=None):
-        if params is not None:
-            params = self.rbm.params
+    def contrastive_grads(self):
+        """Return the contrastive divergence gradients on the parameters of self.rbm """
         return contrastive_grad(
                 free_energy_fn=self.rbm.free_energy_given_v,
-                pos_v=visible_batch, 
+                pos_v=self.visible_batch, 
                 neg_v=self.sampler.positions,
-                params=params,
+                wrt = self.rbm.params(),
                 other_cost=(l1(self.rbm.U)+l1(self.rbm.W)) * self.l1_penalty)
 
+    def cd_updates(self):
+        """
+        Return a dictionary of shared variable updates that implements contrastive divergence
+        learning by stochastic gradient descent with an annealed learning rate.
+        """
 
-    def cd_updates(self, visible_batch, params=None, rng=89234):
-        if params is not None:
-            params = self.rbm.params
-
-        grads = self.contrastive_grads(visible_batch, params)
+        grads = self.contrastive_grads()
 
         # contrastive divergence updates
         # TODO: sgd_updates is a particular optization algo (others are possible)
@@ -416,44 +539,26 @@
         # TODO: when sgd has an annealing schedule, this should
         #       go through that mechanism.
 
-        # TODO: parametrize these constants (e.g. 2000)
-
-        ups[self.iter] = self.iter + 1
         lr = TT.clip(
-                self.learn_rate * 2000 / (self.iter+1), 
+                self.learn_rate * TT.cast(self.lr_anneal_start / (self.iter+1), floatX), 
                 0.0, #min
                 self.learn_rate) #max
 
-        ups = sgd_updates(
-                    params,
+        ups = dict(sgd_updates(
+                    self.rbm.params(),
                     grads,
-                    stepsizes=[a*lr for a in learn_rate_multipliers])
+                    stepsizes=[a*lr for a in self.learn_rate_multipliers]))
+
+        ups[self.iter] = self.iter + 1
 
         # sampler updates
         ups.update(dict(self.sampler.updates()))
 
         # add trainer updates (replace CD update of U)
-        ups[self.rbm.U], ups[self.normVF] = self.normalize_U(ups[U])
+        ups[self.rbm.U], ups[self.normVF] = self.normalize_U(ups[self.rbm.U])
 
         return ups
 
-    # TODO: accept a GD algo as an argument?
-    @classmethod
-    def alloc(cls, rbm, visible_batch, batchsize, initial_lr=0.075, rng=234,
-            l1_penalty=0,
-            learn_rate_multipliers=[2, .2, .02, .1, .02]):
-        # allocates shared var for negative phase particles
-
-        return cls(
-                rbm=rbm,
-                sampler=rbm.sampler(batchsize, rng=rng),
-                normVF=shared(1.0, 'normVF'),
-                learn_rate=shared(initial_lr/batchsize, 'learn_rate'),
-                iter=shared(0, 'iter'),
-                l1_penalty=l1_penalty,
-                learn_rate_multipliers=learn_rate_multipliers)
-
-
 if __name__ == '__main__':
     import pylearn.algorithms.tests.test_mcRBM
     pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010(as_unittest=True)
--- a/pylearn/algorithms/tests/test_mcRBM.py	Sat Sep 04 19:31:16 2010 -0400
+++ b/pylearn/algorithms/tests/test_mcRBM.py	Sat Sep 04 19:32:27 2010 -0400
@@ -60,10 +60,10 @@
     rbm=trainer.rbm
     smplr = trainer.sampler
 
-    grads = trainer.contrastive_grads(train_batch)
+    grads = trainer.contrastive_grads()
     learn_fn = function([batch_idx, trainer.l1_penalty], 
             outputs=[grads[0].norm(2), grads[0].norm(2), grads[1].norm(2)],
-            updates=trainer.cd_updates(train_batch))
+            updates=trainer.cd_updates())
 
     print "Learning..."
     last_epoch = -1
@@ -121,16 +121,7 @@
             print 'arate', smplr.avg_acceptance_rate
 
 
-        if 0:
-            # Continue HMC chain
-            smplr.simulate()
-
-            # Do CD update
-            l2_of_Ugrad = learn_fn(jj, 
-                    lr/max(1, jj/(20*epoch_size/batchsize)),
-                    effective_l1_penalty)
-
-        learn_fn(jj, effective_l1_penalty)
+        l2_of_Ugrad = learn_fn(jj, effective_l1_penalty)
 
         if print_jj:
             print 'l2(U_grad)', float(l2_of_Ugrad[0]),
@@ -146,12 +137,3 @@
             print "Activating L1 weight decay"
             effective_l1_penalty = 1e-3
 
-        # weird normalization technique...
-        # It constrains all the columns of the matrix to have the same length
-        if 0:
-            U = rbm.U.value
-            U_norms = np.sqrt((U*U).sum(axis=0))
-            assert len(U_norms) == n_K
-            normVF = .95 * normVF + .05 * np.mean(U_norms)
-            rbm.U.value = rbm.U.value * normVF/U_norms
-