Mercurial > pylearn
diff sparse_random_autoassociator/model.py @ 370:a1bbcde6b456
Moved sparse_random_autoassociator from my repository
author | Joseph Turian <turian@gmail.com> |
---|---|
date | Mon, 07 Jul 2008 01:54:46 -0400 |
parents | |
children | 75bab24bb2d8 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/sparse_random_autoassociator/model.py Mon Jul 07 01:54:46 2008 -0400 @@ -0,0 +1,37 @@ +""" +The model for an autoassociator for sparse inputs, using Ronan Collobert + Jason +Weston's sampling trick (2008). +""" + +from graph import trainfn +import parameters +import numpy +from globals import LR + +class Model: + def __init__(self): + self.parameters = parameters.Parameters(randomly_initialize=True) + + def update(self, instance, nonzero_indexes, zero_indexes): + xnonzero = numpy.asarray([instance[idx] for idx in nonzero_indexes]) + print + print "xnonzero:", xnonzero + + (ynonzero, yzero, loss, gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero) = trainfn(xnonzero, self.parameters.w1[nonzero_indexes, :], self.parameters.b1, self.parameters.w2[:, nonzero_indexes], self.parameters.w2[:, zero_indexes], self.parameters.b2[nonzero_indexes], self.parameters.b2[zero_indexes]) + print "OLD ynonzero:", ynonzero + print "OLD yzero:", yzero + print "OLD total loss:", loss + + # SGD update + self.parameters.w1[nonzero_indexes, :] -= LR * gw1nonzero + self.parameters.b1 -= LR * gb1 + self.parameters.w2[:, nonzero_indexes] -= LR * gw2nonzero + self.parameters.w2[:, zero_indexes] -= LR * gw2zero + self.parameters.b2[nonzero_indexes] -= LR * gb2nonzero + self.parameters.b2[zero_indexes] -= LR * gb2zero + + # Recompute the loss, to make sure it's descreasing + (ynonzero, yzero, loss, gw1nonzero, gb1, gw2nonzero, gw2zero, gb2nonzero, gb2zero) = trainfn(xnonzero, self.parameters.w1[nonzero_indexes, :], self.parameters.b1, self.parameters.w2[:, nonzero_indexes], self.parameters.w2[:, zero_indexes], self.parameters.b2[nonzero_indexes], self.parameters.b2[zero_indexes]) + print "NEW ynonzero:", ynonzero + print "NEW yzero:", yzero + print "NEW total loss:", loss