comparison sandbox/rbm/model.py @ 405:be4209cd568f

Added weight decay
author Joseph Turian <turian@gmail.com>
date Thu, 10 Jul 2008 01:17:40 -0400
parents ffdd2c199f2a
children c2e6a8fcc35e
comparison
equal deleted inserted replaced
404:8cc11ac97087 405:be4209cd568f
77 """ 77 """
78 Update the L{Model} using one training instance. 78 Update the L{Model} using one training instance.
79 @param instance: A dict from feature index to (non-zero) value. 79 @param instance: A dict from feature index to (non-zero) value.
80 @todo: Should assert that nonzero_indices and zero_indices 80 @todo: Should assert that nonzero_indices and zero_indices
81 are correct (i.e. are truly nonzero/zero). 81 are correct (i.e. are truly nonzero/zero).
82 @todo: Multiply WEIGHT_DECAY by LEARNING_RATE, as done in Semantic Hashing?
83 @todo: Decay the biases too?
82 """ 84 """
83 minibatch = len(instances) 85 minibatch = len(instances)
84 v0 = pylearn.sparse_instance.to_vector(instances, globals.INPUT_DIMENSION) 86 v0 = pylearn.sparse_instance.to_vector(instances, globals.INPUT_DIMENSION)
85 print "old XENT:", numpy.sum(self.deterministic_reconstruction_error(v0)) 87 print "old XENT:", numpy.sum(self.deterministic_reconstruction_error(v0))
86 q0 = sigmoid(self.parameters.b + dot(v0, self.parameters.w)) 88 q0 = sigmoid(self.parameters.b + dot(v0, self.parameters.w))
90 q1 = sigmoid(self.parameters.b + dot(v1, self.parameters.w)) 92 q1 = sigmoid(self.parameters.b + dot(v1, self.parameters.w))
91 93
92 dw = LR * (dot(v0.T, h0) - dot(v1.T, q1)) / minibatch + globals.MOMENTUM * self.prev_dw 94 dw = LR * (dot(v0.T, h0) - dot(v1.T, q1)) / minibatch + globals.MOMENTUM * self.prev_dw
93 db = LR * numpy.sum(h0 - q1, axis=0) / minibatch + globals.MOMENTUM * self.prev_db 95 db = LR * numpy.sum(h0 - q1, axis=0) / minibatch + globals.MOMENTUM * self.prev_db
94 dc = LR * numpy.sum(v0 - v1, axis=0) / minibatch + globals.MOMENTUM * self.prev_dc 96 dc = LR * numpy.sum(v0 - v1, axis=0) / minibatch + globals.MOMENTUM * self.prev_dc
97
98 self.parameters.w *= (1 - globals.WEIGHT_DECAY)
95 99
96 self.parameters.w += dw 100 self.parameters.w += dw
97 self.parameters.b += db 101 self.parameters.b += db
98 self.parameters.c += dc 102 self.parameters.c += dc
99 103