348
|
1 import sys
|
|
2 import os, os.path
|
|
3
|
|
4 import numpy as N
|
|
5
|
|
6 import theano
|
|
7 import theano.tensor as T
|
|
8
|
|
9 from crbm import CRBM, ConvolutionParams
|
|
10
|
|
11 from pylearn.datasets import MNIST
|
|
12 from pylearn.io.image_tiling import tile_raster_images
|
|
13
|
|
14 import Image
|
|
15
|
|
16 from pylearn.io.seriestables import *
|
|
17 import tables
|
|
18
|
|
19 IMAGE_OUTPUT_DIR = 'img/'
|
|
20
|
|
21 REDUCE_EVERY = 100
|
|
22
|
|
23 def filename_from_time(suffix):
|
|
24 import datetime
|
|
25 return str(datetime.datetime.now()) + suffix + ".png"
|
|
26
|
|
27 # Just a shortcut for a common case where we need a few
|
|
28 # related Error (float) series
|
|
29
|
|
30 def get_accumulator_series_array( \
|
|
31 hdf5_file, group_name, series_names,
|
|
32 reduce_every,
|
|
33 index_names=('epoch','minibatch'),
|
|
34 stdout_too=True,
|
|
35 skip_hdf5_append=False):
|
|
36 all_series = []
|
|
37
|
|
38 hdf5_file.createGroup('/', group_name)
|
|
39
|
|
40 other_targets = []
|
|
41 if stdout_too:
|
|
42 other_targets = [StdoutAppendTarget()]
|
|
43
|
|
44 for sn in series_names:
|
|
45 series_base = \
|
|
46 ErrorSeries(error_name=sn,
|
|
47 table_name=sn,
|
|
48 hdf5_file=hdf5_file,
|
|
49 hdf5_group='/'+group_name,
|
|
50 index_names=index_names,
|
|
51 other_targets=other_targets,
|
|
52 skip_hdf5_append=skip_hdf5_append)
|
|
53
|
|
54 all_series.append( \
|
|
55 AccumulatorSeriesWrapper( \
|
|
56 base_series=series_base,
|
|
57 reduce_every=reduce_every))
|
|
58
|
|
59 ret_wrapper = SeriesArrayWrapper(all_series)
|
|
60
|
|
61 return ret_wrapper
|
|
62
|
|
63 class ExperienceRbm(object):
|
|
64 def __init__(self):
|
|
65 self.mnist = MNIST.full()#first_10k()
|
|
66
|
|
67
|
|
68 datasets = load_data(dataset)
|
|
69
|
|
70 train_set_x, train_set_y = datasets[0]
|
|
71 test_set_x , test_set_y = datasets[2]
|
|
72
|
|
73
|
|
74 batch_size = 100 # size of the minibatch
|
|
75
|
|
76 # compute number of minibatches for training, validation and testing
|
|
77 n_train_batches = train_set_x.value.shape[0] / batch_size
|
|
78
|
|
79 # allocate symbolic variables for the data
|
|
80 index = T.lscalar() # index to a [mini]batch
|
|
81 x = T.matrix('x') # the data is presented as rasterized images
|
|
82
|
|
83 rng = numpy.random.RandomState(123)
|
|
84 theano_rng = RandomStreams( rng.randint(2**30))
|
|
85
|
|
86 # initialize storage fot the persistent chain (state = hidden layer of chain)
|
|
87 persistent_chain = theano.shared(numpy.zeros((batch_size, 500)))
|
|
88
|
|
89 # construct the RBM class
|
|
90 self.rbm = RBM( input = x, n_visible=28*28, \
|
|
91 n_hidden = 500,numpy_rng = rng, theano_rng = theano_rng)
|
|
92
|
|
93 # get the cost and the gradient corresponding to one step of CD
|
|
94
|
|
95
|
|
96 self.init_series()
|
|
97
|
|
98 def init_series(self):
|
|
99
|
|
100 series = {}
|
|
101
|
|
102 basedir = os.getcwd()
|
|
103
|
|
104 h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w")
|
|
105
|
|
106 cd_series_names = self.rbm.cd_return_desc
|
|
107 series['cd'] = \
|
|
108 get_accumulator_series_array( \
|
|
109 h5f, 'cd', cd_series_names,
|
|
110 REDUCE_EVERY,
|
|
111 stdout_too=True)
|
|
112
|
|
113
|
|
114
|
|
115 # so first we create the names for each table, based on
|
|
116 # position of each param in the array
|
|
117 params_stdout = StdoutAppendTarget("\n------\nParams")
|
|
118 series['params'] = SharedParamsStatisticsWrapper(
|
|
119 new_group_name="params",
|
|
120 base_group="/",
|
|
121 arrays_names=['W','b_h','b_x'],
|
|
122 hdf5_file=h5f,
|
|
123 index_names=('epoch','minibatch'),
|
|
124 other_targets=[params_stdout])
|
|
125
|
|
126 self.series = series
|
|
127
|
|
128 def train(self, persistent, learning_rate):
|
|
129
|
|
130 training_epochs = 15
|
|
131
|
|
132 #get the cost and the gradient corresponding to one step of CD
|
|
133 if persistant:
|
|
134 persistent_chain = theano.shared(numpy.zeros((batch_size, self.rbm.n_hidden)))
|
|
135 cost, updates = self.rbm.cd(lr=learning_rate, persistent=persistent_chain)
|
|
136
|
|
137 else:
|
|
138 cost, updates = self.rbm.cd(lr=learning_rate)
|
|
139
|
|
140 dirname = 'lr=%.5f'%self.rbm.learning_rate
|
|
141 os.makedirs(dirname)
|
|
142 os.chdir(dirname)
|
|
143
|
|
144 # the purpose of train_rbm is solely to update the RBM parameters
|
|
145 train_rbm = theano.function([index], cost,
|
|
146 updates = updates,
|
|
147 givens = { x: train_set_x[index*batch_size:(index+1)*batch_size]})
|
|
148
|
|
149 plotting_time = 0.
|
|
150 start_time = time.clock()
|
|
151
|
|
152
|
|
153 # go through training epochs
|
|
154 for epoch in xrange(training_epochs):
|
|
155
|
|
156 # go through the training set
|
|
157 mean_cost = []
|
|
158 for batch_index in xrange(n_train_batches):
|
|
159 mean_cost += [train_rbm(batch_index)]
|
|
160
|
|
161
|
|
162 pretraining_time = (end_time - start_time)
|
|
163
|
|
164
|
|
165
|
|
166
|
|
167 def sample_from_rbm(self, gibbs_steps, test_set_x):
|
|
168
|
|
169 # find out the number of test samples
|
|
170 number_of_test_samples = test_set_x.value.shape[0]
|
|
171
|
|
172 # pick random test examples, with which to initialize the persistent chain
|
|
173 test_idx = rng.randint(number_of_test_samples-20)
|
|
174 persistent_vis_chain = theano.shared(test_set_x.value[test_idx:test_idx+20])
|
|
175
|
|
176 # define one step of Gibbs sampling (mf = mean-field)
|
|
177 [hid_mf, hid_sample, vis_mf, vis_sample] = self.rbm.gibbs_vhv(persistent_vis_chain)
|
|
178
|
|
179 # the sample at the end of the channel is returned by ``gibbs_1`` as
|
|
180 # its second output; note that this is computed as a binomial draw,
|
|
181 # therefore it is formed of ints (0 and 1) and therefore needs to
|
|
182 # be converted to the same dtype as ``persistent_vis_chain``
|
|
183 vis_sample = T.cast(vis_sample, dtype=theano.config.floatX)
|
|
184
|
|
185 # construct the function that implements our persistent chain
|
|
186 # we generate the "mean field" activations for plotting and the actual samples for
|
|
187 # reinitializing the state of our persistent chain
|
|
188 sample_fn = theano.function([], [vis_mf, vis_sample],
|
|
189 updates = { persistent_vis_chain:vis_sample})
|
|
190
|
|
191 # sample the RBM, plotting every `plot_every`-th sample; do this
|
|
192 # until you plot at least `n_samples`
|
|
193 n_samples = 10
|
|
194 plot_every = 1000
|
|
195
|
|
196 for idx in xrange(n_samples):
|
|
197
|
|
198 # do `plot_every` intermediate samplings of which we do not care
|
|
199 for jdx in xrange(plot_every):
|
|
200 vis_mf, vis_sample = sample_fn()
|
|
201
|
|
202 # construct image
|
|
203 image = PIL.Image.fromarray(tile_raster_images(
|
|
204 X = vis_mf,
|
|
205 img_shape = (28,28),
|
|
206 tile_shape = (10,10),
|
|
207 tile_spacing = (1,1) ) )
|
|
208
|
|
209 image.save('sample_%i_step_%i.png'%(idx,idx*jdx))
|
|
210
|
|
211
|
|
212 if __name__ == '__main__':
|
|
213 mc = ExperienceRbm()
|
|
214 mc.train()
|
|
215
|