diff sandbox/simple_autoassociator/model.py @ 393:36baeb7125a4

Made sandbox directory
author Joseph Turian <turian@gmail.com>
date Tue, 08 Jul 2008 18:46:26 -0400
parents simple_autoassociator/model.py@e2cb8d489908
children 8cc11ac97087
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/sandbox/simple_autoassociator/model.py	Tue Jul 08 18:46:26 2008 -0400
@@ -0,0 +1,57 @@
+"""
+The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason
+Weston's sampling trick (2008).
+"""
+
+from graph import trainfn
+import parameters
+
+import globals
+from globals import LR
+
+import numpy
+import random
+random.seed(globals.SEED)
+
+class Model:
+    def __init__(self):
+        self.parameters = parameters.Parameters(randomly_initialize=True)
+
+    def update(self, instance):
+        """
+        Update the L{Model} using one training instance.
+        @param instance: A dict from feature index to (non-zero) value.
+        @todo: Should assert that nonzero_indices and zero_indices
+        are correct (i.e. are truly nonzero/zero).
+        """
+        x = numpy.zeros(globals.INPUT_DIMENSION)
+        for idx in instance.keys():
+            x[idx] = instance[idx]
+
+        (y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
+        print
+        print "instance:", instance
+        print "x:", x
+        print "OLD y:", y
+        print "OLD loss (unsummed):", loss_unsummed
+        print "gy:", gy
+        print "OLD total loss:", loss
+        print "gw1:", gw1
+        print "gb1:", gb1
+        print "gw2:", gw2
+        print "gb2:", gb2
+
+        # SGD update
+        self.parameters.w1  -= LR * gw1
+        self.parameters.b1  -= LR * gb1
+        self.parameters.w2  -= LR * gw2
+        self.parameters.b2  -= LR * gb2
+
+        # Recompute the loss, to make sure it's descreasing
+        (y, h, loss, loss_unsummed, gw1, gb1, gw2, gb2, gy) = trainfn(x, self.parameters.w1, self.parameters.b1, self.parameters.w2, self.parameters.b2)
+        print "NEW y:", y
+        print "NEW loss (unsummed):", loss_unsummed
+        print "gy:", gy
+        print "NEW total loss:", loss
+        print "h:", h
+        print self.parameters