Mercurial > pylearn
diff sandbox/simple_autoassociator/model.py @ 393:36baeb7125a4
Made sandbox directory
author | Joseph Turian <turian@gmail.com> |
---|---|
date | Tue, 08 Jul 2008 18:46:26 -0400 |
parents | simple_autoassociator/model.py@e2cb8d489908 |
children | 8cc11ac97087 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sandbox/simple_autoassociator/model.py Tue Jul 08 18:46:26 2008 -0400 @@ -0,0 +1,57 @@ +""" +The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason +Weston's sampling trick (2008). +""" + +from graph import trainfn +import parameters + +import globals +from globals import LR + +import numpy +import random +random.seed(globals.SEED) + +class Model: + def __init__(self): + self.parameters = parameters.Parameters(randomly_initialize=True) + + def update(self, instance): + """ + Update the L{Model} using one training instance. + @param instance: A dict from feature index to (non-zero) value. + @todo: Should assert that nonzero_indices and zero_indices + are correct (i.e. are truly nonzero/zero). + """ + x = numpy.zeros(globals.INPUT_DIMENSION) + for idx in instance.keys(): + x[idx] = instance[idx] + + (y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2) + print + print "instance:", instance + print "x:", x + print "OLD y:", y + print "OLD loss (unsummed):", loss_unsummed + print "gy:", gy + print "OLD total loss:", loss + print "gw1:", gw1 + print "gb1:", gb1 + print "gw2:", gw2 + print "gb2:", gb2 + + # SGD update + self.parameters.w1 -= LR * gw1 + self.parameters.b1 -= LR * gb1 + self.parameters.w2 -= LR * gw2 + self.parameters.b2 -= LR * gb2 + + # Recompute the loss, to make sure it's descreasing + (y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2) + print "NEW y:", y + print "NEW loss (unsummed):", loss_unsummed + print "gy:", gy + print "NEW total loss:", loss + print "h:", h + print self.parameters