changeset 679:05a3800389e4

adding scan1 op
author James Bergstra <bergstrj@iro.umontreal.ca>
date Tue, 07 Apr 2009 17:56:58 -0400
parents 790c5e44906c
children ade894b06471
files pylearn/lib/__init__.py pylearn/lib/scan.py pylearn/lib/test_scan.py
diffstat 2 files changed, 211 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/lib/scan.py	Tue Apr 07 17:56:58 2009 -0400
@@ -0,0 +1,143 @@
+"""Provide Scan and related functions"""
+__docformat__ = 'restructedtext en'
+
+import traceback
+import numpy
+import theano
+
+class Scan:
+    """A Theano loop 
+
+    :todo: Implement this, and rewrite `Scan1` to use `Scan`
+
+
+    """
+    def __init__(self):
+        raise NotImplementedError()
+
+
+class Scan1(theano.Op):
+    """A Theano loop over one variable
+
+    Scan1 is less general than `Scan` because it permits looping only over one tensor.
+
+    Scan1 is defined mathematically like this:
+
+    input - iterable x
+    input - y-element-like u
+    input - function x[i], y[i-1] -> y[i]
+    output - iterable y
+
+    .. code-block:: python
+        
+        #inputs
+        x #a tensor with ndim >= 1
+        u #a tensor that is like a row of y
+        f #the function to scan over x
+
+        for i in xrange(len(x)):
+            if i > 0:
+                y[i] = f(x[i], y[i-1])
+            else:
+                y[0] = f(x[0], u)
+
+        #outputs
+        y # a tensor with the same number of elements as x, 
+          # each element of which is like u (in terms of shape and dtype)
+
+    The Scan1 Op works by representing `f` by an `Env`.
+
+    :note: 
+    Internally, the representation is off-by-one wrt the documentation above.  This Op creates
+    a tensor y whose len is greater by one than x, whose first element is a copy of u.
+    The `Scan1.__call__()` returns a subtensor view of this internal vector `y` that views only
+    the len-1 last elements, so the copy of `u` is not visible.
+
+
+    :todo: 
+    Optimize for the case where y_this is not required to compute y_next.  
+    This makes all the updates possible in parallel, it also makes the `u` argument to
+    make_node un-necessary.
+
+    """
+    
+    destroy_map = {}
+    view_map = {}
+
+    def __init__(self, x_i, y_this, y_next):
+        if y_this.type != y_next.type:
+            raise TypeError('y_this and y_next must match', (y_this.type, y_next.type))
+        
+        #the Env is necessary to create the gradient Scan later.
+        self.env = theano.gof.Env(inputs=[x_i, y_this], outputs=[y_next])
+
+        #this is the function that we use in recursion in perform
+        self.fn = theano.function(self.env.inputs, self.env.outputs[0])
+
+
+    def make_node(self, x, u):
+        out_type = theano.tensor.Tensor(dtype=u.dtype, 
+                broadcastable=[False] + list(u.broadcastable))
+        return theano.Apply(self, [x,u], [out_type()])
+
+    def __call__(self, *args, **kwargs):
+        node = self.make_node(*args, **kwargs)
+        node.tag.trace = traceback.extract_stack()[:-1]
+        all_y = node.outputs[0]
+        return all_y[1:] #cut out the leading copy of u
+
+
+
+    def perform(self, node, (x,u), (y_out,)):
+        y_shape = (x.shape[0]+1,) + u.shape
+        y = numpy.empty(y_shape, dtype=u.dtype)
+
+        y[0] = u
+        for i, x_i in enumerate(x):
+            y[i+1] = self.fn(x_i, y[i])
+        y_out[0] = y
+
+    def grad(self, (x,u), (g_y,)):
+        if not hasattr(self, 'grad_op'):
+            self.grad_op = Scan1Grad(self)
+            
+        return self.grad_op(x, u, g_y)
+
+
+class Scan1Grad(theano.Op):
+    def __init__(self, scan1, inplace=False):
+        self.scan = scan1
+        self.inplace = inplace
+        if inplace:
+            self.destroy_map = {1: [3]}
+
+        xi, y_this = self.scan.env.inputs
+        y_next = self.scan.env.outputs[0]
+        gy_next = y_next.type()
+        gxi, gy_this = theano.tensor.grad(
+                y_next,
+                [xi, y_this],
+                g_cost=gy_next)
+
+        self.fn = theano.function([xi, y_this, gy_next], [gxi, gy_this])
+
+    def make_node(self, x, u, g_y):
+        y = self.scan(x,u)
+        return theano.Apply(self, [x, y, g_y], [x.type(), u.type()])
+
+    def perform(self, node, (x, y, g_y), (gx_out, gu_out)):
+        if not self.inplace:
+            g_y = g_y.copy()
+
+        gx = numpy.zeros_like(x)
+
+        for i in xrange(len(x)-1, -1, -1):
+            #print 'x y gy_next', x[i], y[i], g_y[i+1]
+            gx[i], gy_i= self.fn(x[i], y[i], g_y[i+1])
+            #print 'gx gy', gx[i], gy_i
+            g_y[i] += gy_i
+            
+        gx_out[0] = gx
+        gu_out[0] = g_y[0]
+
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/lib/test_scan.py	Tue Apr 07 17:56:58 2009 -0400
@@ -0,0 +1,68 @@
+import numpy
+import theano
+from theano.tensor import dscalar, dvector
+from scan import Scan1
+
+def test_0():
+    x_i = dscalar()
+    u = dscalar()
+
+
+    scan_add = Scan1(x_i, u, x_i + u)
+
+    x = dvector()
+
+    y = scan_add(x, u)
+
+    f = theano.function([x,u], y)
+
+    xval = numpy.asarray([1., 1, 1. , 1, 1])
+    uval = numpy.asarray(2.)
+
+    yval = f(xval, uval)
+    print yval
+
+
+def test_grad():
+    x_i = dscalar()
+    u = dscalar()
+
+    scan_add = Scan1(x_i, u, x_i + u)
+
+    x = dvector()
+
+    y = scan_add(x, u)
+
+    sum_y = theano.tensor.sum(y)
+
+    g_x = theano.tensor.grad(sum_y, x)
+    g_u = theano.tensor.grad(sum_y, u)
+
+    f = theano.function([x,u], y)
+    gf = theano.function([x, u], [g_x, g_u])
+
+    xval = numpy.asarray([1., 1, 1. , 1, 1])
+    uval = numpy.asarray(2.)
+
+    yval = f(xval, uval)
+    print 'yval', yval
+
+    gxval, guval = gf(xval, uval)
+
+    print 'gxval', gxval
+    print 'guval', guval
+
+
+def test_verify_scan_grad():
+    x_i = dvector()
+    y_prev = dvector()
+    scan_add = Scan1(x_i, y_prev, x_i + y_prev)
+
+    rng = numpy.random.RandomState(456)
+
+    xval = rng.rand(4, 3)
+    uval = rng.rand(3)
+
+    print theano.tensor.verify_grad(scan_add, (xval, uval), rng=rng)
+
+