changeset 871:fafe796ad5ff

merge
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 11 Nov 2009 10:47:15 -0500
parents bd7d540db70d (current diff) 2fffbfa41920 (diff)
children b2821fce15de
files
diffstat 7 files changed, 256 insertions(+), 65 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/dataset_ops/cifar10.py	Mon Nov 09 16:12:09 2009 -0500
+++ b/pylearn/dataset_ops/cifar10.py	Wed Nov 11 10:47:15 2009 -0500
@@ -13,7 +13,7 @@
 import theano
 
 from protocol import TensorFnDataset # protocol.py __init__.py
-from .memo import memo
+from .memo import memo # memo.py
 
 def _unpickle(filename, dtype):
     #implements loading as well as dtype-conversion and dtype-scaling
--- a/pylearn/dataset_ops/memo.py	Mon Nov 09 16:12:09 2009 -0500
+++ b/pylearn/dataset_ops/memo.py	Wed Nov 11 10:47:15 2009 -0500
@@ -20,6 +20,7 @@
     def forget():
         for k in cache.keys():
             del cache[k]
+    rval.cache = cache
     rval.forget = forget
     rval.__name__ = 'memo@%s'%f.__name__
     return rval
--- a/pylearn/shared/layers/__init__.py	Mon Nov 09 16:12:09 2009 -0500
+++ b/pylearn/shared/layers/__init__.py	Wed Nov 11 10:47:15 2009 -0500
@@ -14,7 +14,7 @@
 from kording2004 import Kording2004
 
 # rust2005.py
-from rust2005 import Rust2005
+from rust2005 import Rust2005, Rust2005Conv
 
 # lecun1998.py
 from lecun1998 import LeNetConvPool
--- a/pylearn/shared/layers/kouh2008.py	Mon Nov 09 16:12:09 2009 -0500
+++ b/pylearn/shared/layers/kouh2008.py	Wed Nov 11 10:47:15 2009 -0500
@@ -12,12 +12,9 @@
 
 """
 
-## optimizing this model may be difficult-- paper talks about using exponents p and q in
-# in the range 1-3, but gradient descent may overstep that range.
-
-# TODO: Use updates() to clamp exponents p and q to sensible range
 import logging
 _logger = logging.getLogger('pylearn.shared.layers.kouh2008')
+
 import numpy
 import theano
 from theano import tensor
@@ -50,26 +47,56 @@
     output - a tensor of activations of shape (n_examples, n_out)
     """
 
+    input = None #optional - symbolic variable of input
+    f_list = None # optional - list of filter shared variables
+    filter_l1 = None # optional - l1 of filters
+    filter_l2_sqr = None # optional - l2**2 of filters
+
+    exp_l1 = None
+    exp_l2_sqr = None
+
+    w_l1 = None
+    w_l2_sqr = None
+
+    p_unbounded = None
+    q_unbounded = None
+    r_unbounded = None
+    k_unbounded = None
+
+    p_range_default=(1.0, 3.0)
+    q_range_default=(1.0, 3.0)
+    r_range_default=(0.333, 1.0)
+    k_range_default=(0.0, 1.0)
+    x_range_default=(0.01, 1.0)
+
     def __init__(self, w_list, x_list, p, q, r, k, params, updates, eps=1.0e-6):
         """Transcription of equation 2.1 from paper (page 1434).
         """
         if len(w_list) != len(x_list):
             raise ValueError('w_list must have same len as x_list')
-        output = (sum(w * tensor.pow(x, p) for (w,x) in zip(w_list, x_list)))\
-                / (numpy.asarray(eps, dtype=k.type.dtype) + k + tensor.pow(sum(tensor.pow(x, q) for x in x_list), r))
+        numerator = sum(w_i * tensor.pow(x_i, p) for (w_i,x_i) in zip(w_list, x_list))
+        denominator = k + tensor.pow(sum(tensor.pow(x_i, q) for x_i in x_list), r)
+        output = numerator / (numpy.asarray(eps, dtype=k.type.dtype) + denominator)
 
         assert output.type.ndim == 2
         update_locals(self, locals())
         _logger.debug('output dtype %s' % output.dtype)
 
     @classmethod
-    def new_expbounds(cls, rng, x_list, n_out, dtype=None, params=[], updates=[], exponent_range=(1.0, 3.0)):
+    def new_expbounds(cls, rng, x_list, n_out, dtype=None, params=[], updates=[], 
+            p_range=p_range_default,
+            q_range=q_range_default,
+            r_range=r_range_default,
+            k_range=k_range_default,
+            ):
         """
         """
         if dtype is None:
             dtype = x_list[0].dtype
         n_terms = len(x_list)
 
+        new_params = []
+
         def shared_uniform(low, high, size, name): 
             return _shared_uniform(rng, low, high, size, dtype, name)
 
@@ -81,40 +108,32 @@
             w_list = [w_sm[:,i] for i in xrange(n_terms)]
             w_l1 = abs(w).sum()
             w_l2_sqr = (w**2).sum()
+            new_params.append(w)
         else:
-            w_list = [shared_uniform(low=-2.0/n_terms, high=2.0/n_terms, size=(n_out,), name='w_%i'%i)
+            w_list = [shared_uniform(low=-2.0/n_terms, high=2.0/n_terms, size=(n_out,), name='Kouh2008::w_%i'%i)
                     for i in xrange(n_terms)]
             w_l1 = sum(abs(wi).sum() for wi in w_list)
             w_l2_sqr = sum((wi**2).sum() for wi in w_list)
-
-        e_range_low, e_range_high = exponent_range
-        e_range_low = numpy.asarray(e_range_low, dtype=dtype)
-        e_range_high = numpy.asarray(e_range_high, dtype=dtype)
-        e_range_mag = e_range_high - e_range_low
-        if e_range_mag < 0:
-            raise ValueError('exponent range must have low <= high')
+            new_params.extend(w_list)
 
         p_unbounded = shared_uniform(low=-0.1, high=0.1, size=(n_out,), name='p')
         q_unbounded = shared_uniform(low=-0.1, high=0.1, size=(n_out,), name='q') 
         r_unbounded = shared_uniform(low=-0.1, high=0.1, size=(n_out,), name='r')
-        k_unbounded = shared_uniform(low=-0.2, high=0.2, size=(n_out,), name='k') # biases
+        k_unbounded = shared_uniform(low=-0.1, high=0.1, size=(n_out,), name='k') # biases
+        new_params.extend([p_unbounded, q_unbounded, r_unbounded, k_unbounded])
 
-        p = tensor.nnet.sigmoid(p_unbounded) * e_range_mag + e_range_low
-        q = tensor.nnet.sigmoid(q_unbounded) * e_range_mag + e_range_low
-        r = tensor.nnet.sigmoid(r_unbounded) * \
-                numpy.asarray(1.0/e_range_low - 1.0/e_range_high, dtype=dtype) \
-                + numpy.asarray(1.0/e_range_high, dtype=dtype)
+        def d(a):
+            return numpy.asarray(a, dtype=dtype)
 
-        k = softsign(k_unbounded)
+        p = softsign(p_unbounded) * d(p_range[1] - p_range[0]) + d(p_range[0])
+        q = softsign(q_unbounded) * d(q_range[1] - q_range[0]) + d(q_range[0])
+        r = softsign(r_unbounded) * d(r_range[1] - r_range[0]) + d(r_range[0])
+        k = softsign(k_unbounded) * d(k_range[1] - k_range[0]) + d(k_range[0])
 
-        if use_softmax_w:
-            rval = cls(w_list, x_list, p, q, r, k,
-                    params = [p_unbounded, q_unbounded, r_unbounded, k, w] + params,
-                    updates=updates)
-        else:
-            rval = cls(w_list, x_list, p, q, r, k,
-                    params = [p_unbounded, q_unbounded, r_unbounded, k_unbounded] + w_list + params,
-                    updates=updates)
+        rval = cls(w_list, x_list, p, q, r, k,
+                params = params + new_params,
+                updates=updates)
+
         rval.p_unbounded = p_unbounded
         rval.q_unbounded = q_unbounded
         rval.r_unbounded = r_unbounded
@@ -126,9 +145,14 @@
         return rval
 
     @classmethod
-    def new_filters_expbounds(cls, rng, input, n_in, n_out, n_terms, dtype=None, eps=1e-1,
-            exponent_range=(1.0, 3.0), filter_range=1.0):
-        """Return a KouhLayer instance with random parameters
+    def new_filters_expbounds(cls, rng, input, n_in, n_out, n_terms, dtype=None,
+            p_range=p_range_default,
+            q_range=q_range_default,
+            r_range=r_range_default,
+            k_range=k_range_default,
+            x_range=x_range_default,
+            ):
+        """Return a Kouh2008 instance with random parameters
 
         The parameters are drawn on a range [typically] suitable for fine-tuning by gradient
         descent. 
@@ -145,10 +169,27 @@
         :param nterms: each (of n_out) complex-cell firing rate will be determined from this
         many 'simple cell' responses.
 
-        :param eps: this amount is added to the softplus of filter responses as a baseline
-        firing rate (that prevents a subsequent error from ``pow(0, p)``) 
+        :param eps: this amount is added to the filter responses as a baseline
+        firing rate (that prevents a subsequent error from ``pow(0, p)``)
+        The eps must be large enough so that eps**p_range[1] does not underflow.
+
+        :param p_range: See `new_expbounds`.
+        :type p_range: tuple([low, high])
+
+        :param q_range: See `new_expbounds`.
+        :type q_range: tuple([low, high])
 
-        :returns: KouhLayer instance with freshly-allocated random weights.
+        :param r_range: See `new_expbounds`.
+        :type r_range: tuple([low, high])
+
+        :param k_range: See `new_expbounds`.
+        :type k_range: tuple([low, high])
+
+        :param x_range: Filter responses are affine-transformed softsigns lying between these
+        values.
+        :type x_range: tuple([low, high])
+
+        :returns: Kouh2008 instance with freshly-allocated random weights.
 
         """
         if input.type.ndim != 2:
@@ -161,19 +202,30 @@
         def shared_uniform(low, high, size, name): 
             return _shared_uniform(rng, low, high, size, dtype, name)
 
-        f_list = [shared_uniform(low=-2.0/numpy.sqrt(n_in), high=2.0/numpy.sqrt(n_in), size=(n_in, n_out), name='f_%i'%i)
+        f_list = [shared_uniform(low=-2.0/numpy.sqrt(n_in), high=2.0/numpy.sqrt(n_in), 
+            size=(n_in, n_out), name='Kouh2008::f_%i'%i)
+                for i in xrange(n_terms)]
+
+        b_list = [shared_uniform(low=0, high=.01,
+            size=(n_out,), name='Kouh::2008::b_%i'%i)
                 for i in xrange(n_terms)]
 
-        b_list = [shared_uniform(low=0, high=.01, size=(n_out,), name='b_%i'%i)
-                for i in xrange(n_terms)]
-        #x_list = [numpy.asarray(eps, dtype=dtype)+softplus(tensor.dot(input, f_list[i])) for i in xrange(n_terms)]
-        filter_range = numpy.asarray(filter_range, dtype=dtype)
-        half_filter_range = numpy.asarray(filter_range/2, dtype=dtype)
-        x_list = [numpy.asarray(filter_range + eps, dtype=dtype)+half_filter_range *softsign(tensor.dot(input, f_list[i]) +
-            b_list[i]) for i in xrange(n_terms)]
+        def d(a):
+            return numpy.asarray(a, dtype=dtype)
+
+        x_low = d(x_range[0])
+        x_high = d(x_range[1])
+
+        #softsign's range is (-1, 1)
+        # we want filter responses to span (x_low, x_high)
+        x_list = [x_low + (x_high-x_low)*(d(0.5) + d(0.5)*softsign(tensor.dot(input, f_list[i])+b_list[i]))
+                    for i in xrange(n_terms)]
 
         rval = cls.new_expbounds(rng, x_list, n_out, dtype=dtype, params=f_list + b_list,
-                exponent_range=exponent_range)
+                p_range=p_range,
+                q_range=q_range,
+                r_range=r_range,
+                k_range=k_range)
         rval.f_list = f_list
         rval.input = input #add the input to the returned object
         rval.filter_l1 = sum(abs(fi).sum() for fi in f_list)
@@ -182,6 +234,8 @@
 
     def img_from_weights(self, rows=None, cols=None, row_gap=1, col_gap=1, eps=1e-4):
         """ Return an image that visualizes all the weights in the layer.
+
+        WRITEME: how does the image relate to the weights
         """
 
         n_in, n_out = self.f_list[0].value.shape
--- a/pylearn/shared/layers/lecun1998.py	Mon Nov 09 16:12:09 2009 -0500
+++ b/pylearn/shared/layers/lecun1998.py	Wed Nov 11 10:47:15 2009 -0500
@@ -96,8 +96,12 @@
         w_shp = (n_filters, n_imgs) + filter_shape
         b_shp = (n_filters,)
 
-        w = shared(numpy.asarray(rng.uniform(low=-.05, high=.05, size=w_shp), dtype=dtype))
-        b = shared(numpy.asarray(rng.uniform(low=-.05, high=.05, size=b_shp), dtype=dtype))
+        #TODO: make w_range a parameter to new as well?
+        w_range = (-1.0 / numpy.sqrt(filter_shape[0] * filter_shape[1] * n_imgs),
+                   1.0 / numpy.sqrt(filter_shape[0] * filter_shape[1] * n_imgs))
+
+        w = shared(numpy.asarray(rng.uniform(low=w_range[0], high=w_range[1], size=w_shp), dtype=dtype))
+        b = shared(numpy.asarray(rng.uniform(low=-.0, high=0., size=b_shp), dtype=dtype))
 
         if isinstance(squash_fn, str):
             squash_fn = squash(squash_fn)
--- a/pylearn/shared/layers/rust2005.py	Mon Nov 09 16:12:09 2009 -0500
+++ b/pylearn/shared/layers/rust2005.py	Wed Nov 11 10:47:15 2009 -0500
@@ -4,6 +4,10 @@
 
 This layer implements a model of simple and complex cell firing rate responses.
 
+
+:TODO: implement full model with variable exponents.  The current implementation fixes internal
+exponents to 2 and the external exponent to 1/2.
+
 """
 
 import numpy
@@ -18,25 +22,31 @@
 from theano.compile.sandbox import shared
 from theano.sandbox.softsign import softsign
 from theano.tensor.nnet import softplus
+from theano.sandbox.conv import ConvOp
 
 from .util import update_locals, add_logging
 
-def rust2005_act_from_filters(linpart, E_quad, S_quad):
+def rust2005_act_from_filters(linpart, E_quad, S_quad, eps):
+    """Return rust2005 activation from linear filter responses, as well as E and S terms
+
+    :param linpart: a single tensor of linear filter responses
+    :param E_quad: a list of tensors of linear filter responses
+    :param S_quad: a list of tensors of linear filter responses
+    :param eps: a scalar to add to the sum of squares before the sqrt
+
+    """
+    if isinstance(E_quad, theano.Variable):
+        raise TypeError('E_quad should be a list of Variables, not a Variable itself')
+    if isinstance(S_quad, theano.Variable):
+        raise TypeError('E_quad should be a list of Variables, not a Variable itself')
     sqrt = theano.tensor.sqrt
     softlin = theano.tensor.nnet.softplus(linpart)
-    E = sqrt(sum([E_quad_i**2 for E_quad_i in E_quad] + [1e-8, softlin**2]))
-    S = sqrt(sum([S_quad_i**2 for S_quad_i in S_quad] + [1e-8]))
-    return (E-S) / (1+E+S)
+    E = sqrt(sum([E_quad_i**2 for E_quad_i in E_quad] + [softlin**2], eps))
+    S = sqrt(sum([S_quad_i**2 for S_quad_i in S_quad], eps))
+    return (E-S) / (1+E+S), E, S
 
 class Rust2005(object):
-    """
-    shared variable version.
-
-    w is 3-tensor n_in x n_out x (1+n_E_quadratic + n_S_quadratic)
-
-    w 
-
-    """
+    """Energy-like complex cell activation function described in Rust et al. 2005 """
     #logging methods come from the add_logging() call below
     # _info, _debug, _warn, _error, _fatal
 
@@ -99,7 +109,7 @@
         b = shared(numpy.zeros((n_out,), dtype=dtype))
         return cls(input, w, b, n_out, n_E, n_S, epsilon, filter_shape, [w,b])
 
-    def img_from_weights(self, rows=12, cols=24, row_gap=1, col_gap=1, eps=1e-4):
+    def img_from_weights(self, rows=12, cols=25, row_gap=1, col_gap=1, eps=1e-4, triplegap=0):
         """ Return an image that visualizes all the weights in the layer.
 
         The current implentation returns a tiling in which every triple of columns is a logical
@@ -110,9 +120,12 @@
         """
         if cols % 3: #because there are three kinds of filters: linear, excitatory, inhibitory
             raise ValueError("cols must be multiple of 3")
+
+        n_triples = cols / 3
+
         filter_shape = self.filter_shape
         height = rows * (row_gap + filter_shape[0]) - row_gap
-        width = cols * (col_gap + filter_shape[1]) - col_gap
+        width = cols * (col_gap + filter_shape[1]) - col_gap + (n_triples-1) * triplegap
 
         out_array = numpy.zeros((height, width, 3), dtype='uint8')
 
@@ -123,10 +136,14 @@
         for r in xrange(rows):
             out_r_low = r*(row_gap + filter_shape[0])
             out_r_high = out_r_low + filter_shape[0]
+            extra_col_gap = 0 # a counter we'll use for the triplegap
             for c in xrange(cols):
-                out_c_low = c*(col_gap + filter_shape[1])
+                if c and (c%3==0):
+                    extra_col_gap += triplegap
+                out_c_low = c*(col_gap + filter_shape[1]) + extra_col_gap
                 out_c_high = out_c_low + filter_shape[1]
                 out_tile = out_array[out_r_low:out_r_high, out_c_low:out_c_high,:]
+                assert out_tile.shape[1] == filter_shape[1]
 
                 if c % 3 == 0: # linear filter
                     if w_col < w.shape[1]:
@@ -148,3 +165,119 @@
                         w_col += self.n_S_quadratic
         return Image.fromarray(out_array, 'RGB')
 add_logging(Rust2005)
+
+class Rust2005Conv(object):
+    """Convolutional version of `Rust2005`
+
+    :note:
+    The layer doesn't contain an option for downsampling. It makes sense to downsample the output
+    using DownsampleMaxPool, but downsampling is orthogonal to the behaviour of this
+    layer so it is not included.
+    
+    """
+
+    l1 = 0.0
+    l2_sqr = 0.0
+
+    eps=1e-6 # default epsilon to prevent sqrt(0) in quadratic filters
+
+    image_shape = None
+    filter_shape = None
+    output_shape = None
+    output_channels = None
+    output_examples = None
+
+    def __init__(self, linpart, Es, Ss, params, eps=eps):
+        """
+        """
+        eps = numpy.asarray(self.eps, linpart.dtype)
+        output, E, S = rust2005_act_from_filters(linpart, Es, Ss, eps)
+
+        update_locals(self, locals())
+
+
+    @classmethod
+    def new(cls, rng, input, image_shape, filter_shape, n_examples, n_filters, n_E, n_S,
+            n_channels=1,
+            eps=1.0e-6, dtype=None, conv_mode='valid',
+            w_range=None,
+            q_range=None
+            ):
+        """ Return Rust2005Conv layer
+
+        layer.output will be 4D tensor with shape (n_examples, n_filters, R, C) where R and C
+        depend on the image_shape, the filter_shape and the convolution mode.
+
+
+        :param rng: generator for randomized initial filters
+        :param input: symbolic input (4D tensor)
+        :type input: 4D tensor with shape (n_examples, n_channels, image_shape[0], image_shape[1])
+
+        :param image_shape:  rows, cols of every channel of every image
+        :param filter_shape: rows, cols of every filter
+        :param bsize: number of images to be treated
+        :param n_filters: number of filters (output will have this many channels)
+        :param n_channels: number of channels in each image and filter
+        :param n_E: number of squared exciting terms
+        :param n_S: number of squared inhibition terms
+        :param eps: epsilon to add to sum-of-squares in sqrt
+        :param dtype: dtype to use for new variables (Default: input.dtype)
+        :param conv_mode: convolution mode
+        :param w_range: linear weights will be drawn uniformly from this range
+        :type w_range: pair (lower_bound, upper_bound
+        :param q_range: quadratic weights will be drawn uniformly from this range
+        :type q_range: pair (lower_bound, upper_bound
+        """
+        if dtype is None:
+            dtype = input.dtype
+
+        irows, icols = image_shape
+        krows, kcols = filter_shape
+
+        conv = ConvOp((n_channels,irows, icols), (krows, kcols), n_filters, n_examples,
+                dx=1, dy=1, output_mode=conv_mode)
+
+        w_shp = (n_filters, n_channels, krows, kcols)
+        b_shp = (n_filters,)
+
+        if w_range is None:
+            w_low = -2.0/numpy.sqrt(image_shape[0] * image_shape[1] * n_channels)
+            w_high = 2.0/numpy.sqrt(image_shape[0] * image_shape[1] * n_channels)
+        else:
+            w_low, w_high = w_range
+
+        if q_range is None:
+            q_low, q_high = w_low, w_high
+        else:
+            q_low, q_high = w_range
+
+        w = shared(numpy.asarray(rng.uniform(low=w_low, high=w_high, size=w_shp), dtype=dtype))
+        b = shared(numpy.asarray(rng.uniform(low=w_low, high=w_low, size=b_shp), dtype=dtype))
+
+        E_w = [
+                shared(numpy.asarray(rng.uniform(low=q_low, high=q_high, size=w_shp), dtype=dtype))
+                for i in xrange(n_E)
+                ]
+        S_w = [
+                shared(numpy.asarray(rng.uniform(low=q_low, high=q_high, size=w_shp), dtype=dtype))
+                for i in xrange(n_S)
+                ]
+
+        rval = cls(
+                linpart=conv(input, w) + b.dimshuffle(0,'x','x'),
+                Es=[conv(input, e) for e in E_w],
+                Ss=[conv(input, s) for s in S_w],
+                params=[w,b]+E_w + S_w,
+                eps=eps)
+
+        # ignore bias in l1 (Yoshua's habit)
+        rval.l1 = sum(abs(p) for p in ([w]+E_w+S_w))
+        rval.l2_sqr = sum(p**2 for p in ([w]+E_w+S_w))
+        rval.image_shape = image_shape
+        rval.filter_shape = filter_shape
+        rval.output_shape = conv.outshp
+        rval.output_channels = n_filters # how many channels of *output*
+        rval.output_examples = n_examples
+
+        return rval
+add_logging(Rust2005Conv) # _debug, _info, _warn, _error, _fatal
--- a/pylearn/shared/layers/util.py	Mon Nov 09 16:12:09 2009 -0500
+++ b/pylearn/shared/layers/util.py	Wed Nov 11 10:47:15 2009 -0500
@@ -36,7 +36,6 @@
             pass
         cls._logger.setLevel(level)
 
-    print 'adding loggers to ', cls
     cls._debug = LogFn(cls._logger.debug)
     cls._info = LogFn(cls._logger.info)
     cls._warn = cls._warning = LogFn(cls._logger.warn)