Mercurial > pylearn
comparison sandbox/rbm/model.py @ 398:6e55ccb7e2bf
Better output
author | Joseph Turian <turian@gmail.com> |
---|---|
date | Tue, 08 Jul 2008 20:48:56 -0400 |
parents | e0c9357456e0 |
children | 8796b91a9f09 |
comparison
equal
deleted
inserted
replaced
397:25a3212287cd | 398:6e55ccb7e2bf |
---|---|
28 assert v[j][i] >= 0 and v[j][i] <= 1 | 28 assert v[j][i] >= 0 and v[j][i] <= 1 |
29 if random.random() < v[j][i]: x[j][i] = 1 | 29 if random.random() < v[j][i]: x[j][i] = 1 |
30 else: x[j][i] = 0 | 30 else: x[j][i] = 0 |
31 return x | 31 return x |
32 | 32 |
33 def crossentropy(output, target): | |
34 """ | |
35 Compute the crossentropy of binary output wrt binary target. | |
36 @note: We do not sum, crossentropy is computed by component. | |
37 @todo: Rewrite as a scalar, and then broadcast to tensor. | |
38 """ | |
39 return -(target * numpy.log(output) + (1 - target) * numpy.log(1 - output)) | |
40 | |
41 | |
33 class Model: | 42 class Model: |
34 def __init__(self): | 43 def __init__(self): |
35 self.parameters = parameters.Parameters(randomly_initialize=True) | 44 self.parameters = parameters.Parameters(randomly_initialize=True) |
36 | 45 |
37 def update(self, instance): | 46 def update(self, instance): |
53 print | 62 print |
54 print "v[0]:", v0 | 63 print "v[0]:", v0 |
55 print "Q(h[0][i] = 1 | v[0]):", q0 | 64 print "Q(h[0][i] = 1 | v[0]):", q0 |
56 print "h[0]:", h0 | 65 print "h[0]:", h0 |
57 print "P(v[1][j] = 1 | h[0]):", p0 | 66 print "P(v[1][j] = 1 | h[0]):", p0 |
67 print "XENT(P(v[1][j] = 1 | h[0]) | v0):", numpy.sum(crossentropy(p0, v0)) | |
58 print "v[1]:", v1 | 68 print "v[1]:", v1 |
59 print "Q(h[1][i] = 1 | v[1]):", q1 | 69 print "Q(h[1][i] = 1 | v[1]):", q1 |
60 | 70 |
61 print self.parameters.w.shape | |
62 self.parameters.w += LR * (dot(v0.T, h0) - dot(v1.T, q1)) | 71 self.parameters.w += LR * (dot(v0.T, h0) - dot(v1.T, q1)) |
63 self.parameters.b += LR * (h0 - q1) | 72 self.parameters.b += LR * (h0 - q1) |
64 self.parameters.c += LR * (v0 - v1) | 73 self.parameters.c += LR * (v0 - v1) |
65 print self.parameters | 74 # print self.parameters |