comparison deep/crbm/mnist_crbm.py @ 337:8d116d4a7593

Added convolutional RBM (ala Lee09) code, imported from my working dir elsewhere. Seems to work for one layer. No subsampling yet.
author fsavard
date Fri, 16 Apr 2010 16:05:55 -0400
parents
children ffbf0e41bcee
comparison
equal deleted inserted replaced
336:a79db7cee035 337:8d116d4a7593
1 #!/usr/bin/python
2
3 import sys
4 import os, os.path
5
6 import numpy as N
7
8 import theano
9 import theano.tensor as T
10
11 from crbm import CRBM, ConvolutionParams
12
13 from pylearn.datasets import MNIST
14 from pylearn.io.image_tiling import tile_raster_images
15
16 import Image
17
18 from pylearn.io.seriestables import *
19 import tables
20
21 IMAGE_OUTPUT_DIR = 'img/'
22
23 REDUCE_EVERY = 100
24
25 def filename_from_time(suffix):
26 import datetime
27 return str(datetime.datetime.now()) + suffix + ".png"
28
29 # Just a shortcut for a common case where we need a few
30 # related Error (float) series
31 def get_accumulator_series_array( \
32 hdf5_file, group_name, series_names,
33 reduce_every,
34 index_names=('epoch','minibatch'),
35 stdout_too=True,
36 skip_hdf5_append=False):
37 all_series = []
38
39 hdf5_file.createGroup('/', group_name)
40
41 other_targets = []
42 if stdout_too:
43 other_targets = [StdoutAppendTarget()]
44
45 for sn in series_names:
46 series_base = \
47 ErrorSeries(error_name=sn,
48 table_name=sn,
49 hdf5_file=hdf5_file,
50 hdf5_group='/'+group_name,
51 index_names=index_names,
52 other_targets=other_targets,
53 skip_hdf5_append=skip_hdf5_append)
54
55 all_series.append( \
56 AccumulatorSeriesWrapper( \
57 base_series=series_base,
58 reduce_every=reduce_every))
59
60 ret_wrapper = SeriesArrayWrapper(all_series)
61
62 return ret_wrapper
63
64 class MnistCrbm(object):
65 def __init__(self):
66 self.mnist = MNIST.full()#first_10k()
67
68 self.cp = ConvolutionParams( \
69 num_filters=40,
70 num_input_planes=1,
71 height_filters=12,
72 width_filters=12)
73
74 self.image_size = (28,28)
75
76 self.minibatch_size = 10
77
78 self.lr = 0.01
79 self.sparsity_lambda = 1.0
80 # about 1/num_filters, so only one filter active at a time
81 # 40 * 0.05 = ~2 filters active for any given pixel
82 self.sparsity_p = 0.05
83
84 self.crbm = CRBM( \
85 minibatch_size=self.minibatch_size,
86 image_size=self.image_size,
87 conv_params=self.cp,
88 learning_rate=self.lr,
89 sparsity_lambda=self.sparsity_lambda,
90 sparsity_p=self.sparsity_p)
91
92 self.num_epochs = 10
93
94 self.init_series()
95
96 def init_series(self):
97
98 series = {}
99
100 basedir = os.getcwd()
101
102 h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w")
103
104 cd_series_names = self.crbm.cd_return_desc
105 series['cd'] = \
106 get_accumulator_series_array( \
107 h5f, 'cd', cd_series_names,
108 REDUCE_EVERY,
109 stdout_too=True)
110
111 sparsity_series_names = self.crbm.sparsity_return_desc
112 series['sparsity'] = \
113 get_accumulator_series_array( \
114 h5f, 'sparsity', sparsity_series_names,
115 REDUCE_EVERY,
116 stdout_too=True)
117
118 # so first we create the names for each table, based on
119 # position of each param in the array
120 params_stdout = StdoutAppendTarget("\n------\nParams")
121 series['params'] = SharedParamsStatisticsWrapper(
122 new_group_name="params",
123 base_group="/",
124 arrays_names=['W','b_h','b_x'],
125 hdf5_file=h5f,
126 index_names=('epoch','minibatch'),
127 other_targets=[params_stdout])
128
129 self.series = series
130
131 def train(self):
132 num_minibatches = len(self.mnist.train.x) / self.minibatch_size
133
134 print_every = 1000
135 visualize_every = 5000
136 gibbs_steps_from_random = 1000
137
138 for epoch in xrange(self.num_epochs):
139 for mb_index in xrange(num_minibatches):
140 mb_x = self.mnist.train.x \
141 [mb_index : mb_index+self.minibatch_size]
142 mb_x = mb_x.reshape((self.minibatch_size, 1, 28, 28))
143
144 #E_h = crbm.E_h_given_x_func(mb_x)
145 #print "Shape of E_h", E_h.shape
146
147 cd_return = self.crbm.CD_step(mb_x)
148 sp_return = self.crbm.sparsity_step(mb_x)
149
150 self.series['cd'].append( \
151 (epoch, mb_index), cd_return)
152 self.series['sparsity'].append( \
153 (epoch, mb_index), sp_return)
154
155 total_idx = epoch*num_minibatches + mb_index
156
157 if (total_idx+1) % REDUCE_EVERY == 0:
158 self.series['params'].append( \
159 (epoch, mb_index), self.crbm.params)
160
161 if total_idx % visualize_every == 0:
162 self.visualize_gibbs_result(\
163 mb_x, gibbs_steps_from_random)
164 self.visualize_gibbs_result(mb_x, 1)
165 self.visualize_filters()
166
167 def visualize_gibbs_result(self, start_x, gibbs_steps):
168 # Run minibatch_size chains for gibbs_steps
169 x_samples = None
170 if not start_x is None:
171 x_samples = self.crbm.gibbs_samples_from(start_x, gibbs_steps)
172 else:
173 x_samples = self.crbm.random_gibbs_samples(gibbs_steps)
174 x_samples = x_samples.reshape((self.minibatch_size, 28*28))
175
176 tile = tile_raster_images(x_samples, self.image_size,
177 (1, self.minibatch_size), output_pixel_vals=True)
178
179 filepath = os.path.join(IMAGE_OUTPUT_DIR,
180 filename_from_time("gibbs"))
181 img = Image.fromarray(tile)
182 img.save(filepath)
183
184 print "Result of running Gibbs", \
185 gibbs_steps, "times outputed to", filepath
186
187 def visualize_filters(self):
188 cp = self.cp
189
190 # filter size
191 fsz = (cp.height_filters, cp.width_filters)
192 tile_shape = (cp.num_filters, cp.num_input_planes)
193
194 filters_flattened = self.crbm.W.value.reshape(
195 (tile_shape[0]*tile_shape[1],
196 fsz[0]*fsz[1]))
197
198 tile = tile_raster_images(filters_flattened, fsz,
199 tile_shape, output_pixel_vals=True)
200
201 filepath = os.path.join(IMAGE_OUTPUT_DIR,
202 filename_from_time("filters"))
203 img = Image.fromarray(tile)
204 img.save(filepath)
205
206 print "Filters (as images) outputed to", filepath
207 print "b_h is", self.crbm.b_h.value
208
209
210
211
212 if __name__ == '__main__':
213 mc = MnistCrbm()
214 mc.train()
215