view pylearn/shared/layers/rust2005.py @ 891:fc9779dcd710

some backport to python 2.4
author Frederic Bastien <nouiz@nouiz.org>
date Tue, 05 Jan 2010 10:48:59 -0500
parents 2a8a7ce78c12
children 912be602c3ac
line wrap: on
line source

""" Provides Rust2005 layer

Paper: 

This layer implements a model of simple and complex cell firing rate responses.


:TODO: implement full model with variable exponents.  The current implementation fixes internal
exponents to 2 and the external exponent to 1/2.

:TODO: add the weights on the quadratic filter
responses.  E and S are supposed to be the "square root of a *weighted* sum of squares".
The simplifications here are probably useful, but make them optional, or a different class or something.
- The current implementation can be interpreted as including the weights inside the filters.
  The filters are not constrained to have a unit norm, for example.

"""

import numpy
try:
    from PIL import Image
except:
    pass

import theano
import theano.tensor
import theano.tensor.nnet
from theano.compile.sandbox import shared
from theano.sandbox.softsign import softsign
from theano.tensor.nnet import softplus
from theano.sandbox.conv import ConvOp

from pylearn.shared.layers.util import update_locals, add_logging

def rust2005_act_from_filters(linpart, E_quad, S_quad, eps):
    """Return rust2005 activation from linear filter responses, as well as E and S terms

    :param linpart: a single tensor of linear filter responses
    :param E_quad: a list of tensors of linear filter responses
    :param S_quad: a list of tensors of linear filter responses
    :param eps: a scalar to add to the sum of squares before the sqrt

    """
    if isinstance(E_quad, theano.Variable):
        raise TypeError('E_quad should be a list of Variables, not a Variable itself')
    if isinstance(S_quad, theano.Variable):
        raise TypeError('E_quad should be a list of Variables, not a Variable itself')
    sqrt = theano.tensor.sqrt
    softlin = theano.tensor.nnet.softplus(linpart)
    E = sqrt(sum([E_quad_i**2 for E_quad_i in E_quad] + [softlin**2], eps))
    S = sqrt(sum([S_quad_i**2 for S_quad_i in S_quad], eps))
    return (E-S) / (1+E+S), E, S

class Rust2005(object):
    """Energy-like complex cell activation function described in Rust et al. 2005 """
    #logging methods come from the add_logging() call below
    # _info, _debug, _warn, _error, _fatal

    def __init__(self, input, w, b, n_out, n_E_quadratic, n_S_quadratic,
            epsilon, filter_shape, params):
        """
        w should be a matrix with input.shape[1] rows, and n_out *
        (1+n_E_quadratic+n_S_quadratic) columns.

        Every successive block of (1+n_E_quadratic+n_S_quadratic) adjacent columns contributes
        to the computation of one output features.  The first column in the block is the filter
        for the linear term.  The following n_E_quadratic columns are used to compute the
        exciting quadratic part.  The following n_S_quadratic columns are used to compute the
        inhibitory part.
        """
        if w.dtype != input.dtype:
            self._warn('WARNING w type mismatch', input.dtype, w.dtype, b.dtype)
        if b.dtype != input.dtype:
            self._warn( 'WARNING b type mismatch', input.dtype, w.dtype, b.dtype)
        #when each column of w corresponds to a flattened shape, put it here.
        # filter_shape is used for rendering weights as tiled images

        filter_responses = theano.dot(input, w).reshape((
                input.shape[0],
                n_out, 
                1 + n_E_quadratic + n_S_quadratic))

        assert filter_responses.dtype == input.dtype
        Lf = filter_responses[:, :, 0]
        Ef = filter_responses[:,:, 1:1+n_E_quadratic]
        Sf = filter_responses[:,:, 1+n_E_quadratic:]
        assert Lf.dtype == input.dtype

        sqrt = theano.tensor.sqrt
        E = sqrt((Ef**2).sum(axis=2) + epsilon + softplus(Lf+b)**2)
        S = sqrt((Sf**2).sum(axis=2) + epsilon)

        output = (E-S) / (1+E+S)
        assert output.dtype == input.dtype
        Ef = Ef
        Sf = Sf
        E = E
        S = S

        l1 = abs(w).sum()
        l2_sqr = (w**2).sum()

        update_locals(self, locals())

    @classmethod
    def new(cls, input, n_in, n_out, n_E, n_S, rng, eps=1.0e-6, filter_shape=None, dtype=None):
        """Allocate parameters and initialize them randomly.
        """
        if dtype is None:
            dtype = input.dtype
        epsilon = numpy.asarray(eps, dtype=dtype)
        w = shared(numpy.asarray(
                rng.randn(n_in, n_out*(1 + n_E + n_S))*.3 / numpy.sqrt(n_in),
                dtype=dtype))
        b = shared(numpy.zeros((n_out,), dtype=dtype))
        return cls(input, w, b, n_out, n_E, n_S, epsilon, filter_shape, [w,b])

    def img_from_weights(self, rows=12, cols=25, row_gap=1, col_gap=1, eps=1e-4, triplegap=0):
        """ Return an image that visualizes all the weights in the layer.

        The current implentation returns a tiling in which every triple of columns is a logical
        group.  The first column in a triple has images of the linear weights.  The second
        column in a triple has images of the exciting quadratic weights. The third column in a
        triple has images of the supressive quadratic weights.

        """
        if cols % 3: #because there are three kinds of filters: linear, excitatory, inhibitory
            raise ValueError("cols must be multiple of 3")

        n_triples = cols / 3

        filter_shape = self.filter_shape
        height = rows * (row_gap + filter_shape[0]) - row_gap
        width = cols * (col_gap + filter_shape[1]) - col_gap + (n_triples-1) * triplegap

        out_array = numpy.zeros((height, width, 3), dtype='uint8')

        w = self.w.value
        w_col = 0
        def pixel_range(x):
            return 255 * (x - x.min()) / (x.max() - x.min() + eps)
        for r in xrange(rows):
            out_r_low = r*(row_gap + filter_shape[0])
            out_r_high = out_r_low + filter_shape[0]
            extra_col_gap = 0 # a counter we'll use for the triplegap
            for c in xrange(cols):
                if c and (c%3==0):
                    extra_col_gap += triplegap
                out_c_low = c*(col_gap + filter_shape[1]) + extra_col_gap
                out_c_high = out_c_low + filter_shape[1]
                out_tile = out_array[out_r_low:out_r_high, out_c_low:out_c_high,:]
                assert out_tile.shape[1] == filter_shape[1]

                if c % 3 == 0: # linear filter
                    if w_col < w.shape[1]:
                        out_tile[...] = pixel_range(w[:,w_col]).reshape(filter_shape+(1,))
                        w_col += 1
                if c % 3 == 1: # E filters
                    if w_col < w.shape[1]:
                        #filters after the 3rd do not get rendered, but are skipped over.
                        #  there are only 3 colour channels.
                        for i in xrange(min(self.n_E_quadratic,3)):
                            out_tile[:,:,i] = pixel_range(w[:,w_col+i]).reshape(filter_shape)
                        w_col += self.n_E_quadratic
                if c % 3 == 2: # S filters
                    if w_col < w.shape[1]:
                        #filters after the 3rd do not get rendered, but are skipped over.
                        #  there are only 3 colour channels.
                        for i in xrange(min(self.n_S_quadratic,3)):
                            out_tile[:,:,2-i] = pixel_range(w[:,w_col+i]).reshape(filter_shape)
                        w_col += self.n_S_quadratic
        return Image.fromarray(out_array, 'RGB')
add_logging(Rust2005)

class Rust2005Conv(object):
    """Convolutional version of `Rust2005`

    :note:
    The layer doesn't contain an option for downsampling. It makes sense to downsample the output
    using DownsampleMaxPool, but downsampling is orthogonal to the behaviour of this
    layer so it is not included.
    
    """

    l1 = 0.0
    l2_sqr = 0.0

    eps=1e-6 # default epsilon to prevent sqrt(0) in quadratic filters

    image_shape = None
    filter_shape = None
    output_shape = None
    output_channels = None
    output_examples = None

    def __init__(self, linpart, Es, Ss, params, eps=eps):
        """
        """
        eps = numpy.asarray(self.eps, linpart.dtype)
        output, E, S = rust2005_act_from_filters(linpart, Es, Ss, eps)

        update_locals(self, locals())


    @classmethod
    def new(cls, rng, input, image_shape, filter_shape, n_examples, n_filters, n_E, n_S,
            n_channels=1,
            eps=1.0e-6, dtype=None, conv_mode='valid',
            w_range=None,
            q_range=None
            ):
        """ Return Rust2005Conv layer

        layer.output will be 4D tensor with shape (n_examples, n_filters, R, C) where R and C
        depend on the image_shape, the filter_shape and the convolution mode.


        :param rng: generator for randomized initial filters
        :param input: symbolic input (4D tensor)
        :type input: 4D tensor with shape (n_examples, n_channels, image_shape[0], image_shape[1])

        :param image_shape:  rows, cols of every channel of every image
        :param filter_shape: rows, cols of every filter
        :param bsize: number of images to be treated
        :param n_filters: number of filters (output will have this many channels)
        :param n_channels: number of channels in each image and filter
        :param n_E: number of squared exciting terms
        :param n_S: number of squared inhibition terms
        :param eps: epsilon to add to sum-of-squares in sqrt
        :param dtype: dtype to use for new variables (Default: input.dtype)
        :param conv_mode: convolution mode
        :param w_range: linear weights will be drawn uniformly from this range
        :type w_range: pair (lower_bound, upper_bound
        :param q_range: quadratic weights will be drawn uniformly from this range
        :type q_range: pair (lower_bound, upper_bound
        """
        if dtype is None:
            dtype = input.dtype

        irows, icols = image_shape
        krows, kcols = filter_shape

        conv = ConvOp((n_channels,irows, icols), (krows, kcols), n_filters, n_examples,
                dx=1, dy=1, output_mode=conv_mode)

        w_shp = (n_filters, n_channels, krows, kcols)
        b_shp = (n_filters,)

        if w_range is None:
            w_low = -2.0/numpy.sqrt(image_shape[0] * image_shape[1] * n_channels)
            w_high = 2.0/numpy.sqrt(image_shape[0] * image_shape[1] * n_channels)
        else:
            w_low, w_high = w_range

        if q_range is None:
            q_low, q_high = w_low, w_high
        else:
            q_low, q_high = w_range

        w = shared(numpy.asarray(rng.uniform(low=w_low, high=w_high, size=w_shp), dtype=dtype))
        b = shared(numpy.asarray(rng.uniform(low=w_low, high=w_low, size=b_shp), dtype=dtype))

        E_w = [
                shared(numpy.asarray(rng.uniform(low=q_low, high=q_high, size=w_shp), dtype=dtype))
                for i in xrange(n_E)
                ]
        S_w = [
                shared(numpy.asarray(rng.uniform(low=q_low, high=q_high, size=w_shp), dtype=dtype))
                for i in xrange(n_S)
                ]

        rval = cls(
                linpart=conv(input, w) + b.dimshuffle(0,'x','x'),
                Es=[conv(input, e) for e in E_w],
                Ss=[conv(input, s) for s in S_w],
                params=[w,b]+E_w + S_w,
                eps=eps)

        # ignore bias in l1 (Yoshua's habit)
        rval.l1 = sum(abs(p) for p in ([w]+E_w+S_w))
        rval.l2_sqr = sum(p**2 for p in ([w]+E_w+S_w))
        rval.image_shape = image_shape
        rval.filter_shape = filter_shape
        rval.output_shape = conv.outshp
        rval.output_channels = n_filters # how many channels of *output*
        rval.output_examples = n_examples

        return rval
add_logging(Rust2005Conv) # _debug, _info, _warn, _error, _fatal