view sandbox/simple_autoassociator/model.py @ 404:8cc11ac97087

Debugging simple AA a bit
author Joseph Turian <turian@gmail.com>
date Thu, 10 Jul 2008 00:51:32 -0400
parents 36baeb7125a4
children faffaae0d2f9
line wrap: on
line source

"""
The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason
Weston's sampling trick (2008).
"""

from graph import trainfn
import parameters

import globals
from globals import LR

import numpy
import random
random.seed(globals.SEED)

class Model:
    def __init__(self):
        self.parameters = parameters.Parameters(randomly_initialize=True)

    def update(self, instance):
        """
        Update the L{Model} using one training instance.
        @param instance: A dict from feature index to (non-zero) value.
        @todo: Should assert that nonzero_indices and zero_indices
        are correct (i.e. are truly nonzero/zero).
        """
        x = numpy.zeros(globals.INPUT_DIMENSION)
        for idx in instance.keys():
            x[idx] = instance[idx]

        (y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy, gh) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
        print
        print "instance:", instance
        print "x:", x
        print "OLD y:", y
        print "OLD loss (unsummed):", loss_unsummed
        print "gy:", gy
        print "gh:", gh
        print "OLD total loss:", loss
        print "gw1:", gw1
        print "gb1:", gb1
        print "gw2:", gw2
        print "gb2:", gb2

        # SGD update
        self.parameters.w1  -= LR * gw1
        self.parameters.b1  -= LR * gb1
        self.parameters.w2  -= LR * gw2
        self.parameters.b2  -= LR * gb2

        # Recompute the loss, to make sure it's descreasing
        (y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy, gh) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
        print "NEW y:", y
        print "NEW loss (unsummed):", loss_unsummed
        print "gy:", gy
        print "NEW total loss:", loss
        print "h:", h
        print self.parameters