Mercurial > pylearn
changeset 550:b52c1a8811a6
init rnn.py
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 27 Nov 2008 23:16:52 -0500 |
parents | 16894d38ce48 |
children | 7de7fa19fb9b |
files | pylearn/algorithms/rnn.py |
diffstat | 1 files changed, 193 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/algorithms/rnn.py Thu Nov 27 23:16:52 2008 -0500 @@ -0,0 +1,193 @@ + +import numpy as N +from theano import Op, Apply, tensor as T, Module, Member, Method, Mode, compile +from theano.gof import OpSub, TopoOptimizer + +from .minimizer import make_minimizer # minimizer +from theano.printing import Print +import sgd #until Olivier's module-import thing works better + +#################### +# Library-type stuff +#################### + +class TanhRnn(Op): + """ + This class implements the recurrent part of a recurrent neural network. + + There is not a neat way to include this in a more fine-grained way in Theano at the moment, + so to get something working, I'm implementing a relatively complicated Op that could be + broken down later into constituents. + + Anyway, this Op implements recursive computation of the form: + + .. latex-eqn: + z_t &= \tanh( z_{t-1} A + x_{t-1}) + + For z0 a vector, and x a TxM matrix, it returns a matrix z of shape (T+1, M), + in which z[0] = z0. + + """ + + def make_node(self, x, z0, A): + """ + :type x: matrix (each row is an x_t) (shape: (T, M)) + :type z0: vector (the first row of output) (shape: M) + :type A: matrix (M by M) + + """ + x = T.as_tensor(x) + z0 = T.as_tensor(z0) + A = T.as_tensor(A) + z = x.type() #make a new symbolic result with the same type as x + return Apply(self, [x, z0, A], [z]) + + def perform(self, node, (x,z0,A), out): + T,M = x.shape + z = N.zeros((T+1, M)) + z[0] = z0 + for i in xrange(T): + z[i+1] = N.tanh(N.dot(z[i], A) + x[i]) + out[0][0] = z + + def grad(self, (x, z0, A), (gz,)): + z = tanh_rnn(x, z0, A) + gz_incl_rnn, gx = tanh_rnn_grad(A, z, gz) + return [gx, gz_incl_rnn[0], (T.dot(z[:-1].T, gx))] +tanh_rnn = TanhRnn() + +class TanhRnnGrad(Op): + """Gradient calculation for TanhRnn""" + + def __init__(self, inplace): + self.inplace = inplace + + if self.inplace: + self.destroy_map = {0: [2]} + + def __eq__(self, other): + return (type(self) == type(other)) and (self.inplace == other.inplace) + + def __hash__(self, other): + return hash(type(self)) ^ hash(self.inplace) + + def make_node(self, A, z, gz): + return Apply(self, [A,z,gz], (z.type(), gz.type())) + + def perform(self, node, (A, z, gz), out): + Tp1,M = z.shape + T = Tp1 - 1 + gx = N.zeros((T, M)) + + if not self.inplace: + gz = gz.copy() + + for i in xrange(T-1, -1, -1): + #back through the tanh + gx[i] = gz[i+1] * (1.0 - z[i+1] * z[i+1]) + gz[i] += N.dot(A, gx[i]) + + out[0][0] = gz + out[1][0] = gx + + def __str__(self): + if self.inplace: + return 'Inplace' + super(TanhRnnGrad, self).__str__() + else: + return super(TanhRnnGrad, self).__str__() + +tanh_rnn_grad = TanhRnnGrad(inplace=False) +tanh_rnn_grad_inplace = TanhRnnGrad(inplace=True) + +compile.optdb.register('inplace_rnngrad', TopoOptimizer(OpSub(tanh_rnn_grad, tanh_rnn_grad_inplace)), 60, 'fast_run', 'inplace') + + +####################### +# Experiment-type stuff +####################### + + + +class ExampleRNN(Module): + + def __init__(self, n_vis, n_hid, n_out, minimizer): + super(ExampleRNN, self).__init__() + + def affine(weight, bias): + return (lambda a : T.dot(a, weight) + bias) + + self.n_vis = n_vis + self.n_hid = n_hid + self.n_out = n_out + + #affine transformatoin x -> latent space + self.v, self.b = Member(T.dmatrix()), Member(T.dvector()) + input_transform = affine(self.v, self.b) + + #recurrent weight matrix in latent space + self.z0 = Member(T.dvector()) + self.w = Member(T.dmatrix()) + + #affine transformation latent -> output space + self.u, self.c = Member(T.dmatrix()), Member(T.dvector()) + output_transform = affine(self.u, self.c) + + self.params = [self.v, self.b, self.w, self.u, self.c] + + #input and target + x, y = T.dmatrix(), T.dmatrix() + + z = tanh_rnn(input_transform(x), self.z0, self.w) + yhat = output_transform(z[1:]) + self.cost = T.sum((y - yhat)**2) + + self.blah = Method([x,y], self.cost) + + # using the make_minimizer protocol + self.minimizer = minimizer([x, y], self.cost, self.params) + + def _instance_initialize(self, obj): + n_vis = self.n_vis + n_hid = self.n_hid + n_out = self.n_out + + rng = N.random.RandomState(2342) + + obj.z0 = N.zeros(n_hid) + obj.v = rng.randn(n_vis, n_hid) * 0.01 + obj.b = N.zeros(n_hid) + obj.w = rng.randn(n_hid, n_hid) * 0.01 + obj.u = rng.randn(n_hid, n_out) * 0.01 + obj.c = N.zeros(n_out) + obj.minimizer.initialize() + + +def test_example_rnn(): + minimizer_fn = make_minimizer('sgd', stepsize = 0.001) + + n_vis = 5 + n_out = 3 + n_hid = 4 + rnn_module = ExampleRNN(n_vis, n_hid, n_out, minimizer_fn) + + rnn = rnn_module.make(mode='FAST_RUN') + + rng = N.random.RandomState(7722342) + x = rng.randn(10,n_vis) + y = rng.randn(10,n_out) + + #set y to be like x with a lag of LAG + LAG = 4 + y[LAG:] = x[:-LAG, 0:n_out] + + if 1: + for i, node in enumerate(rnn.minimizer.step_cost.maker.env.toposort()): + print i, node + + niter=1500 + for i in xrange(niter): + if i % 100 == 0: + print i, rnn.minimizer.step_cost(x, y), rnn.minimizer.stepsize + else: + rnn.minimizer.step_cost(x, y) +