annotate code_tutoriel/rbm.py @ 225:eb78a695ad7a

Merge
author fsavard
date Thu, 11 Mar 2010 16:29:39 -0500
parents 4bc5eeec6394
children
rev   line source
165
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
1 """This tutorial introduces restricted boltzmann machines (RBM) using Theano.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
2
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
3 Boltzmann Machines (BMs) are a particular form of energy-based model which
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
4 contain hidden variables. Restricted Boltzmann Machines further restrict BMs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
5 to those without visible-visible and hidden-hidden connections.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
6 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
7
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
8
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
9 import numpy, time, cPickle, gzip, PIL.Image
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
10
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
11 import theano
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
12 import theano.tensor as T
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
13 import os
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
14
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
15 from theano.tensor.shared_randomstreams import RandomStreams
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
16
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
17 from utils import tile_raster_images
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
18 from logistic_sgd import load_data
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
19
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
20 class RBM(object):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
21 """Restricted Boltzmann Machine (RBM) """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
22 def __init__(self, input=None, n_visible=784, n_hidden=500, \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
23 W = None, hbias = None, vbias = None, numpy_rng = None,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
24 theano_rng = None):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
25 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
26 RBM constructor. Defines the parameters of the model along with
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
27 basic operations for inferring hidden from visible (and vice-versa),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
28 as well as for performing CD updates.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
29
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
30 :param input: None for standalone RBMs or symbolic variable if RBM is
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
31 part of a larger graph.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
32
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
33 :param n_visible: number of visible units
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
34
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
35 :param n_hidden: number of hidden units
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
36
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
37 :param W: None for standalone RBMs or symbolic variable pointing to a
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
38 shared weight matrix in case RBM is part of a DBN network; in a DBN,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
39 the weights are shared between RBMs and layers of a MLP
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
40
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
41 :param hbias: None for standalone RBMs or symbolic variable pointing
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
42 to a shared hidden units bias vector in case RBM is part of a
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
43 different network
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
44
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
45 :param vbias: None for standalone RBMs or a symbolic variable
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
46 pointing to a shared visible units bias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
47 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
48
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
49 self.n_visible = n_visible
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
50 self.n_hidden = n_hidden
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
51
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
52
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
53 if W is None :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
54 # W is initialized with `initial_W` which is uniformely sampled
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
55 # from -6./sqrt(n_visible+n_hidden) and 6./sqrt(n_hidden+n_visible)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
56 # the output of uniform if converted using asarray to dtype
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
57 # theano.config.floatX so that the code is runable on GPU
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
58 initial_W = numpy.asarray( numpy.random.uniform(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
59 low = -numpy.sqrt(6./(n_hidden+n_visible)),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
60 high = numpy.sqrt(6./(n_hidden+n_visible)),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
61 size = (n_visible, n_hidden)),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
62 dtype = theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
63 # theano shared variables for weights and biases
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
64 W = theano.shared(value = initial_W, name = 'W')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
65
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
66 if hbias is None :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
67 # create shared variable for hidden units bias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
68 hbias = theano.shared(value = numpy.zeros(n_hidden,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
69 dtype = theano.config.floatX), name='hbias')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
70
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
71 if vbias is None :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
72 # create shared variable for visible units bias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
73 vbias = theano.shared(value =numpy.zeros(n_visible,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
74 dtype = theano.config.floatX),name='vbias')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
75
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
76 if numpy_rng is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
77 # create a number generator
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
78 numpy_rng = numpy.random.RandomState(1234)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
79
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
80 if theano_rng is None :
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
81 theano_rng = RandomStreams(numpy_rng.randint(2**30))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
82
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
83
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
84 # initialize input layer for standalone RBM or layer0 of DBN
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
85 self.input = input if input else T.dmatrix('input')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
86
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
87 self.W = W
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
88 self.hbias = hbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
89 self.vbias = vbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
90 self.theano_rng = theano_rng
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
91 # **** WARNING: It is not a good idea to put things in this list
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
92 # other than shared variables created in this function.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
93 self.params = [self.W, self.hbias, self.vbias]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
94 self.batch_size = self.input.shape[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
95
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
96 def free_energy(self, v_sample):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
97 ''' Function to compute the free energy '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
98 wx_b = T.dot(v_sample, self.W) + self.hbias
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
99 vbias_term = T.sum(T.dot(v_sample, self.vbias))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
100 hidden_term = T.sum(T.log(1+T.exp(wx_b)))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
101 return -hidden_term - vbias_term
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
102
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
103 def sample_h_given_v(self, v0_sample):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
104 ''' This function infers state of hidden units given visible units '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
105 # compute the activation of the hidden units given a sample of the visibles
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
106 h1_mean = T.nnet.sigmoid(T.dot(v0_sample, self.W) + self.hbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
107 # get a sample of the hiddens given their activation
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
108 h1_sample = self.theano_rng.binomial(size = h1_mean.shape, n = 1, prob = h1_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
109 return [h1_mean, h1_sample]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
110
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
111 def sample_v_given_h(self, h0_sample):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
112 ''' This function infers state of visible units given hidden units '''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
113 # compute the activation of the visible given the hidden sample
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
114 v1_mean = T.nnet.sigmoid(T.dot(h0_sample, self.W.T) + self.vbias)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
115 # get a sample of the visible given their activation
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
116 v1_sample = self.theano_rng.binomial(size = v1_mean.shape,n = 1,prob = v1_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
117 return [v1_mean, v1_sample]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
118
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
119 def gibbs_hvh(self, h0_sample):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
120 ''' This function implements one step of Gibbs sampling,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
121 starting from the hidden state'''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
122 v1_mean, v1_sample = self.sample_v_given_h(h0_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
123 h1_mean, h1_sample = self.sample_h_given_v(v1_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
124 return [v1_mean, v1_sample, h1_mean, h1_sample]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
125
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
126 def gibbs_vhv(self, v0_sample):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
127 ''' This function implements one step of Gibbs sampling,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
128 starting from the visible state'''
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
129 h1_mean, h1_sample = self.sample_h_given_v(v0_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
130 v1_mean, v1_sample = self.sample_v_given_h(h1_sample)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
131 return [h1_mean, h1_sample, v1_mean, v1_sample]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
132
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
133 def cd(self, lr = 0.1, persistent=None):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
134 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
135 This functions implements one step of CD-1 or PCD-1
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
136
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
137 :param lr: learning rate used to train the RBM
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
138 :param persistent: None for CD. For PCD, shared variable containing old state
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
139 of Gibbs chain. This must be a shared variable of size (batch size, number of
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
140 hidden units).
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
141
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
142 Returns the updates dictionary. The dictionary contains the update rules for weights
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
143 and biases but also an update of the shared variable used to store the persistent
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
144 chain, if one is used.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
145 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
146
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
147 # compute positive phase
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
148 ph_mean, ph_sample = self.sample_h_given_v(self.input)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
149
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
150 # decide how to initialize persistent chain:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
151 # for CD, we use the newly generate hidden sample
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
152 # for PCD, we initialize from the old state of the chain
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
153 if persistent is None:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
154 chain_start = ph_sample
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
155 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
156 chain_start = persistent
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
157
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
158 # perform actual negative phase
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
159 [nv_mean, nv_sample, nh_mean, nh_sample] = self.gibbs_hvh(chain_start)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
160
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
161 # determine gradients on RBM parameters
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
162 g_vbias = T.sum( self.input - nv_mean, axis = 0)/self.batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
163 g_hbias = T.sum( ph_mean - nh_mean, axis = 0)/self.batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
164 g_W = T.dot(ph_mean.T, self.input )/ self.batch_size - \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
165 T.dot(nh_mean.T, nv_mean )/ self.batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
166
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
167 gparams = [g_W.T, g_hbias, g_vbias]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
168
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
169 # constructs the update dictionary
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
170 updates = {}
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
171 for gparam, param in zip(gparams, self.params):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
172 updates[param] = param + gparam * lr
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
173
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
174 if persistent:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
175 # Note that this works only if persistent is a shared variable
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
176 updates[persistent] = T.cast(nh_sample, dtype=theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
177 # pseudo-likelihood is a better proxy for PCD
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
178 cost = self.get_pseudo_likelihood_cost(updates)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
179 else:
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
180 # reconstruction cross-entropy is a better proxy for CD
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
181 cost = self.get_reconstruction_cost(updates, nv_mean)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
182
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
183 return cost, updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
184
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
185 def get_pseudo_likelihood_cost(self, updates):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
186 """Stochastic approximation to the pseudo-likelihood"""
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
187
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
188 # index of bit i in expression p(x_i | x_{\i})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
189 bit_i_idx = theano.shared(value=0, name = 'bit_i_idx')
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
190
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
191 # binarize the input image by rounding to nearest integer
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
192 xi = T.iround(self.input)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
193
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
194 # calculate free energy for the given bit configuration
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
195 fe_xi = self.free_energy(xi)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
196
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
197 # flip bit x_i of matrix xi and preserve all other bits x_{\i}
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
198 # Equivalent to xi[:,bit_i_idx] = 1-xi[:, bit_i_idx]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
199 # NB: slice(start,stop,step) is the python object used for
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
200 # slicing, e.g. to index matrix x as follows: x[start:stop:step]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
201 xi_flip = T.setsubtensor(xi, 1-xi[:, bit_i_idx],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
202 idx_list=(slice(None,None,None),bit_i_idx))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
203
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
204 # calculate free energy with bit flipped
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
205 fe_xi_flip = self.free_energy(xi_flip)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
206
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
207 # equivalent to e^(-FE(x_i)) / (e^(-FE(x_i)) + e^(-FE(x_{\i})))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
208 cost = self.n_visible * T.log(T.nnet.sigmoid(fe_xi_flip - fe_xi))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
209
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
210 # increment bit_i_idx % number as part of updates
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
211 updates[bit_i_idx] = (bit_i_idx + 1) % self.n_visible
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
212
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
213 return cost
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
214
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
215 def get_reconstruction_cost(self, updates, nv_mean):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
216 """Approximation to the reconstruction error"""
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
217
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
218 cross_entropy = T.mean(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
219 T.sum(self.input*T.log(nv_mean) +
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
220 (1 - self.input)*T.log(1-nv_mean), axis = 1))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
221
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
222 return cross_entropy
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
223
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
224
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
225
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
226 def test_rbm(learning_rate=0.1, training_epochs = 15,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
227 dataset='mnist.pkl.gz'):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
228 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
229 Demonstrate ***
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
230
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
231 This is demonstrated on MNIST.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
232
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
233 :param learning_rate: learning rate used for training the RBM
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
234
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
235 :param training_epochs: number of epochs used for training
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
236
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
237 :param dataset: path the the pickled dataset
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
238
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
239 """
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
240 datasets = load_data(dataset)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
241
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
242 train_set_x, train_set_y = datasets[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
243 test_set_x , test_set_y = datasets[2]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
244
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
245
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
246 batch_size = 20 # size of the minibatch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
247
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
248 # compute number of minibatches for training, validation and testing
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
249 n_train_batches = train_set_x.value.shape[0] / batch_size
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
250
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
251 # allocate symbolic variables for the data
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
252 index = T.lscalar() # index to a [mini]batch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
253 x = T.matrix('x') # the data is presented as rasterized images
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
254
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
255 rng = numpy.random.RandomState(123)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
256 theano_rng = RandomStreams( rng.randint(2**30))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
257
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
258 # initialize storage fot the persistent chain (state = hidden layer of chain)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
259 persistent_chain = theano.shared(numpy.zeros((batch_size, 500)))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
260
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
261 # construct the RBM class
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
262 rbm = RBM( input = x, n_visible=28*28, \
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
263 n_hidden = 500,numpy_rng = rng, theano_rng = theano_rng)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
264
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
265 # get the cost and the gradient corresponding to one step of CD
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
266 cost, updates = rbm.cd(lr=learning_rate, persistent=persistent_chain)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
267
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
268
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
269 #################################
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
270 # Training the RBM #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
271 #################################
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
272 dirname = 'lr=%.5f'%learning_rate
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
273 os.makedirs(dirname)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
274 os.chdir(dirname)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
275
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
276 # it is ok for a theano function to have no output
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
277 # the purpose of train_rbm is solely to update the RBM parameters
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
278 train_rbm = theano.function([index], cost,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
279 updates = updates,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
280 givens = { x: train_set_x[index*batch_size:(index+1)*batch_size]})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
281
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
282 plotting_time = 0.
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
283 start_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
284
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
285
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
286 # go through training epochs
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
287 for epoch in xrange(training_epochs):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
288
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
289 # go through the training set
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
290 mean_cost = []
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
291 for batch_index in xrange(n_train_batches):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
292 mean_cost += [train_rbm(batch_index)]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
293
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
294 print 'Training epoch %d, cost is '%epoch, numpy.mean(mean_cost)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
295
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
296 # Plot filters after each training epoch
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
297 plotting_start = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
298 # Construct image from the weight matrix
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
299 image = PIL.Image.fromarray(tile_raster_images( X = rbm.W.value.T,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
300 img_shape = (28,28),tile_shape = (10,10),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
301 tile_spacing=(1,1)))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
302 image.save('filters_at_epoch_%i.png'%epoch)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
303 plotting_stop = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
304 plotting_time += (plotting_stop - plotting_start)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
305
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
306 end_time = time.clock()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
307
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
308 pretraining_time = (end_time - start_time) - plotting_time
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
309
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
310 print ('Training took %f minutes' %(pretraining_time/60.))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
311
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
312
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
313 #################################
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
314 # Sampling from the RBM #
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
315 #################################
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
316
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
317 # find out the number of test samples
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
318 number_of_test_samples = test_set_x.value.shape[0]
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
319
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
320 # pick random test examples, with which to initialize the persistent chain
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
321 test_idx = rng.randint(number_of_test_samples-20)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
322 persistent_vis_chain = theano.shared(test_set_x.value[test_idx:test_idx+20])
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
323
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
324 # define one step of Gibbs sampling (mf = mean-field)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
325 [hid_mf, hid_sample, vis_mf, vis_sample] = rbm.gibbs_vhv(persistent_vis_chain)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
326
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
327 # the sample at the end of the channel is returned by ``gibbs_1`` as
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
328 # its second output; note that this is computed as a binomial draw,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
329 # therefore it is formed of ints (0 and 1) and therefore needs to
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
330 # be converted to the same dtype as ``persistent_vis_chain``
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
331 vis_sample = T.cast(vis_sample, dtype=theano.config.floatX)
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
332
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
333 # construct the function that implements our persistent chain
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
334 # we generate the "mean field" activations for plotting and the actual samples for
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
335 # reinitializing the state of our persistent chain
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
336 sample_fn = theano.function([], [vis_mf, vis_sample],
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
337 updates = { persistent_vis_chain:vis_sample})
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
338
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
339 # sample the RBM, plotting every `plot_every`-th sample; do this
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
340 # until you plot at least `n_samples`
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
341 n_samples = 10
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
342 plot_every = 1000
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
343
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
344 for idx in xrange(n_samples):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
345
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
346 # do `plot_every` intermediate samplings of which we do not care
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
347 for jdx in xrange(plot_every):
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
348 vis_mf, vis_sample = sample_fn()
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
349
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
350 # construct image
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
351 image = PIL.Image.fromarray(tile_raster_images(
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
352 X = vis_mf,
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
353 img_shape = (28,28),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
354 tile_shape = (10,10),
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
355 tile_spacing = (1,1) ) )
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
356 print ' ... plotting sample ', idx
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
357 image.save('sample_%i_step_%i.png'%(idx,idx*jdx))
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
358
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
359 if __name__ == '__main__':
4bc5eeec6394 Updating the tutorial code to the latest revisions.
Dumitru Erhan <dumitru.erhan@gmail.com>
parents:
diff changeset
360 test_rbm()