view pylearn/algorithms/mcRBM.py @ 1270:d38cb039c662

debugging mcRBM
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 03 Sep 2010 15:05:31 -0400
parents 075c193afd1b
children ba25c6e4f55d
line wrap: on
line source

"""
This file implements the Mean & Covariance RBM discussed in 

    Ranzato, M. and Hinton, G. E. (2010)
    Modeling pixel means and covariances using factored third-order Boltzmann machines.
    IEEE Conference on Computer Vision and Pattern Recognition.

and performs one of the experiments on CIFAR-10 discussed in that paper.  There are some minor
discrepancies between the paper and the accompanying code (train_mcRBM.py), and the
accompanying code has been taken to be correct in those cases because I couldn't get things to
work otherwise.


Math
====

Energy of "covariance RBM"

    E = -0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i C_{if} v_i )^2
      = -0.5 \sum_f (\sum_k P_{fk} h_k) ( \sum_i C_{if} v_i )^2
                    "vector element f"           "vector element f"

In some parts of the paper, the P matrix is chosen to be a diagonal matrix with non-positive
diagonal entries, so it is helpful to see this as a simpler equation:

    E =  \sum_f h_f ( \sum_i C_{if} v_i )^2



Version in paper
----------------

Full Energy of the Mean and Covariance RBM, with 
:math:`h_k = h_k^{(c)}`,
:math:`g_j = h_j^{(m)}`,
:math:`b_k = b_k^{(c)}`,
:math:`c_j = b_j^{(m)}`,
:math:`U_{if} = C_{if}`,

    E (v, h, g) =
        - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i (U_{if} v_i) / |U_{.f}|*|v| )^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

For the energy function to correspond to a probability distribution, P must be non-positive.  P
is initialized to be a diagonal, and in our experience it can be left as such because even in
the paper it has a very low learning rate, and is only allowed to be updated after the filters
in U are learned (in effect).

Version in published train_mcRBM code
-------------------------------------

The train_mcRBM file implements learning in a similar but technically different Energy function:

    E (v, h, g) =
        - 0.5 \sum_f \sum_k P_{fk} h_k (\sum_i U_{if} v_i / sqrt(\sum_i v_i^2/I + 0.5))^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

There are two differences with respect to the paper:

    - 'v' is not normalized by its length, but rather it is normalized to have length close to
      the square root of the number of its components.  The variable called 'small' that
      "avoids division by zero" is orders larger than machine precision, and is on the order of
      the normalized sum-of-squares, so I've included it in the Energy function.

    - 'U' is also not normalized by its length.  U is initialized to have columns that are
      shorter than unit-length (approximately 0.2 with the 105 principle components in the
      train_mcRBM data).  During training, the columns of U are constrained manually to have
      equal lengths (see the use of normVF), but Euclidean norm is allowed to change.  During
      learning it quickly converges towards 1 and then exceeds 1.  It does not seem like this
      column-wise normalization of U is justified by maximum-likelihood, I have no intuition
      for why it is used.


Version in this code
--------------------

This file implements the same algorithm as the train_mcRBM code, except that the P matrix is
omitted for clarity, and replaced analytically with a negative identity matrix.

    E (v, h, g) =
        + 0.5 \sum_k h_k (\sum_i U_{ik} v_i / sqrt(\sum_i v_i^2/I + 0.5))^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

      

Conventions in this file
========================

This file contains some global functions, as well as a class (MeanCovRBM) that makes using them a little
more convenient.


Global functions like `free_energy` work on an mcRBM as parametrized in a particular way.
Suppose we have 
I input dimensions, 
F squared filters, 
J mean variables, and 
K covariance variables.
The mcRBM is parametrized by 5 variables:

 - `U`, a matrix whose rows are visible covariance directions (I x F)
 - `W`, a matrix whose rows are visible mean directions (I x J)
 - `b`, a vector of hidden covariance biases (K)
 - `c`, a vector of hidden mean biases  (J)

Matrices are generally layed out and accessed according to a C-order convention.

"""

#
# WORKING NOTES
# THIS DERIVATION IS BASED ON THE ** PAPER ** ENERGY FUNCTION
# NOT THE ENERGY FUNCTION IN THE CODE!!!
#
# Free energy is the marginal energy of visible units
# Recall: 
#   Q(x) = exp(-E(x))/Z ==> -log(Q(x)) - log(Z) = E(x)
#
#
#   E (v, h, g) =
#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / |U_{*f}|^2 |v|^2
#       - \sum_k b_k h_k
#       + 0.5 \sum_i v_i^2
#       - \sum_j \sum_i W_{ij} g_j v_i
#       - \sum_j c_j g_j
#       - \sum_i a_i v_i
#
#
# Derivation, in which partition functions are ignored.
#
# E(v) = -\log(Q(v))
#  = -\log( \sum_{h,g} Q(v,h,g))
#  = -\log( \sum_{h,g} exp(-E(v,h,g)))
#  = -\log( \sum_{h,g} exp(-
#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
#       - \sum_k b_k h_k
#       + 0.5 \sum_i v_i^2
#       - \sum_j \sum_i W_{ij} g_j v_i
#       - \sum_j c_j g_j 
#       - \sum_i a_i v_i ))
#
# Get rid of double negs  in exp
#  = -\log(  \sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
#       + \sum_k b_k h_k
#       - 0.5 \sum_i v_i^2
#       ) * \sum_{g} exp(
#       + \sum_j \sum_i W_{ij} g_j v_i
#       + \sum_j c_j g_j))
#    - \sum_i a_i v_i 
#
# Break up log
#  = -\log(  \sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)
#       + \sum_k b_k h_k
#       ))
#    -\log( \sum_{g} exp(
#       + \sum_j \sum_i W_{ij} g_j v_i
#       + \sum_j c_j g_j )))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
# Use domain h is binary to turn log(sum(exp(sum...))) into sum(log(..
#  = -\log(\sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|* |v|)
#       + \sum_k b_k h_k
#       ))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
#  = - \sum_{k} \log(1 + exp(b_k + 0.5 \sum_f P_{fk}( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
# For negative-one-diagonal P this gives:
#
#  = - \sum_{k} \log(1 + exp(b_k - 0.5 \sum_i (U_{ik} v_i )^2 / (|U_{*k}|*|v|)))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 

import sys, os, logging
import numpy as np
import numpy

import theano
from theano import function, shared, dot
from theano import tensor as TT
floatX = theano.config.floatX

import pylearn
#TODO: clean up the HMC_sampler code
#TODO: think of naming convention for acronyms + suffix?
from pylearn.sampling.hmc import HMC_sampler
from pylearn.io import image_tiling
from pylearn.gd.sgd import sgd_updates
import pylearn.dataset_ops.image_patches

###########################################
#
# Candidates for factoring
#
###########################################

#TODO: Document, move to pylearn's math lib
def l1(X):
    return abs(X).sum()

#TODO: Document, move to pylearn's math lib
def l2(X):
    return TT.sqrt((X**2).sum())

#TODO: Document, move to pylearn's math lib
def contrastive_cost(free_energy_fn, pos_v, neg_v):
    return (free_energy_fn(pos_v) - free_energy_fn(neg_v)).sum()

#TODO: Typical use of contrastive_cost is to later use tensor.grad, but in that case we want to
#      block  gradient going through neg_v
def contrastive_grad(free_energy_fn, pos_v, neg_v, params, other_cost=0):
    """
    :param pos_v: positive-phase sample of visible units
    :param neg_v: negative-phase sample of visible units
    """
    #block the grad through neg_v
    cost=contrastive_cost(free_energy_fn, pos_v, neg_v)
    if other_cost:
        cost = cost + other_cost
    return theano.tensor.grad(cost,
            wrt=params,
            consider_constant=[neg_v])

###########################################
#
# Expressions that are mcRBM-specific
#
###########################################

class mcRBM(object):
    """Light-weight class that provides the math related to inference

    Attributes:

      - U - the covariance filters (theano shared variable)
      - W - the mean filters (theano shared variable)
      - a - the visible bias (theano shared variable)
      - b - the covariance bias (theano shared variable)
      - c - the mean bias (theano shared variable)
    """
    def __init__(self, U, W, a, b, c):
        self.U = U
        self.W = W
        self.a = a
        self.b = b
        self.c = c

    def hidden_cov_units_preactivation_given_v(self, v, small=0.5):
        """Return argument to the sigmoid that would give mean of covariance hid units

        See the math at the top of this file for what 'adjusted' means.

        return b - 0.5 * dot(adjusted(v), U)**2
        """
        unit_v = v / (TT.sqrt(TT.mean(v**2, axis=1)+small)).dimshuffle(0,'x') # adjust row norm
        return self.b - 0.5 * dot(unit_v, self.U)**2

    def free_energy_terms_given_v(self, v):
        """Returns theano expression for the terms that are added to form the free energy of
        visible vector `v` in an mcRBM.

         1.  Free energy related to covariance hiddens
         2.  Free energy related to mean hiddens
         3.  Free energy related to L2-Norm of `v`
         4.  Free energy related to projection of `v` onto biases `a`
        """
        t0 = -TT.sum(TT.nnet.softplus(self.hidden_cov_units_preactivation_given_v(v)),axis=1)
        t1 = -TT.sum(TT.nnet.softplus(self.c + dot(v,self.W)), axis=1)
        t2 =  0.5 * TT.sum(v**2, axis=1)
        t3 = -TT.dot(v, self.a)
        return [t0, t1, t2, t3]

    def free_energy_given_v(self, v):
        """Returns theano expression for free energy of visible vector `v` in an mcRBM
        """
        return TT.add(*self.free_energy_terms_given_v(v))

    def expected_h_g_given_v(self, v):
        """Returns tuple (`h`, `g`) of theano expression conditional expectations in an mcRBM.

        `h` is the conditional on the covariance units.
        `g` is the conditional on the mean units.
        
        """
        h = TT.nnet.sigmoid(self.hidden_cov_units_preactivation_given_v(v))
        g = nnet.sigmoid(self.c + dot(v,self.W))
        return (h, g)

    def n_visible_units(self):
        """Return the number of visible units of this RBM

        For an RBM made from shared variables, this will return an integer,
        for a purely symbolic RBM this will return a theano expression.
        
        """
        try:
            return self.W.value.shape[0]
        except AttributeError:
            return self.W.shape[0]

    def sampler(self, n_particles, n_visible=None, rng=7823748):
        """Return an `HMC_sampler` that will draw samples from the distribution over visible
        units specified by this RBM.

        :param n_particles: this many parallel chains will be simulated.
        :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation.
        """
        if not hasattr(rng, 'randn'):
            rng = np.random.RandomState(rng)
        if n_visible is None:
            n_visible = self.n_visible_units()
        rval = HMC_sampler.new_from_shared_positions(
            shared_positions = shared(
                rng.randn(
                    n_particles,
                    n_visible).astype(floatX),
                name='particles'),
            energy_fn=self.free_energy_given_v,
            seed=int(rng.randint(2**30)))
        return rval

    def as_feedforward_layer(self, v):
        return dict(
                outputs = self.expected_h_g_given_v(v),
                params = [self.U, self.W, self.b, self.c],
                )

    @classmethod
    def alloc(cls, n_I, n_K, n_J, rng = 8923402190):
        """
        Return a MeanCovRBM instance with randomly-initialized parameters.

        :param n_I: input dimensionality
        :param n_K: number of covariance hidden units
        :param n_J: number of mean filters (linear)
        :param rng: seed or numpy RandomState object to initialize params
        """
        if not hasattr(rng, 'randn'):
            rng = np.random.RandomState(rng)

        def shrd(X,name):
            return shared(X.astype(floatX), name=name)

        # initialization taken from train_mcRBM.py
        rval =  cls(
                U = shrd(0.02 * rng.randn(n_I, n_K),'U'),
                W = shrd(0.05 * rng.randn(n_I, n_J),'W'),
                a = shrd(np.ones(n_I)*(0),'a'),
                b = shrd(np.ones(n_K)*2,'b'),
                c = shrd(np.ones(n_J)*(-2),'c'))

        rval.params = [rval.U, rval.W, rval.a, rval.b, rval.c]
        return rval

class mcRBMTrainer(object):
    """

    Attributes:
      - rbm 
      - sampler
      - normVF
      - learn_rate
      - learn_rate_multipliers

    """
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def normalize_U(self, new_U):
        #TODO: write the docstring
        U_norms = TT.sqrt((new_U**2).sum(axis=0))
        new_normVF = .95 * self.normVF + .05 * TT.mean(U_norms)
        return (new_U * this_normVF / U_norms), new_normVF

    def contrastive_grads(self, visible_batch, params=None):
        if params is not None:
            params = self.rbm.params
        return contrastive_grad(
                free_energy_fn=self.rbm.free_energy_given_v,
                pos_v=visible_batch, 
                neg_v=self.sampler.positions,
                params=params,
                other_cost=(l1(self.rbm.U)+l1(self.rbm.W)) * self.l1_penalty)


    def cd_updates(self, visible_batch, params=None, rng=89234):
        if params is not None:
            params = self.rbm.params

        grads = self.contrastive_grads(visible_batch, params)

        # contrastive divergence updates
        # TODO: sgd_updates is a particular optization algo (others are possible)
        #       parametrize so that algo is plugin
        #       the normalization normVF might be sgd-specific though...

        # TODO: when sgd has an annealing schedule, this should
        #       go through that mechanism.

        # TODO: parametrize these constants (e.g. 2000)

        ups[self.iter] = self.iter + 1
        lr = TT.clip(
                self.learn_rate * 2000 / (self.iter+1), 
                0.0, #min
                self.learn_rate) #max

        ups = sgd_updates(
                    params,
                    grads,
                    stepsizes=[a*lr for a in learn_rate_multipliers])

        # sampler updates
        ups.update(dict(self.sampler.updates()))

        # add trainer updates (replace CD update of U)
        ups[self.rbm.U], ups[self.normVF] = self.normalize_U(ups[U])

        return ups

    # TODO: accept a GD algo as an argument?
    @classmethod
    def alloc(cls, rbm, visible_batch, batchsize, initial_lr=0.075, rng=234,
            l1_penalty=0,
            learn_rate_multipliers=[2, .2, .02, .1, .02]):
        # allocates shared var for negative phase particles

        return cls(
                rbm=rbm,
                sampler=rbm.sampler(batchsize, rng=rng),
                normVF=shared(1.0, 'normVF'),
                learn_rate=shared(initial_lr/batchsize, 'learn_rate'),
                iter=shared(0, 'iter'),
                l1_penalty=l1_penalty,
                learn_rate_multipliers=learn_rate_multipliers)


if __name__ == '__main__':
    import pylearn.algorithms.tests.test_mcRBM
    pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010(as_unittest=True)