view algorithms/aa.py @ 520:82bafb80ba65

merge
author Joseph Turian <turian@iro.umontreal.ca>
date Fri, 14 Nov 2008 02:09:23 -0500
parents 8fcd0f3d9a17
children
line wrap: on
line source


import theano
from theano import tensor as T
from theano.tensor import nnet as NN
import numpy as N

class AutoEncoder(theano.FancyModule):

    def __init__(self, input = None, regularize = True, tie_weights = True):
        super(AutoEncoder, self).__init__()

        # MODEL CONFIGURATION
        self.regularize = regularize
        self.tie_weights = tie_weights

        # ACQUIRE/MAKE INPUT
        if not input:
            input = T.matrix('input')
        self.input = theano.External(input)

        # HYPER-PARAMETERS
        self.lr = theano.Member(T.scalar())

        # PARAMETERS
        self.w1 = theano.Member(T.matrix())
        if not tie_weights:
            self.w2 = theano.Member(T.matrix())
        else:
            self.w2 = self.w1.T
        self.b1 = theano.Member(T.vector())
        self.b2 = theano.Member(T.vector())

        # HIDDEN LAYER
        self.hidden_activation = T.dot(input, self.w1) + self.b1
        self.hidden = self.build_hidden()

        # RECONSTRUCTION LAYER
        self.output_activation = T.dot(self.hidden, self.w2) + self.b2
        self.output = self.build_output()

        # RECONSTRUCTION COST
        self.reconstruction_cost = self.build_reconstruction_cost()

        # REGULARIZATION COST
        self.regularization = self.build_regularization()

        # TOTAL COST
        self.cost = self.reconstruction_cost
        if self.regularize:
            self.cost = self.cost + self.regularization

        # GRADIENTS AND UPDATES
        if self.tie_weights:
            self.params = self.w1, self.b1, self.b2
        else:
            self.params = self.w1, self.w2, self.b1, self.b2
        gradients = T.grad(self.cost, self.params)
        updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gradients))

        # INTERFACE METHODS
        self.update = theano.Method(input, self.cost, updates)
        self.reconstruction = theano.Method(input, self.output)
        self.representation = theano.Method(input, self.hidden)

    def _instance_initialize(self, obj, input_size = None, hidden_size = None, seed = None, **init):
        if (input_size is None) ^ (hidden_size is None):
            raise ValueError("Must specify hidden_size and target_size or neither.")
        super(AutoEncoder, self)._instance_initialize(obj, **init)
        if seed is not None:
            R = N.random.RandomState(seed)
        else:
            R = N.random
        if input_size is not None:
            sz = (input_size, hidden_size)
            range = 1/N.sqrt(input_size)
            obj.w1 = R.uniform(size = sz, low = -range, high = range)
            if not self.tie_weights:
                obj.w2 = R.uniform(size = list(reversed(sz)), low = -range, high = range)
            obj.b1 = N.zeros(hidden_size)
            obj.b2 = N.zeros(input_size)

    def build_regularization(self):
        return T.zero() # no regularization!


class SigmoidXEAutoEncoder(AutoEncoder):

    def build_hidden(self):
        return NN.sigmoid(self.hidden_activation)

    def build_output(self):
        return NN.sigmoid(self.output_activation)

    def build_reconstruction_cost(self):
        self.reconstruction_cost_matrix = self.input * T.log(self.output) + (1.0 - self.input) * T.log(1.0 - self.output)
        self.reconstruction_costs = -T.sum(self.reconstruction_cost_matrix, axis=1)
        return T.sum(self.reconstruction_costs)

    def build_regularization(self):
        self.l2_coef = theano.Member(T.scalar())
        if self.tie_weights:
            return self.l2_coef * T.sum(self.w1 * self.w1)
        else:
            return self.l2_coef * T.sum(self.w1 * self.w1) + T.sum(self.w2 * self.w2)

    def _instance_initialize(self, obj, input_size = None, hidden_size = None, **init):
        init.setdefault('l2_coef', 0)
        super(SigmoidXEAutoEncoder, self)._instance_initialize(obj, input_size, hidden_size, **init)