view sandbox/simple_autoassociator/model.py @ 416:8849eba55520

Can now do minibatch update
author Joseph Turian <turian@iro.umontreal.ca>
date Fri, 11 Jul 2008 16:34:46 -0400
parents faffaae0d2f9
children 4f61201fa9a9
line wrap: on
line source

"""
The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason
Weston's sampling trick (2008).
"""

from graph import trainfn
import parameters

import globals
from globals import LR

import numpy
import random
random.seed(globals.SEED)

import pylearn.sparse_instance

class Model:
    def __init__(self):
        self.parameters = parameters.Parameters(randomly_initialize=True)

#    def deterministic_reconstruction(self, x):
#        (y, h, loss, gw1, gb1, gw2, gb2) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
#        return y

    def update(self, instances):
        """
        Update the L{Model} using one training instance.
        @param instances: A list of dict from feature index to (non-zero) value.
        @todo: Should assert that nonzero_indices and zero_indices
        are correct (i.e. are truly nonzero/zero).
        """
        minibatch = len(instances)
#        x = pylearn.sparse_instance.to_vector(instances, self.input_dimension)
        x = pylearn.sparse_instance.to_vector(instances, globals.INPUT_DIMENSION)

        (y, h, loss, gw1, gb1, gw2, gb2) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
#        print
#        print "instance:", instance
#        print "x:", x
#        print "OLD y:", y
        print "OLD total loss:", loss
#        print "gw1:", gw1
#        print "gb1:", gb1
#        print "gw2:", gw2
#        print "gb2:", gb2

        # SGD update
        self.parameters.w1  -= LR * gw1
        self.parameters.b1  -= LR * gb1
        self.parameters.w2  -= LR * gw2
        self.parameters.b2  -= LR * gb2

        # Recompute the loss, to make sure it's descreasing
        (y, h, loss, gw1, gb1, gw2, gb2) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
#        print "NEW y:", y
        print "NEW total loss:", loss
#        print "h:", h
#        print self.parameters