view pylearn/algorithms/mcRBM.py @ 1000:d4a14c6c36e0

mcRBM - post code-review #1 with Guillaume
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 24 Aug 2010 19:24:54 -0400
parents c6d08a760960
children 075c193afd1b
line wrap: on
line source

"""
This file implements the Mean & Covariance RBM discussed in 

    Ranzato, M. and Hinton, G. E. (2010)
    Modeling pixel means and covariances using factored third-order Boltzmann machines.
    IEEE Conference on Computer Vision and Pattern Recognition.

and performs one of the experiments on CIFAR-10 discussed in that paper.  There are some minor
discrepancies between the paper and the accompanying code (train_mcRBM.py), and the
accompanying code has been taken to be correct in those cases because I couldn't get things to
work otherwise.


Math
====

Energy of "covariance RBM"

    E = -0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i C_{if} v_i )^2
      = -0.5 \sum_f (\sum_k P_{fk} h_k) ( \sum_i C_{if} v_i )^2
                    "vector element f"           "vector element f"

In some parts of the paper, the P matrix is chosen to be a diagonal matrix with non-positive
diagonal entries, so it is helpful to see this as a simpler equation:

    E =  \sum_f h_f ( \sum_i C_{if} v_i )^2



Version in paper
----------------

Full Energy of the Mean and Covariance RBM, with 
:math:`h_k = h_k^{(c)}`,
:math:`g_j = h_j^{(m)}`,
:math:`b_k = b_k^{(c)}`,
:math:`c_j = b_j^{(m)}`,
:math:`U_{if} = C_{if}`,

    E (v, h, g) =
        - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i (U_{if} v_i) / |U_{.f}|*|v| )^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

For the energy function to correspond to a probability distribution, P must be non-positive.  P
is initialized to be a diagonal, and in our experience it can be left as such because even in
the paper it has a very low learning rate, and is only allowed to be updated after the filters
in U are learned (in effect).

Version in published train_mcRBM code
-------------------------------------

The train_mcRBM file implements learning in a similar but technically different Energy function:

    E (v, h, g) =
        - 0.5 \sum_f \sum_k P_{fk} h_k (\sum_i U_{if} v_i / sqrt(\sum_i v_i^2/I + 0.5))^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

There are two differences with respect to the paper:

    - 'v' is not normalized by its length, but rather it is normalized to have length close to
      the square root of the number of its components.  The variable called 'small' that
      "avoids division by zero" is orders larger than machine precision, and is on the order of
      the normalized sum-of-squares, so I've included it in the Energy function.

    - 'U' is also not normalized by its length.  U is initialized to have columns that are
      shorter than unit-length (approximately 0.2 with the 105 principle components in the
      train_mcRBM data).  During training, the columns of U are constrained manually to have
      equal lengths (see the use of normVF), but Euclidean norm is allowed to change.  During
      learning it quickly converges towards 1 and then exceeds 1.  It does not seem like this
      column-wise normalization of U is justified by maximum-likelihood, I have no intuition
      for why it is used.


Version in this code
--------------------

This file implements the same algorithm as the train_mcRBM code, except that the P matrix is
omitted for clarity, and replaced analytically with a negative identity matrix.

    E (v, h, g) =
        + 0.5 \sum_k h_k (\sum_i U_{ik} v_i / sqrt(\sum_i v_i^2/I + 0.5))^2 
        - \sum_k b_k h_k
        + 0.5 \sum_i v_i^2
        - \sum_j \sum_i W_{ij} g_j v_i
        - \sum_j c_j g_j

      

Conventions in this file
========================

This file contains some global functions, as well as a class (MeanCovRBM) that makes using them a little
more convenient.


Global functions like `free_energy` work on an mcRBM as parametrized in a particular way.
Suppose we have 
I input dimensions, 
F squared filters, 
J mean variables, and 
K covariance variables.
The mcRBM is parametrized by 5 variables:

 - `U`, a matrix whose rows are visible covariance directions (I x F)
 - `W`, a matrix whose rows are visible mean directions (I x J)
 - `b`, a vector of hidden covariance biases (K)
 - `c`, a vector of hidden mean biases  (J)

Matrices are generally layed out and accessed according to a C-order convention.

"""

#
# WORKING NOTES
# THIS DERIVATION IS BASED ON THE ** PAPER ** ENERGY FUNCTION
# NOT THE ENERGY FUNCTION IN THE CODE!!!
#
# Free energy is the marginal energy of visible units
# Recall: 
#   Q(x) = exp(-E(x))/Z ==> -log(Q(x)) - log(Z) = E(x)
#
#
#   E (v, h, g) =
#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / |U_{*f}|^2 |v|^2
#       - \sum_k b_k h_k
#       + 0.5 \sum_i v_i^2
#       - \sum_j \sum_i W_{ij} g_j v_i
#       - \sum_j c_j g_j
#       - \sum_i a_i v_i
#
#
# Derivation, in which partition functions are ignored.
#
# E(v) = -\log(Q(v))
#  = -\log( \sum_{h,g} Q(v,h,g))
#  = -\log( \sum_{h,g} exp(-E(v,h,g)))
#  = -\log( \sum_{h,g} exp(-
#       - 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
#       - \sum_k b_k h_k
#       + 0.5 \sum_i v_i^2
#       - \sum_j \sum_i W_{ij} g_j v_i
#       - \sum_j c_j g_j 
#       - \sum_i a_i v_i ))
#
# Get rid of double negs  in exp
#  = -\log(  \sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}| * |v|)
#       + \sum_k b_k h_k
#       - 0.5 \sum_i v_i^2
#       ) * \sum_{g} exp(
#       + \sum_j \sum_i W_{ij} g_j v_i
#       + \sum_j c_j g_j))
#    - \sum_i a_i v_i 
#
# Break up log
#  = -\log(  \sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)
#       + \sum_k b_k h_k
#       ))
#    -\log( \sum_{g} exp(
#       + \sum_j \sum_i W_{ij} g_j v_i
#       + \sum_j c_j g_j )))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
# Use domain h is binary to turn log(sum(exp(sum...))) into sum(log(..
#  = -\log(\sum_{h} exp(
#       + 0.5 \sum_f \sum_k P_{fk} h_k ( \sum_i U_{if} v_i )^2 / (|U_{*f}|* |v|)
#       + \sum_k b_k h_k
#       ))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
#  = - \sum_{k} \log(1 + exp(b_k + 0.5 \sum_f P_{fk}( \sum_i U_{if} v_i )^2 / (|U_{*f}|*|v|)))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 
#
# For negative-one-diagonal P this gives:
#
#  = - \sum_{k} \log(1 + exp(b_k - 0.5 \sum_i (U_{ik} v_i )^2 / (|U_{*k}|*|v|)))
#    - \sum_{j} \log(1 + exp(\sum_i W_{ij} v_i + c_j ))
#    + 0.5 \sum_i v_i^2
#    - \sum_i a_i v_i 

import sys, os, logging
import numpy as np
import numpy

import theano
from theano import function, shared, dot
from theano import tensor as TT
floatX = theano.config.floatX

import pylearn
#TODO: clean up the HMC_sampler code
#TODO: think of naming convention for acronyms + suffix?
from pylearn.sampling.hmc import HMC_sampler
from pylearn.io import image_tiling
from pylearn.gd.sgd import sgd_updates
import pylearn.dataset_ops.image_patches

###########################################
#
# Candidates for factoring
#
###########################################

#TODO: Document, move to pylearn's math lib
def l1(X):
    return abs(X).sum()

#TODO: Document, move to pylearn's math lib
def l2(X):
    return TT.sqrt((X**2).sum())

#TODO: Document, move to pylearn's math lib
def contrastive_cost(free_energy_fn, pos_v, neg_v):
    return (free_energy_fn(pos_v) - free_energy_fn(neg_v)).sum()

#TODO: Typical use of contrastive_cost is to later use tensor.grad, but in that case we want to
#      block  gradient going through neg_v
def contrastive_grad(free_energy_fn, pos_v, neg_v, params, other_cost=0):
    """
    :param pos_v: positive-phase sample of visible units
    :param neg_v: negative-phase sample of visible units
    """
    #block the grad through neg_v
    cost=contrastive_cost(free_energy_fn, pos_v, neg_v)
    if other_cost:
        cost = cost + other_cost
    return theano.tensor.grad(cost,
            wrt=params,
            consider_constant=[neg_v])

###########################################
#
# Expressions that are mcRBM-specific
#
###########################################

#TODO: make global function to initialize parameter

def hidden_cov_units_preactivation_given_v(rbm, v, small=0.5):
    """Return argument to the sigmoid that would give mean of covariance hid units

    See the math at the top of this file for what 'adjusted' means.

    return b - 0.5 * dot(adjusted(v), U)**2
    """
    (U,W,a,b,c) = rbm
    unit_v = v / (TT.sqrt(TT.mean(v**2, axis=1)+small)).dimshuffle(0,'x') # adjust row norm
    return b - 0.5 * dot(unit_v, U)**2

def free_energy_terms_given_v(rbm, v):
    """Returns theano expression for the terms that are added to form the free energy of
    visible vector `v` in an mcRBM.

     1.  Free energy related to covariance hiddens
     2.  Free energy related to mean hiddens
     3.  Free energy related to L2-Norm of `v`
     4.  Free energy related to projection of `v` onto biases `a`
    """
    U, W, a, b, c = rbm
    t0 = -TT.sum(TT.nnet.softplus(hidden_cov_units_preactivation_given_v(rbm, v)),axis=1)
    t1 = -TT.sum(TT.nnet.softplus(c + dot(v,W)), axis=1)
    t2 =  0.5 * TT.sum(v**2, axis=1)
    t3 = -TT.dot(v, a)
    return [t0, t1, t2, t3]

def free_energy_given_v(rbm, v):
    """Returns theano expression for free energy of visible vector `v` in an mcRBM
    """
    return sum(free_energy_terms_given_v(rbm,v))

def expected_h_g_given_v(rbm, v):
    """Returns tuple (`h`, `g`) of theano expression conditional expectations in an mcRBM.

    `h` is the conditional on the covariance units.
    `g` is the conditional on the mean units.
    
    """
    (U, W, a, b, c) = rbm
    h = TT.nnet.sigmoid(hidden_cov_units_preactivation_given_v(rbm, v))
    g = nnet.sigmoid(c + dot(v,W))
    return (h, g)

def n_visible_units(rbm):
    """Return the number of visible units of this RBM

    For an RBM made from shared variables, this will return an integer,
    for a purely symbolic RBM this will return a theano expression.
    
    """
    W = rbm[1]
    try:
        return W.value.shape[0]
    except AttributeError:
        return W.shape[0]

def sampler(rbm, n_particles, n_visible=None, rng=7823748):
    """Return an `HMC_sampler` that will draw samples from the distribution over visible
    units specified by this RBM.

    :param n_particles: this many parallel chains will be simulated.
    :param rng: seed or numpy RandomState object to initialize particles, and to drive the simulation.
    """
    if not hasattr(rng, 'randn'):
        rng = np.random.RandomState(rng)
    if n_visible is None:
        n_visible = n_visible_units(rbm)
    rval = HMC_sampler(
        positions = [shared(
            rng.randn(
                n_particles,
                n_visible).astype(floatX),
            name='particles')],
        energy_fn = lambda p : free_energy_given_v(rbm, p[0]),
        seed=int(rng.randint(2**30)))
    return rval

#############################
#
# Convenient data container
#
#############################

class MeanCovRBM(object):
    """Container for mcRBM parameters

    It provides parameter lookup by name, as well as a heuristic for initializing the
    parameters for effective learning.

    Attributes:

      - U - the covariance filters (theano shared variable)
      - W - the mean filters (theano shared variable)
      - a - the visible bias (theano shared variable)
      - b - the covariance bias (theano shared variable)
      - c - the mean bias (theano shared variable)
    
    """

    params = property(lambda s: [s.U, s.W, s.a, s.b, s.c])

    n_visible = property(lambda s: s.W.value.shape[0])

    def __init__(self, U, W, a, b, c):
        self.__dict__.update(locals())
        del self.__dict__['self']

    def __getitem__(self, idx):
        # support unpacking of this container as if it were a tuple
        return self.params[idx]

    @classmethod
    def new_from_dims(cls, n_I, n_K, n_J, rng = 8923402190):
        """
        Return a MeanCovRBM instance with randomly-initialized parameters.

        :param n_I: input dimensionality
        :param n_K: number of covariance hidden units
        :param n_J: number of mean filters (linear)
        :param rng: seed or numpy RandomState object to initialize params
        """
        if not hasattr(rng, 'randn'):
            rng = np.random.RandomState(rng)

        def shrd(X,name):
            return shared(X.astype(floatX), name=name)

        # initialization taken from train_mcRBM.py
        return cls(
                U = shrd(0.02 * rng.randn(n_I, n_K),'U'),
                W = shrd(0.05 * rng.randn(n_I, n_J),'W'),
                a = shrd(np.ones(n_I)*(0),'a'),
                b = shrd(np.ones(n_K)*2,'b'),
                c = shrd(np.ones(n_J)*(-2),'c'))

    def __getstate__(self):
        # unpack shared containers, which may have references to Theano stuff and are not a
        # long-term stable data type.
        return dict(
                U = self.U.value,
                W = self.W.value,
                b = self.b.value,
                c = self.c.value)

    def __setstate__(self, dct):
        d = dict(dct)
        for key in ['U', 'W', 'a', 'b', 'c']:
            d[key] = shared(d[key], name=key)
        self.__init__(**d)


#TODO: put the normalization of U as a global function


#TODO: put the learning loop as a global function or class, so that someone could load and *TRAIN* an mcRBM!!!

if __name__ == '__main__':
    import pylearn.algorithms.tests.test_mcRBM
    pylearn.algorithms.tests.test_mcRBM.test_reproduce_ranzato_hinton_2010(as_unittest=True)