changeset 834:580087712f69

added shared.layers
author James Bergstra <bergstrj@iro.umontreal.ca>
date Fri, 16 Oct 2009 12:14:43 -0400
parents 039e93a95c20
children 79c482ec4ccf
files pylearn/shared/README.txt pylearn/shared/__init__.py pylearn/shared/layers/README.txt pylearn/shared/layers/__init__.py pylearn/shared/layers/exponential_mean.py pylearn/shared/layers/kording2004.py pylearn/shared/layers/kouh2008.py pylearn/shared/layers/lecun1998.py pylearn/shared/layers/logreg.py pylearn/shared/layers/rust2005.py pylearn/shared/layers/sandbox/__init__.py pylearn/shared/layers/sandbox/adelsonbergen87.py pylearn/shared/layers/sandbox/linsvm.py pylearn/shared/layers/sgd.py pylearn/shared/layers/sigmoidal_layer.py pylearn/shared/layers/squash.py pylearn/shared/layers/tests/test_kording2004.py pylearn/shared/layers/tests/test_kouh2008.py pylearn/shared/layers/tests/test_lecun1998.py pylearn/shared/layers/tests/test_sigmoidal_layer.py pylearn/shared/layers/util.py
diffstat 19 files changed, 1254 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/README.txt	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,2 @@
+The shared folder is for code taking advantage of Theano's shared-variable feature.
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/README.txt	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,29 @@
+
+Layers are the building blocks of neural networks.
+Often they are parametric, but not necessarily.
+
+This directory is meant to be a library of layers and, where applicable, the
+algorithms meant to fit them to data.
+
+
+.. code-block:: python
+
+    class Layer(object):
+
+        """ Base class for Layer, documenting interface conventions
+
+        WRITEME
+        """
+
+        input = None
+
+        output = None
+
+        l1 = 0
+
+        l2_sqr = 0
+
+        params = []
+
+        updates = []
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/__init__.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,23 @@
+# logreg.py
+from .logreg import LogisticRegression
+
+# sigmoidal_layer.py
+from .sigmoidal_layer import SigmoidalLayer
+
+# exponential_mean.py
+from .exponential_mean import ExponentialMean
+
+# sgd.py
+from .sgd import StochasticGradientDescent, HalflifeStopper
+
+# kording
+from kording2004 import Kording2004
+
+# rust2005.py
+from rust2005 import Rust2005
+
+# lecun1998.py
+from lecun1998 import LeNetConvPool
+
+# kouh2008.py
+from kouh2008 import Kouh2008
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/exponential_mean.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,84 @@
+
+"""Modules for maintaining statistics based on exponential decay"""
+__docformat__ = "restructuredtext en"
+
+import copy
+import numpy
+import theano
+import theano.tensor
+from theano.compile.sandbox import shared
+
+class ExponentialMean(object):
+    """Maintain an exponentially-decaying estimate of the mean
+
+    This module computes the exact mean of the first `max_denom` values of `x`.
+    After the first `max_denom` values, it tracks the mean using the formula:
+
+    :math:`self.running <- (1.0 - (1.0/max_denom)) * self.running + (1.0/max_denom) * x`
+
+    """
+
+    max_denom = None
+    """The average will be updated as if the current estimated average was estimated from at
+    most `max_denom-1` values."""
+
+    running = None
+    """Shared: The running mean statistic from which the output is computed."""
+
+    denom = None
+    """Shared: The number of examples we've updated from so far
+    """
+
+    def __init__(self, input, max_denom, ival):
+        """
+        :param input: track the mean of this Variable
+
+        :param max_denom: see `self.max_denom`
+    
+        :param ival: This should be a tensor of zeros with a shape that matches `input`'s runtime
+        value.
+
+        """
+        dtype=ival.dtype #dtype is an actual numpy dtype object, not a string
+        self.max_denom = max_denom
+
+        if len(ival.shape) == 0:
+            input_type = theano.tensor.dscalar
+        elif len(ival.shape) == 1:
+            input_type = theano.tensor.dvector
+        elif len(ival.shape) == 2:
+            input_type = theano.tensor.dmatrix
+        else:
+            #TODO: x_type = theano.tensor.TensorType(...)
+            raise NotImplementedError()
+
+        self.running = shared(numpy.array(ival, copy=True))
+        # TODO: making this an lscalar caused different optimizations, followed by integer
+        # division somewhere were I wanted float division.... and the wrong answer.
+        self.denom = shared(numpy.asarray(1, dtype=dtype))
+
+        alpha = 1.0 / self.denom
+        self.output = (1.0 - alpha) * self.running + theano.tensor.cast(alpha * input, str(dtype))
+
+        self.updates = [
+                (self.running, self.output),
+                (self.denom, theano.tensor.smallest(self.denom + 1, self.max_denom)),
+                ]
+
+        assert self.output.type.dtype == dtype
+
+    @classmethod
+    def new(cls, x, x_shape, max_denom, dtype='float64'):
+        """Return an `ExponentialMean` to track a Variable `x` with given shape
+        
+        :type x: Variable
+        :type x_shape: tuple
+        :type max_denom: int
+        :type dtype: string
+        :param dtype: the running average will be computed at this precision
+
+        :rtype: ExponentialMean instance
+        """
+        return cls(x, 
+                max_denom=max_denom,
+                ival=numpy.zeros(x_shape, dtype=dtype))
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/kording2004.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,127 @@
+import numpy
+import theano.tensor
+from hpu.theano_outgoing import mean, var, cov
+
+from .exponential_mean import ExponentialMean # exponential_mean.py
+
+import logging
+_logger = logging.getLogger('kording2004')
+def debug(*msg): _logger.debug(' '.join(str(m) for m in msg))
+def info(*msg): _logger.info(' '.join(str(m) for m in msg))
+def warn(*msg): _logger.warn(' '.join(str(m) for m in msg))
+def warning(*msg): _logger.warning(' '.join(str(m) for m in msg))
+def error(*msg): _logger.error(' '.join(str(m) for m in msg))
+
+def cov_sum_of_squares(z, hint='tall', bias=0):
+    """Return the sum of the squares of all terms in the covariance of [normalized-and-centered] z
+
+    :param hint: either 'tall' or 'fat' to indicate whether the computation should be carried
+    out on the gram matrix or in the covariance matrix.
+    
+    :note: This is computed using either the inner or outer-product depending on the `hint`
+    """
+    denom = theano.tensor.cast(z.shape[0] if bias else (z.shape[0]-1), z.dtype)
+    if hint == 'fat':
+        return theano.tensor.sum(theano.tensor.dot(z, z.T)**2) / denom**2
+    elif hint == 'tall':
+        return theano.tensor.sum(theano.tensor.dot(z.T, z)**2) / denom**2
+    else:
+        raise ValueError(hint)
+
+def var_sum_of_squares(z, bias=0):
+    """Return the sum of squared variances in the columns of centered variable z
+    """
+    denom = theano.tensor.cast((z.shape[0] if bias else (z.shape[0]-1)), z.dtype)
+    return theano.tensor.sum(theano.tensor.sum(z**2, axis=0)**2) / denom**2
+
+def kording2004_normalized_decorrelation3(z, hint='fat'):
+    """Return the sum of the squares of the off-diagonal terms of an uncentered covariance
+    matrix
+
+    :param z: a matrix of feature responses.  Each row is the responses at one time-step.
+
+    These features must have marginal mean 0 and variance 1 for this cost to make sense as a
+    training criterion.
+
+    :note: This is computed using the gram matrix, not the covariance matrix
+    """
+    assert z.ndim == 3
+    zshape = z.shape
+    z2 = theano.tensor.reshape(z, [zshape[0]*zshape[1], zshape[2]])
+    return cov_sum_of_squares(z2, hint=hint) - var_sum_of_squares(z2)
+
+def kording2004_normalized_slowness3(z, slowness_type='l2'):
+    """Return the average squared difference between each feature response and its previous
+    response.
+
+    :param z: a 3-tensor of feature responses.  Indexed [sequence][frame][feature]
+
+    These features must have marginal mean 0 and variance 1 for this cost to make sense as a
+    training criterion.
+    """
+    assert z.ndim == 3
+    diff = (z[:,1:,:] - z[:,0:-1,:]) #the diff is taken over axis 1
+    if slowness_type=='l2':
+        cost = diff**2
+    elif slowness_type=='l1':
+        cost = abs(diff)
+    else:
+        raise ValueError(slowness_type)
+    rval = theano.tensor.mean(cost)
+    assert rval.ndim == 0
+    return rval
+
+class Kording2004(object):
+    """This implements the Kording2004 cost using a dynamicly tracked mean, but not a
+    dynamically tracked variance.
+
+    It is designed to accept 3-tensors, indexed like this: [movie_idx, frame_idx, feature_idx]
+    The variance in each feature will be computed over the outer two dimensions.
+    The speed of each feature will be computed as the first derivative over frame_idx, and the
+    mean over movie_idx.
+
+    """
+
+    def __init__(self, input, (n_movies, n_frames, n_hid), slowness_multiplier, 
+            slowness_type='l2',
+            eps=None, dtype='float64'):
+        info('Using Kording2004')
+        if eps == None:
+            if input.dtype == 'float64':
+                eps = numpy.asarray(1e-8, dtype=input.dtype)
+            else:
+                eps = numpy.asarray(1e-5, dtype=input.dtype)
+        assert input.ndim == 3
+        self.input = input
+        self.n_hid = n_hid
+        self.n_movies = n_movies
+        self.n_frames = n_frames
+        self.slowness_multiplier = slowness_multiplier
+        cur_mean_input = mean(input, axis=[0,1])
+        assert cur_mean_input.ndim == 1
+        self.mean_input = ExponentialMean.new(cur_mean_input, x_shape=(n_hid,), max_denom=500, dtype=dtype)
+
+        assert self.mean_input.output.dtype == dtype
+
+        centered_input = self.input - self.mean_input.output #broadcasting over first 2 of 3 dims
+        var_input = theano.tensor.mean(centered_input**2, axis=0)
+        assert var_input.dtype == dtype
+
+        z = centered_input / theano.tensor.sqrt(var_input + eps)
+
+        assert z.dtype == dtype
+        self.z = z
+
+        self.corr = kording2004_normalized_decorrelation3(z)
+        assert self.corr.dtype == dtype
+        self.slow = kording2004_normalized_slowness3(z, slowness_type=slowness_type)
+        assert self.slow.dtype == dtype
+
+        print slowness_multiplier, type(slowness_multiplier), slowness_multiplier.dtype
+        assert self.slowness_multiplier.dtype == dtype
+
+        self.output = self.slowness_multiplier * self.slow + self.corr
+
+        self.params = []
+        self.updates = list(self.mean_input.updates)
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/kouh2008.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,119 @@
+""" 
+Paper: 
+
+This layer implements a model of complex cell firing rate responses.
+
+Canonical neural circuit (Kouh and Poggio, 2008)
+
+This layer is in a sense a 2-layer neural network, with a strange activation function
+in the middle.   It is introduced in "A Canonical Neural Circuit for Cortical Nonlinear
+Operations", NECO 2008.  It includes various complex-cell models and approximates neural
+network activation functions as special cases.
+
+"""
+
+## optimizing this model may be difficult-- paper talks about using exponents p and q in
+# in the range 1-3, but gradient descent may overstep that range.
+
+# TODO: Use updates() to clamp exponents p and q to sensible range
+
+import numpy
+import theano
+from theano import tensor
+from theano.tensor.nnet import softplus
+from theano.compile.sandbox import shared
+from .util import add_logging, update_locals
+
+def _shared_uniform(rng, low, high, size, dtype, name=None):
+    return shared(
+            numpy.asarray(
+                rng.uniform(low=low, high=high, size=size),
+                dtype=dtype), name)
+
+class Kouh2008(object):
+    """WRITEME
+
+    :param x: a list of N non-negative tensors of shape (n_examples, n_out)
+    :param w: a list of N output weights of shape (n_out, )
+    :param p: a tensor of exponents of shape (n_out,)
+    :param q: a tensor of exponents of shape (n_out,)
+    :param k: a tensor of biases of shape (n_out,)
+
+    output - a tensor of activations of shape (n_examples, n_out)
+    """
+
+    def __init__(self, w_list, x_list, p, q, r, k, params, updates):
+        """Transcription of equation 2.1 from paper that appears on page 1434.
+        """
+        if len(w_list) != len(x_list):
+            raise ValueError('w_list must have same len as x_list')
+        output = (sum(w * tensor.pow(x, p) for (w,x) in zip(w_list, x_list)))\
+                / (k + tensor.pow(sum(tensor.pow(x, q) for x in x_list), r))
+
+        assert output.type.ndim == 2
+        update_locals(self, locals())
+
+    @classmethod
+    def new(cls, rng, x_list, n_out, dtype=None, params=[], updates=[]):
+        """
+        """
+        if dtype is None:
+            dtype = x_list[0].dtype
+        n_terms = len(x_list)
+
+        def shared_uniform(low, high, size, name): 
+            return _shared_uniform(rng, low, high, size, dtype, name)
+
+        w_list = [shared_uniform(low=-2.0/n_terms, high=2.0/n_terms, size=(n_out,), name='w_%i'%i)
+                for i in xrange(n_terms)]
+        p = shared_uniform(low=1.0, high=3.0, size=(n_out,), name='p')
+        q = shared_uniform(low=1.0, high=3.0, size=(n_out,), name='q')
+        r = shared_uniform(low=0.3, high=0.8, size=(n_out,), name='r')
+        k = shared_uniform(low=-0.3, high=0.3, size=(n_out,), name='k')
+        return cls(w_list, x_list, p, q, r, k,
+                params = [p, q, r, k] + w_list + params,
+                updates=updates)
+
+    @classmethod
+    def new_filters(cls, rng, input, n_in, n_out, n_terms, dtype=None):
+        """Return a KouhLayer instance with random parameters
+
+        The parameters are drawn on a range [typically] suitable for fine-tuning by gradient
+        descent. 
+
+
+        :param input: a tensor of shape (n_examples, n_in)
+
+        :type n_in: positive int
+        :param n_in: number of input dimensions
+
+        :type n_out: positive int
+        :param n_out: number of dimensions in rval.output
+
+        :param nterms: each (of n_out) complex-cell firing rate will be determined from this
+        many 'simple cell' responses.
+
+        :returns: KouhLayer instance with freshly-allocated random weights.
+
+        """
+        if input.type.ndim != 2:
+            raise TypeError('matrix expected for input')
+
+        if dtype is None:
+            dtype = input.dtype
+
+        def shared_uniform(low, high, size, name): 
+            return _shared_uniform(rng, low, high, size, dtype, name)
+
+        f_list = [shared_uniform(low=-2.0/n_in, high=2.0/n_in, size=(n_in, n_out), name='f_%i'%i)
+                for i in xrange(n_terms)]
+
+        x_list = [softplus(tensor.dot(input, f_list[i])) for i in xrange(n_terms)]
+
+        rval = cls.new(rng, x_list, n_out, dtype=dtype, params=f_list)
+        rval.input = input #add the input to the returned object
+        rval.l1 = sum(abs(fi).sum() for fi in f_list)
+        rval.l2_sqr = sum((fi**2).sum() for fi in f_list)
+        return rval
+
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/lecun1998.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,74 @@
+""" Provide the convolution and pooling layers described in LeCun 98
+
+"""
+
+import numpy
+
+import theano
+from theano import tensor
+from theano.compile.sandbox import shared, pfunc
+
+from theano.sandbox.conv import ConvOp
+from theano.sandbox.downsample import DownsampleFactorMax
+
+from .util import update_locals
+from .squash import squash
+
+class LeNetConvPool(object):
+    """
+    """
+
+    #TODO: implement biases & scales properly. There are supposed to be more parameters.
+    #    - one bias & scale per filter
+    #    - one bias & scale per downsample feature location (a 2d bias)
+    #    - more?
+
+    def __init__(self, input, w, b, conv_op, ds_op, squash_op, params):
+        if input.ndim != 4:
+            raise TypeError(input)
+        if w.ndim != 4:
+            raise TypeError(w)
+        if b.ndim != 1:
+            raise TypeError(b)
+
+        conv_out = conv_op(input, w)
+        output = squash_op(ds_op(conv_out) + b.dimshuffle('x', 0, 'x', 'x'))
+        update_locals(self, locals())
+
+    @classmethod
+    def new(cls, rng, input, n_examples, n_imgs, img_shape, n_filters, filter_shape, poolsize,
+            ignore_border=True, conv_subsample=(1,1), dtype=None, conv_mode='valid',
+            pool_type='max', squash_fn=tensor.tanh):
+        """
+        """
+        if pool_type != 'max':
+            # LeNet5 actually used averaging filters. Consider implementing 'mean'
+            # consider 'min' pooling?
+            # consider 'prod' pooling or some kind of geometric mean 'gmean'??
+            raise NotImplementedError()
+
+        if conv_subsample != (1,1):
+            # we need to adjust our calculation of the bias size
+            raise NotImplementedError()
+
+        if dtype is None:
+            dtype = input.dtype
+
+        if len(filter_shape) != 2:
+            raise TypeError(filter_shape)
+
+        conv_op = ConvOp((n_imgs,)+img_shape, filter_shape, n_filters, n_examples,
+                dx=conv_subsample[0], dy=conv_subsample[1], output_mode=conv_mode)
+        ds_op = DownsampleFactorMax(poolsize, ignore_border=ignore_border)
+
+        w_shp = (n_filters, n_imgs) + filter_shape
+        b_shp = (n_filters,)
+
+        w = shared(numpy.asarray(rng.uniform(low=-.05, high=.05, size=w_shp), dtype=dtype))
+        b = shared(numpy.asarray(rng.uniform(low=-.05, high=.05, size=b_shp), dtype=dtype))
+
+        if isinstance(squash_fn, str):
+            squash_fn = squash(squash_fn)
+
+        return cls(input, w, b, conv_op, ds_op, squash_fn, [w,b])
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/logreg.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,47 @@
+"""Provides LogisticRegression
+"""
+import numpy
+import theano
+from theano.compile.sandbox import shared
+from theano.tensor import nnet
+from .util import update_locals, add_logging
+
+class LogisticRegression(object):
+    def __init__(self, input, w, b, params=[]):
+        output=nnet.softmax(theano.dot(input, w)+b)
+        l1=abs(w).sum()
+        l2_sqr = (w**2).sum()
+        argmax=theano.tensor.argmax(theano.dot(input, w)+b, axis=input.ndim-1)
+        update_locals(self, locals())
+
+    @classmethod
+    def new(cls, input, n_in, n_out, dtype=None):
+        if dtype is None:
+            dtype = input.dtype
+        cls._debug('allocating params w, b', n_in, n_out, dtype)
+        w = shared(numpy.zeros((n_in, n_out), dtype=dtype))
+        b = shared(numpy.zeros((n_out,), dtype=dtype))
+        return cls(input, w, b, params=[w,b])
+
+
+    def nll(self, target):
+        """Return the negative log-likelihood of the prediction of this model under a given
+        target distribution.  Passing symbolic integers here means 1-hot.
+        WRITEME
+        """
+        return nnet.categorical_crossentropy(self.output, target)
+
+    def errors(self, target):
+        """Return a vector of 0s and 1s, with 1s on every line that was mis-classified.
+        """
+        if target.ndim != self.argmax.ndim:
+            raise TypeError('target should have the same shape as self.argmax', ('target', target.type,
+                'argmax', self.argmax.type))
+        if target.dtype.startswith('int'):
+            return theano.tensor.neq(self.argmax, target)
+        else:
+            raise NotImplementedError()
+
+add_logging(LogisticRegression)
+
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/rust2005.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,150 @@
+""" Provides Rust2005 layer
+
+Paper: 
+
+This layer implements a model of simple and complex cell firing rate responses.
+
+"""
+
+import numpy
+try:
+    from PIL import Image
+except:
+    pass
+
+import theano
+import theano.tensor
+import theano.tensor.nnet
+from theano.compile.sandbox import shared
+from theano.sandbox.softsign import softsign
+from theano.tensor.nnet import softplus
+
+from .util import update_locals, add_logging
+
+def rust2005_act_from_filters(linpart, E_quad, S_quad):
+    sqrt = theano.tensor.sqrt
+    softlin = theano.tensor.nnet.softplus(linpart)
+    E = sqrt(sum([E_quad_i**2 for E_quad_i in E_quad] + [1e-8, softlin**2]))
+    S = sqrt(sum([S_quad_i**2 for S_quad_i in S_quad] + [1e-8]))
+    return (E-S) / (1+E+S)
+
+class Rust2005(object):
+    """
+    shared variable version.
+
+    w is 3-tensor n_in x n_out x (1+n_E_quadratic + n_S_quadratic)
+
+    w 
+
+    """
+    #logging methods come from the add_logging() call below
+    # _info, _debug, _warn, _error, _fatal
+
+    def __init__(self, input, w, b, n_out, n_E_quadratic, n_S_quadratic,
+            epsilon, filter_shape, params):
+        """
+        w should be a matrix with input.shape[1] rows, and n_out *
+        (1+n_E_quadratic+n_S_quadratic) columns.
+
+        Every successive block of (1+n_E_quadratic+n_S_quadratic) adjacent columns contributes
+        to the computation of one output features.  The first column in the block is the filter
+        for the linear term.  The following n_E_quadratic columns are used to compute the
+        exciting quadratic part.  The following n_S_quadratic columns are used to compute the
+        inhibitory part.
+        """
+        if w.dtype != input.dtype:
+            self._warn('WARNING w type mismatch', input.dtype, w.dtype, b.dtype)
+        if b.dtype != input.dtype:
+            self._warn( 'WARNING b type mismatch', input.dtype, w.dtype, b.dtype)
+        #when each column of w corresponds to a flattened shape, put it here.
+        # filter_shape is used for rendering weights as tiled images
+
+        filter_responses = theano.dot(input, w).reshape((
+                input.shape[0],
+                n_out, 
+                1 + n_E_quadratic + n_S_quadratic))
+
+        assert filter_responses.dtype == input.dtype
+        Lf = filter_responses[:, :, 0]
+        Ef = filter_responses[:,:, 1:1+n_E_quadratic]
+        Sf = filter_responses[:,:, 1+n_E_quadratic:]
+        assert Lf.dtype == input.dtype
+
+        sqrt = theano.tensor.sqrt
+        E = sqrt((Ef**2).sum(axis=2) + epsilon + softplus(Lf+b)**2)
+        S = sqrt((Sf**2).sum(axis=2) + epsilon)
+
+        output = (E-S) / (1+E+S)
+        assert output.dtype == input.dtype
+        Ef = Ef
+        Sf = Sf
+        E = E
+        S = S
+
+        l1 = abs(w).sum()
+        l2_sqr = (w**2).sum()
+
+        update_locals(self, locals())
+
+    @classmethod
+    def new(cls, input, n_in, n_out, n_E, n_S, rng, eps=1.0e-6, filter_shape=None, dtype=None):
+        """Allocate parameters and initialize them randomly.
+        """
+        if dtype is None:
+            dtype = input.dtype
+        epsilon = numpy.asarray(eps, dtype=dtype)
+        w = shared(numpy.asarray(
+                rng.randn(n_in, n_out*(1 + n_E + n_S))*.3 / numpy.sqrt(n_in),
+                dtype=dtype))
+        b = shared(numpy.zeros((n_out,), dtype=dtype))
+        return cls(input, w, b, n_out, n_E, n_S, epsilon, filter_shape, [w,b])
+
+    def img_from_weights(self, rows=12, cols=24, row_gap=1, col_gap=1, eps=1e-4):
+        """ Return an image that visualizes all the weights in the layer.
+
+        The current implentation returns a tiling in which every triple of columns is a logical
+        group.  The first column in a triple has images of the linear weights.  The second
+        column in a triple has images of the exciting quadratic weights. The third column in a
+        triple has images of the supressive quadratic weights.
+
+        """
+        if cols % 3: #because there are three kinds of filters: linear, excitatory, inhibitory
+            raise ValueError("cols must be multiple of 3")
+        filter_shape = self.filter_shape
+        height = rows * (row_gap + filter_shape[0]) - row_gap
+        width = cols * (col_gap + filter_shape[1]) - col_gap
+
+        out_array = numpy.zeros((height, width, 3), dtype='uint8')
+
+        w = self.w.value
+        w_col = 0
+        def pixel_range(x):
+            return 255 * (x - x.min()) / (x.max() - x.min() + eps)
+        for r in xrange(rows):
+            out_r_low = r*(row_gap + filter_shape[0])
+            out_r_high = out_r_low + filter_shape[0]
+            for c in xrange(cols):
+                out_c_low = c*(col_gap + filter_shape[1])
+                out_c_high = out_c_low + filter_shape[1]
+                out_tile = out_array[out_r_low:out_r_high, out_c_low:out_c_high,:]
+
+                if c % 3 == 0: # linear filter
+                    if w_col < w.shape[1]:
+                        out_tile[...] = pixel_range(w[:,w_col]).reshape(filter_shape+(1,))
+                        w_col += 1
+                if c % 3 == 1: # E filters
+                    if w_col < w.shape[1]:
+                        #filters after the 3rd do not get rendered, but are skipped over.
+                        #  there are only 3 colour channels.
+                        for i in xrange(min(self.n_E_quadratic,3)):
+                            out_tile[:,:,i] = pixel_range(w[:,w_col+i]).reshape(filter_shape)
+                        w_col += self.n_E_quadratic
+                if c % 3 == 2: # S filters
+                    if w_col < w.shape[1]:
+                        #filters after the 3rd do not get rendered, but are skipped over.
+                        #  there are only 3 colour channels.
+                        for i in xrange(min(self.n_S_quadratic,3)):
+                            out_tile[:,:,2-i] = pixel_range(w[:,w_col+i]).reshape(filter_shape)
+                        w_col += self.n_S_quadratic
+        return Image.fromarray(out_array, 'RGB')
+add_logging(Rust2005)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/sandbox/adelsonbergen87.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,80 @@
+""" WRITEME
+
+Paper: 
+
+This is the so-called "Energy Model" of complex cell response.
+
+"""
+import theano
+import theano.tensor
+import theano.tensor.nnet
+from theano.sandbox.softsign import softsign
+
+try:
+    # for render_img
+    from pylearn.io.image_tiling import tile_raster_images
+    from PIL import Image
+except:
+    pass
+
+def adelson_bergen_87(filter0, filter1, limit=0):
+    h = theano.tensor.sqrt(filter0**2 + filter1**2 + 1.0e-8)
+    if limit:
+        return limit * softsign((1.0/limit) * h)
+    else:
+        return h
+
+class AdelsonBergenLayer(theano.Module):
+    def __init__(self, x,
+            w=None, u=None,
+            w_val=None, u_val=None, limit=False):
+        super(AdelsonBergenLayer, self).__init__()
+
+        self.w = theano.tensor.dmatrix() if w is None else w
+        self.u = theano.tensor.dmatrix() if u is None else u
+
+        self._params = [self.w, self.u]
+
+        self.w_val = w_val
+        self.u_val = u_val
+        self.limit = limit
+
+        self.output = adelson_bergen_87(theano.dot(x, self.w), theano.dot(x, self.u), limit=self.limit)
+
+    def _instance_initialize(self, obj):
+        obj.w = self.w_val.copy()
+        obj.u = self.u_val.copy()
+
+    def l1(self):
+        return abs(self.w).sum() + abs(self.u).sum()
+
+    def l2(self):
+        return theano.tensor.sqrt((self.w**2).sum() + (self.u**2).sum())
+
+    def params(self):
+        return list(self._params)
+
+    def _instance_save_img(self, obj, filename, **kwargs):
+        obj.render_img(**kwargs).save(filename)
+
+    def _instance_render_img(self, obj, img_shape, 
+            tile_shape=(12,25), tile_spacing=(1,1)):
+        """ Render the weights of this module an image.
+        :param filename: save the image to this file
+        :param img_shape: interpret the columns of weight matrices as images of this shape
+        :param tile_shape: see pylearn.io.tile_raster_images
+        :param tile_spacing: see pylearn.io.tile_raster_images
+        """
+        if (img_shape[0] * img_shape[1]) != obj.w.shape[0]:
+            raise ValueError("Image shape doesn't match filter column length")
+        return Image.fromarray(
+                tile_raster_images((
+                        obj.w.T, #RED
+                        None,  #GREEN 
+                        obj.u.T, #BLUE
+                        None),  #ALPHA
+                    img_shape=img_shape,
+                    tile_shape=tile_shape,
+                    tile_spacing=tile_spacing),
+                'RGBA')
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/sandbox/linsvm.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,40 @@
+import numpy
+import theano
+from theano.compile.sandbox import shared
+from theano.tensor import nnet
+from .util import update_locals
+
+class LinearSVM(object):
+    def __init__(self, input, w, b, params=[]):
+        output=nnet.softmax(theano.dot(input, w)+b)
+        l1=abs(w).sum()
+        l2 = (w**2).sum()
+        argmax=theano.tensor.argmax(theano.dot(input, w)+b, axis=input.ndim-1)
+        update_locals(self, locals())
+
+    @classmethod
+    def new(cls, input, n_in, n_out):
+        w = shared(numpy.zeros((n_in, n_out), dtype=input.dtype))
+        b = shared(numpy.zeros((n_out,), dtype=input.dtype))
+        return cls(input, w, b, params=[w,b])
+
+
+    def margin(self, target):
+        """Return the negative log-likelihood of the prediction of this model under a given
+        target distribution.  Passing symbolic integers here means 1-hot.
+        WRITEME
+        """
+        raise NotImplementedError()
+
+    def errors(self, target):
+        """Return a vector of 0s and 1s, with 1s on every line that was mis-classified.
+        """
+        if target.ndim != self.argmax.ndim:
+            raise TypeError('target should have the same shape as self.argmax', ('target', target.type,
+                'argmax', self.argmax.type))
+        if target.dtype.startswith('int'):
+            return theano.tensor.neq(self.argmax, target)
+        else:
+            raise NotImplementedError()
+
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/sgd.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,146 @@
+"""
+Provides StochasticGradientDescent, HalflifeStopper
+"""
+import numpy
+import theano
+from theano import tensor
+from theano.compile.sandbox import shared
+
+class StochasticGradientDescent(object):
+    """Fixed stepsize gradient descent
+
+    For given inputs, the outputs of this object are the new values that the inputs should take
+    in order to perform stochastic gradient descent. 
+
+    The updates attribute is a list of (p, new_p) pairs for all inputs `p` that are
+    SharedVariables. 
+
+    """
+    def __init__(self, inputs, cost, stepsize, gradients, params):
+        """
+        :param stepsize: the step to take in (negative) gradient direction
+        :type stepsize: None, scalar value, or scalar TensorVariable
+
+        :param updates: extra symbolic updates to make when evating either step or step_cost
+        (these override the gradients if necessary)
+        :type updates: dict Variable -> Variable
+        :param auxout: auxiliary outputs, list containing output symbols to 
+                      compute at the same time as cost (for efficiency)
+        :param methods: Should this module define the step and step_cost methods?
+        """
+        if len(inputs) != len(gradients):
+            raise ValueError('inputs list and gradients list must have same len')
+
+        self.inputs = inputs
+        self.params = params
+        self.updates = updates = []
+        self.outputs = outputs = []
+
+        for i, g in zip(inputs, gradients):
+            o = i - stepsize * g
+            outputs.append(o)
+            if hasattr(i, 'value'): # this is true for shared variables, false for most things.
+                updates.append((i, o))
+
+    @classmethod
+    def new(cls, inputs, cost, stepsize, dtype=None):
+        if dtype is None:
+            dtype = cost.dtype
+
+        ginputs = tensor.grad(cost, inputs)
+
+        if isinstance(stepsize, theano.Variable):
+            _stepsize = stepsize
+            params = []
+        else:
+            _stepsize = shared(numpy.asarray(stepsize, dtype=dtype))
+            params = [_stepsize]
+
+        if _stepsize.type.ndim != 0:
+            raise TypeError('stepsize must be a scalar', stepsize)
+
+        rval = cls(inputs, cost, _stepsize, ginputs, params)
+
+        # if we allocated a shared variable for the stepsize, 
+        # put it into the stepsize attribute.
+        if params:
+            rval.stepsize = _stepsize
+
+        return rval
+
+
+class HalflifeStopper(object):
+    """An early-stopping crition.
+
+    This object will track the progress of a dynamic quantity along some noisy U-shaped
+    trajectory.
+
+    The heuristic used is to first iterate at least `initial_wait` times, while looking at the
+    score.  If at any point thereafter, the score hasn't made a *significant* improvement in the
+    second half  of the entire run, the run is declared *not*-`promising`.
+
+    Significant improvement in the second half of a run is defined as achieving
+    `progresh_thresh` proportion of the best score from the first half of the run.
+
+    .. code-block:: python
+
+        stopper = HalflifeStopper()
+        ...
+        while (...):
+            stopper.step(score)
+            if m.stopper.best_updated:
+                # this is the best score we've seen yet
+            if not m.stopper.promising:
+                # we haven't seen a good score in a long time,
+                # and the stopper recommends giving up.
+                break
+
+    """
+    def __init__(self, 
+            initial_wait=20,
+            patience_factor=2.0,
+            progress_thresh=0.99 ):
+        """
+        :param method:
+        :param method_output_idx:
+        :param initial_wait:
+        :param patience_factor:
+        :param progress_thresh:
+        """
+        #constants
+        self.progress_thresh = progress_thresh
+        self.patience_factor = patience_factor
+        self.initial_wait = initial_wait
+
+        #dynamic variables
+        self.iter = 0
+        self.promising = True
+
+        self.halflife_iter = -1
+        self.halflife_value = float('inf')
+        self.halflife_updated = False
+
+        self.best_iter = -1
+        self.best_value = float('inf')
+        self.best_updated = False
+
+
+    def step(self, value):
+        if value < (self.halflife_value * self.progress_thresh):
+            self.halflife_updated = True
+            self.halflife_value = value
+            self.halflife_iter = self.iter
+        else:
+            self.halflife_updated = False
+
+        if value < self.best_value:
+            self.best_updated = True
+            self.best_value = value
+            self.best_iter = self.iter
+        else:
+            self.best_updated = False
+
+        self.promising = (self.iter < self.initial_wait) \
+                or (self.iter < (self.halflife_iter * self.patience_factor))
+        self.iter += 1
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/sigmoidal_layer.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,48 @@
+""" Provide the "normal" sigmoidal layers for making multi-layer perceptrons / neural nets
+
+"""
+import logging
+import numpy
+
+import theano
+from theano import tensor
+from theano.compile.sandbox import shared, pfunc
+from .util import update_locals, add_logging
+from .squash import squash
+
+
+class SigmoidalLayer(object):
+    def __init__(self, input, w, b, squash_fn, params):
+        """
+        :param input: a symbolic tensor of shape (n_examples, n_in)
+        :param w: a symbolic weight matrix of shape (n_in, n_out)
+        :param b: symbolic bias terms of shape (n_out,)
+        :param squash: an squashing function
+        """
+        output = squash_fn(tensor.dot(input, w) + b)
+        update_locals(self, locals())
+
+    @classmethod
+    def new(cls, rng, input, n_in, n_out, squash_fn=tensor.tanh, dtype=None):
+        """Allocate a SigmoidLayer with weights to transform inputs with n_in dimensions, 
+        to outputs of n_out dimensions.  
+
+        Weights are initialized randomly using rng.
+
+        :param squash_fn: an op constructor function, or a string that has been registed as a
+        `squashing_function`.
+
+        :param dtype: the numerical type to use for the parameters (i.e. 'float32', 'float64')
+
+        """
+        if dtype is None:
+            dtype = input.dtype
+        cls._debug('allocating weights and biases', n_in, n_out, dtype)
+        w = shared(
+                numpy.asarray(
+                    rng.uniform(low=-2/numpy.sqrt(n_in), high=2/numpy.sqrt(n_in),
+                    size=(n_in, n_out)), dtype=dtype))
+        b = shared(numpy.asarray(numpy.zeros(n_out), dtype=dtype))
+        return cls(input, w, b, squash(squash_fn), [w,b])
+
+add_logging(SigmoidalLayer)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/squash.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,24 @@
+"""Provides a convenient lookup mechanism for squashing functions.
+"""
+_squash_dct = {}
+
+def squashing_function(f):
+    _squash_dct[f.__name__] = f
+    return f
+
+def squash(s):
+    try:
+        return _squash_dct[s]
+    except:
+        return s
+
+## initialize the dct
+
+import theano
+import theano.tensor.nnet
+import theano.sandbox.softsign
+
+_squash_dct['tanh'] = theano.tensor.tanh
+_squash_dct['sigmoid'] = theano.tensor.nnet.sigmoid
+_squash_dct['softsign'] = theano.sandbox.softsign.softsign
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/tests/test_kording2004.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,24 @@
+
+from pylearn.shared.layers.kording2004 import *
+
+def test_cov_sum_of_squares():
+    z = numpy.random.RandomState(5234).randn(15, 30)
+    z -= numpy.mean(z, axis=0)
+    z /= numpy.std(z, axis=0)
+
+    cov_z = numpy.cov(z, rowvar=0)
+    print cov_z.shape
+
+    real_val = numpy.sum(numpy.cov(z.T)**2)
+
+    s = theano.tensor.dmatrix()
+    tall_val = theano.function([s], cov_sum_of_squares(s, 'tall'))(z)
+    fat_val = theano.function([s], cov_sum_of_squares(s, 'fat'))(z)
+
+    print real_val
+    print tall_val
+    print fat_val
+
+    assert numpy.allclose(real_val, tall_val)
+    assert numpy.allclose(real_val, fat_val)
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/tests/test_kouh2008.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,107 @@
+import numpy
+import theano.compile.debugmode
+from theano import tensor
+from theano.compile.sandbox import pfunc
+from pylearn.shared.layers import LogisticRegression, Kouh2008
+
+def test_dtype():
+    n_in = 10
+    n_out = 10
+    n_terms = 3
+    rng = numpy.random.RandomState(23455)
+    layer = Kouh2008.new_filters(rng, tensor.dmatrix(), n_in, n_out, n_terms, dtype='float64')
+    assert layer.output.dtype =='float64'
+    layer = Kouh2008.new_filters(rng, tensor.fmatrix(), n_in, n_out, n_terms, dtype='float32')
+    assert layer.output.dtype =='float32'
+
+def run_w_random(bsize=10, n_iter=200, n_in = 1024, n_out = 100, n_terms=2, dtype='float64'):
+    x = tensor.dmatrix()
+    y = tensor.lvector()
+    rng = numpy.random.RandomState(23455)
+
+    layer = Kouh2008.new_filters(rng, x, n_in, n_out, n_terms, dtype='float64')
+    out = LogisticRegression.new(layer.output, n_out, 2)
+    cost = out.nll(y).sum()
+
+    #isolated optimization
+    for ii in xrange(len(layer.params)):
+        params = out.params+ [layer.params[ii]]
+        print 'PARAMS', params
+        updates = [(p, p - numpy.asarray(0.001, dtype=dtype)*gp) for p,gp in zip(params, tensor.grad(cost, params)) ]
+        print 'COMPILING'
+        f = pfunc([x, y], cost, updates=updates)
+        print 'DONE'
+        if False:
+            for i, n in enumerate(f.maker.env.toposort()):
+                print i, n
+
+        xval = numpy.asarray(rng.rand(bsize, n_in), dtype=dtype)
+        yval = numpy.asarray(rng.randint(0,2,bsize), dtype='int64')
+        f0 = f(xval, yval)
+        for i in xrange(n_iter):
+            fN = f(xval, yval)
+            assert fN  < f0
+            f0 = fN
+            if 0 ==  i % 5: print i, 'rval', fN
+
+    return fN
+
+def test_A(bsize=10, n_iter=2, n_in = 10, n_out = 10, n_terms=2, dtype='float64'):
+
+    x = tensor.dmatrix()
+    y = tensor.lvector()
+    rng = numpy.random.RandomState(23455)
+
+    layer = Kouh2008.new_filters(rng, x, n_in, n_out, n_terms, dtype='float64')
+    out = LogisticRegression.new(layer.output, n_out, 2)
+    cost = out.nll(y).sum()
+    #joint optimization except for one of the linear filters
+    out.w.value += 0.1 * rng.rand(*out.w.value.shape)
+    params = layer.params[:-2]
+    mode = None
+    updates = [(p, p - numpy.asarray(0.001, dtype=dtype)*gp) for p,gp in zip(params, tensor.grad(cost, params)) ]
+    for p, newp in updates:
+        if p is layer.r:
+            theano.compile.debugmode.debugprint(newp, depth=5)
+    f = pfunc([x, y], [cost], mode, updates=updates)
+    env_r = f.maker.env.inputs[9]
+    order = f.maker.env.toposort()
+
+    assert str(f.maker.env.outputs[6].owner.inputs[0]) == 'r'
+    assert str(f.maker.env.inputs[9]) == 'r'
+    assert f.maker.env.outputs[6].owner.inputs[0] is env_r
+    assert (f.maker.env.outputs[6].owner,0) in env_r.clients
+
+    if False:
+        for i, n in enumerate(f.maker.env.toposort()):
+            print i, n, n.inputs
+
+    xval = numpy.asarray(rng.rand(bsize, n_in), dtype=dtype)
+    yval = numpy.asarray(rng.randint(0,2,bsize), dtype='int64')
+    for i in xrange(n_iter):
+        fN = f(xval, yval)
+        if 0 == i:
+            f0 = fN
+        #if 0 ==  i % 5: print i, 'rval', fN
+        print i, 'rval', fN
+        for p0 in params:
+            for p1 in params:
+                assert p0 is p1 or not numpy.may_share_memory(p0.value, p1.value)
+        assert not numpy.may_share_memory(layer.r.value, xval)
+    print 'XVAL SUM', xval.sum(), layer.r.value.sum()
+
+    assert f0 > 6
+    assert fN < f0 # TODO: assert more improvement
+
+if __name__ == '__main__':
+    test_A()
+
+def test_smaller():
+    assert run_w_random(n_in=10, n_out=8) < 6.1
+
+def test_smaller32():
+    assert run_w_random(n_in=10, n_out=8, dtype='float32') < 6.1
+
+def test_big():
+    assert run_w_random() < 0.1
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/tests/test_lecun1998.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,44 @@
+from pylearn.shared.layers.lecun1998 import *
+from pylearn.shared.layers import LogisticRegression
+import theano.sandbox.softsign
+
+def test_w_random(bsize=10, n_iter=100, dtype='float64'):
+    ishape=(28,28)
+    fshape=(5,5)
+    if dtype == 'float64':
+        x = tensor.dtensor4()
+    else:
+        x = tensor.ftensor4()
+    y = tensor.lvector()
+    rng = numpy.random.RandomState(23455)
+
+    layer = LeNetConvPool.new(rng, x, bsize, 1, ishape, 6, fshape, (2,2))
+    out = LogisticRegression.new(layer.output.flatten(2), 6*144, 2)
+    cost = out.nll(y).sum()
+    params = out.params+layer.params
+    updates = [(p, p - numpy.asarray(0.01, dtype=dtype)*gp) for p,gp in zip(params, tensor.grad(cost, params)) ]
+    f = pfunc([x, y], cost, updates=updates)
+    if True:
+        for i, n in enumerate(f.maker.env.toposort()):
+            print i, n
+
+    xval = numpy.asarray(rng.rand(bsize, 1, ishape[0], ishape[1]), dtype=dtype)
+    yval = numpy.asarray(rng.randint(0,2,bsize), dtype='int64')
+    f0 = f(xval, yval)
+    for i in xrange(n_iter):
+        fN = f(xval, yval)
+        print i, 'rval', fN
+
+    assert f0 > 6
+    assert fN < .3
+
+
+def test_squash():
+    ishape=(28,28)
+    fshape=(5,5)
+    x = tensor.ftensor4()
+    y = tensor.lvector()
+    rng = numpy.random.RandomState(23455)
+
+    layer = LeNetConvPool.new(rng, x, 10, 1, ishape, 6, fshape, (2,2), squash_fn='softsign')
+    assert layer.squash_op == theano.sandbox.softsign.softsign
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/tests/test_sigmoidal_layer.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,41 @@
+import numpy
+from pylearn.shared.layers import SigmoidalLayer, LogisticRegression
+from theano import tensor
+from theano.compile.sandbox import shared, pfunc
+
+def test_w_random(dtype='float64'):
+    if dtype == 'float64':
+        x = tensor.dmatrix()
+    else:
+        x = tensor.fmatrix()
+    y = tensor.lvector()
+    rng = numpy.random.RandomState(23455)
+
+    bsize=10
+    n_in = 10
+    n_hid = 12
+    n_out = 2
+    n_iter=100
+
+    layer = SigmoidalLayer.new(rng, x, n_in, n_hid, squash_fn='tanh', dtype=dtype)
+    out = LogisticRegression.new(layer.output, n_hid, 2)
+    cost = out.nll(y).sum()
+    params = out.params+layer.params
+    updates = [(p, p - numpy.asarray(0.01, dtype=dtype)*gp) for p,gp in zip(params, tensor.grad(cost, params)) ]
+    f = pfunc([x, y], cost, updates=updates)
+
+    w0 = layer.w.value.copy()
+    b0 = layer.b.value.copy()
+
+    xval = numpy.asarray(rng.rand(bsize, n_in), dtype=dtype)
+    yval = numpy.asarray(rng.randint(0,2,bsize), dtype='int64')
+    f0 = f(xval, yval)
+    for i in xrange(n_iter):
+        fN = f(xval, yval)
+        print i, 'rval', fN
+
+    assert f0 > 6
+    assert fN < 2 
+
+    assert numpy.all(w0 != layer.w.value)
+    assert numpy.all(b0 != layer.b.value)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/shared/layers/util.py	Fri Oct 16 12:14:43 2009 -0400
@@ -0,0 +1,45 @@
+"""A few little internal functions"""
+import logging
+
+def update_locals(obj, dct):
+    if 'self' in dct:
+        del dct['self']
+    obj.__dict__.update(dct)
+
+def LogFn(f):
+    def rval(*args):
+        f(' '.join(str(a) for a in args))
+    return staticmethod(rval)
+
+def add_logging(cls, name=None, level=None):
+    """ Add logging functions to a class: self._debug, self._info, self._warn, self._warning,
+    self._error.
+
+    All of these functions has the same signature.  They accept a variable number of positional
+    arguments, and print them all casted to string (and concatenated with a ' '.)
+
+    :type name: str
+    :param name: the name of the logger.
+
+    :type level: None, str, type(logging.INFO)
+    :param level: a logging level (e.g. logging.INFO), or the name of a logging level (e.g
+    'INFO').  If level is None, then this function doesn't set the logging level.
+
+    """
+    if name is None:
+        name = "layers.%s" % cls.__name__
+    cls._logger = logging.getLogger(name)
+    if level:
+        try:
+            level = getattr(logging, level)
+        except:
+            pass
+        cls._logger.setLevel(level)
+
+    print 'adding loggers to ', cls
+    cls._debug = LogFn(cls._logger.debug)
+    cls._info = LogFn(cls._logger.info)
+    cls._warn = cls._warning = LogFn(cls._logger.warn)
+    cls._error = LogFn(cls._logger.error)
+    cls._critical = LogFn(cls._logger.critical)
+    cls._fatal = LogFn(cls._logger.fatal)