annotate deep/rbm/rbm.py @ 367:f24b10e43a6f

correction d'un petit bug dans la fonction traduire()
author SylvainPL <sylvain.pannetier.lebeuf@umontreal.ca>
date Fri, 23 Apr 2010 11:39:55 -0400
parents 9685e9d94cc4
children d81284e13d77
rev   line source
347
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
1 """This tutorial introduces restricted boltzmann machines (RBM) using Theano.
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
2
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
3 Boltzmann Machines (BMs) are a particular form of energy-based model which
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
4 contain hidden variables. Restricted Boltzmann Machines further restrict BMs
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
5 to those without visible-visible and hidden-hidden connections.
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
6 """
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
7
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
8
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
9 import numpy, time, cPickle, gzip, PIL.Image
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
10
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
11 import theano
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
12 import theano.tensor as T
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
13 import os
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
14
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
15 from theano.tensor.shared_randomstreams import RandomStreams
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
16
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
17 from utils import tile_raster_images
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
18 from logistic_sgd import load_data
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
19
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
20
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
21 class RBM(object):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
22 """Restricted Boltzmann Machine (RBM) """
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
23 def __init__(self, input=None, n_visible=784, n_hidden=1000, \
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
24 W = None, hbias = None, vbias = None, numpy_rng = None,
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
25 theano_rng = None):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
26 """
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
27 RBM constructor. Defines the parameters of the model along with
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
28 basic operations for inferring hidden from visible (and vice-versa),
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
29 as well as for performing CD updates.
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
30
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
31 :param input: None for standalone RBMs or symbolic variable if RBM is
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
32 part of a larger graph.
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
33
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
34 :param n_visible: number of visible units
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
35
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
36 :param n_hidden: number of hidden units
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
37
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
38 :param W: None for standalone RBMs or symbolic variable pointing to a
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
39 shared weight matrix in case RBM is part of a DBN network; in a DBN,
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
40 the weights are shared between RBMs and layers of a MLP
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
41
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
42 :param hbias: None for standalone RBMs or symbolic variable pointing
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
43 to a shared hidden units bias vector in case RBM is part of a
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
44 different network
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
45
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
46 :param vbias: None for standalone RBMs or a symbolic variable
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
47 pointing to a shared visible units bias
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
48 """
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
49
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
50 self.n_visible = n_visible
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
51 self.n_hidden = n_hidden
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
52
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
53
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
54 if W is None :
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
55 # W is initialized with `initial_W` which is uniformely sampled
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
56 # from -6./sqrt(n_visible+n_hidden) and 6./sqrt(n_hidden+n_visible)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
57 # the output of uniform if converted using asarray to dtype
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
58 # theano.config.floatX so that the code is runable on GPU
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
59 initial_W = numpy.asarray( numpy.random.uniform(
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
60 low = -numpy.sqrt(6./(n_hidden+n_visible)),
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
61 high = numpy.sqrt(6./(n_hidden+n_visible)),
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
62 size = (n_visible, n_hidden)),
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
63 dtype = theano.config.floatX)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
64 # theano shared variables for weights and biases
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
65 W = theano.shared(value = initial_W, name = 'W')
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
66
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
67 if hbias is None :
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
68 # create shared variable for hidden units bias
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
69 hbias = theano.shared(value = numpy.zeros(n_hidden,
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
70 dtype = theano.config.floatX), name='hbias')
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
71
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
72 if vbias is None :
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
73 # create shared variable for visible units bias
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
74 vbias = theano.shared(value =numpy.zeros(n_visible,
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
75 dtype = theano.config.floatX),name='vbias')
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
76
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
77 if numpy_rng is None:
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
78 # create a number generator
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
79 numpy_rng = numpy.random.RandomState(1234)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
80
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
81 if theano_rng is None :
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
82 theano_rng = RandomStreams(numpy_rng.randint(2**30))
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
83
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
84
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
85 # initialize input layer for standalone RBM or layer0 of DBN
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
86 self.input = input if input else T.dmatrix('input')
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
87
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
88 self.W = W
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
89 self.hbias = hbias
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
90 self.vbias = vbias
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
91 self.theano_rng = theano_rng
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
92 # **** WARNING: It is not a good idea to put things in this list
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
93 # other than shared variables created in this function.
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
94 self.params = [self.W, self.hbias, self.vbias]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
95 self.batch_size = self.input.shape[0]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
96
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
97 def free_energy(self, v_sample):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
98 ''' Function to compute the free energy '''
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
99 wx_b = T.dot(v_sample, self.W) + self.hbias
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
100 vbias_term = T.sum(T.dot(v_sample, self.vbias))
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
101 hidden_term = T.sum(T.log(1+T.exp(wx_b)))
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
102 return -hidden_term - vbias_term
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
103
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
104 def sample_h_given_v(self, v0_sample):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
105 ''' This function infers state of hidden units given visible units '''
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
106 # compute the activation of the hidden units given a sample of the visibles
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
107 h1_mean = T.nnet.sigmoid(T.dot(v0_sample, self.W) + self.hbias)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
108 # get a sample of the hiddens given their activation
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
109 h1_sample = self.theano_rng.binomial(size = h1_mean.shape, n = 1, prob = h1_mean)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
110 return [h1_mean, h1_sample]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
111
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
112 def sample_v_given_h(self, h0_sample):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
113 ''' This function infers state of visible units given hidden units '''
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
114 # compute the activation of the visible given the hidden sample
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
115 v1_mean = T.nnet.sigmoid(T.dot(h0_sample, self.W.T) + self.vbias)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
116 # get a sample of the visible given their activation
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
117 v1_sample = self.theano_rng.binomial(size = v1_mean.shape,n = 1,prob = v1_mean)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
118 return [v1_mean, v1_sample]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
119
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
120 def gibbs_hvh(self, h0_sample):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
121 ''' This function implements one step of Gibbs sampling,
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
122 starting from the hidden state'''
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
123 v1_mean, v1_sample = self.sample_v_given_h(h0_sample)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
124 h1_mean, h1_sample = self.sample_h_given_v(v1_sample)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
125 return [v1_mean, v1_sample, h1_mean, h1_sample]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
126
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
127 def gibbs_vhv(self, v0_sample):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
128 ''' This function implements one step of Gibbs sampling,
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
129 starting from the visible state'''
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
130 h1_mean, h1_sample = self.sample_h_given_v(v0_sample)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
131 v1_mean, v1_sample = self.sample_v_given_h(h1_sample)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
132 return [h1_mean, h1_sample, v1_mean, v1_sample]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
133
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
134 def cd(self, lr = 0.1, persistent=None):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
135 """
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
136 This functions implements one step of CD-1 or PCD-1
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
137
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
138 :param lr: learning rate used to train the RBM
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
139 :param persistent: None for CD. For PCD, shared variable containing old state
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
140 of Gibbs chain. This must be a shared variable of size (batch size, number of
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
141 hidden units).
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
142
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
143 Returns the updates dictionary. The dictionary contains the update rules for weights
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
144 and biases but also an update of the shared variable used to store the persistent
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
145 chain, if one is used.
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
146 """
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
147
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
148 # compute positive phase
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
149 ph_mean, ph_sample = self.sample_h_given_v(self.input)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
150
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
151 # decide how to initialize persistent chain:
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
152 # for CD, we use the newly generate hidden sample
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
153 # for PCD, we initialize from the old state of the chain
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
154 if persistent is None:
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
155 chain_start = ph_sample
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
156 else:
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
157 chain_start = persistent
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
158
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
159 # perform actual negative phase
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
160 [nv_mean, nv_sample, nh_mean, nh_sample] = self.gibbs_hvh(chain_start)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
161
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
162 # determine gradients on RBM parameters
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
163 g_vbias = T.sum( self.input - nv_mean, axis = 0)/self.batch_size
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
164 g_hbias = T.sum( ph_mean - nh_mean, axis = 0)/self.batch_size
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
165 g_W = T.dot(ph_mean.T, self.input )/ self.batch_size - \
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
166 T.dot(nh_mean.T, nv_mean )/ self.batch_size
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
167
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
168 gparams = [g_W.T, g_hbias, g_vbias]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
169
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
170 # constructs the update dictionary
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
171 updates = {}
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
172 for gparam, param in zip(gparams, self.params):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
173 updates[param] = param + gparam * lr
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
174
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
175 if persistent:
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
176 # Note that this works only if persistent is a shared variable
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
177 updates[persistent] = T.cast(nh_sample, dtype=theano.config.floatX)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
178 # pseudo-likelihood is a better proxy for PCD
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
179 cost = self.get_pseudo_likelihood_cost(updates)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
180 else:
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
181 # reconstruction cross-entropy is a better proxy for CD
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
182 cost = self.get_reconstruction_cost(updates, nv_mean)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
183
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
184 return cost, updates
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
185
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
186 def get_pseudo_likelihood_cost(self, updates):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
187 """Stochastic approximation to the pseudo-likelihood"""
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
188
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
189 # index of bit i in expression p(x_i | x_{\i})
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
190 bit_i_idx = theano.shared(value=0, name = 'bit_i_idx')
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
191
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
192 # binarize the input image by rounding to nearest integer
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
193 xi = T.iround(self.input)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
194
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
195 # calculate free energy for the given bit configuration
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
196 fe_xi = self.free_energy(xi)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
197
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
198 # flip bit x_i of matrix xi and preserve all other bits x_{\i}
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
199 # Equivalent to xi[:,bit_i_idx] = 1-xi[:, bit_i_idx]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
200 # NB: slice(start,stop,step) is the python object used for
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
201 # slicing, e.g. to index matrix x as follows: x[start:stop:step]
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
202 xi_flip = T.setsubtensor(xi, 1-xi[:, bit_i_idx],
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
203 idx_list=(slice(None,None,None),bit_i_idx))
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
204
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
205 # calculate free energy with bit flipped
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
206 fe_xi_flip = self.free_energy(xi_flip)
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
207
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
208 # equivalent to e^(-FE(x_i)) / (e^(-FE(x_i)) + e^(-FE(x_{\i})))
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
209 cost = self.n_visible * T.log(T.nnet.sigmoid(fe_xi_flip - fe_xi))
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
210
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
211 # increment bit_i_idx % number as part of updates
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
212 updates[bit_i_idx] = (bit_i_idx + 1) % self.n_visible
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
213
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
214 return cost
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
215
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
216 def get_reconstruction_cost(self, updates, nv_mean):
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
217 """Approximation to the reconstruction error"""
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
218
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
219 cross_entropy = T.mean(
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
220 T.sum(self.input*T.log(nv_mean) +
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
221 (1 - self.input)*T.log(1-nv_mean), axis = 1))
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
222
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
223 return cross_entropy
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
224
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
225
9685e9d94cc4 base class for an rbm
goldfinger
parents:
diff changeset
226