Mercurial > ift6266
comparison deep/crbm/mnist_crbm.py @ 337:8d116d4a7593
Added convolutional RBM (ala Lee09) code, imported from my working dir elsewhere. Seems to work for one layer. No subsampling yet.
author | fsavard |
---|---|
date | Fri, 16 Apr 2010 16:05:55 -0400 |
parents | |
children | ffbf0e41bcee |
comparison
equal
deleted
inserted
replaced
336:a79db7cee035 | 337:8d116d4a7593 |
---|---|
1 #!/usr/bin/python | |
2 | |
3 import sys | |
4 import os, os.path | |
5 | |
6 import numpy as N | |
7 | |
8 import theano | |
9 import theano.tensor as T | |
10 | |
11 from crbm import CRBM, ConvolutionParams | |
12 | |
13 from pylearn.datasets import MNIST | |
14 from pylearn.io.image_tiling import tile_raster_images | |
15 | |
16 import Image | |
17 | |
18 from pylearn.io.seriestables import * | |
19 import tables | |
20 | |
21 IMAGE_OUTPUT_DIR = 'img/' | |
22 | |
23 REDUCE_EVERY = 100 | |
24 | |
25 def filename_from_time(suffix): | |
26 import datetime | |
27 return str(datetime.datetime.now()) + suffix + ".png" | |
28 | |
29 # Just a shortcut for a common case where we need a few | |
30 # related Error (float) series | |
31 def get_accumulator_series_array( \ | |
32 hdf5_file, group_name, series_names, | |
33 reduce_every, | |
34 index_names=('epoch','minibatch'), | |
35 stdout_too=True, | |
36 skip_hdf5_append=False): | |
37 all_series = [] | |
38 | |
39 hdf5_file.createGroup('/', group_name) | |
40 | |
41 other_targets = [] | |
42 if stdout_too: | |
43 other_targets = [StdoutAppendTarget()] | |
44 | |
45 for sn in series_names: | |
46 series_base = \ | |
47 ErrorSeries(error_name=sn, | |
48 table_name=sn, | |
49 hdf5_file=hdf5_file, | |
50 hdf5_group='/'+group_name, | |
51 index_names=index_names, | |
52 other_targets=other_targets, | |
53 skip_hdf5_append=skip_hdf5_append) | |
54 | |
55 all_series.append( \ | |
56 AccumulatorSeriesWrapper( \ | |
57 base_series=series_base, | |
58 reduce_every=reduce_every)) | |
59 | |
60 ret_wrapper = SeriesArrayWrapper(all_series) | |
61 | |
62 return ret_wrapper | |
63 | |
64 class MnistCrbm(object): | |
65 def __init__(self): | |
66 self.mnist = MNIST.full()#first_10k() | |
67 | |
68 self.cp = ConvolutionParams( \ | |
69 num_filters=40, | |
70 num_input_planes=1, | |
71 height_filters=12, | |
72 width_filters=12) | |
73 | |
74 self.image_size = (28,28) | |
75 | |
76 self.minibatch_size = 10 | |
77 | |
78 self.lr = 0.01 | |
79 self.sparsity_lambda = 1.0 | |
80 # about 1/num_filters, so only one filter active at a time | |
81 # 40 * 0.05 = ~2 filters active for any given pixel | |
82 self.sparsity_p = 0.05 | |
83 | |
84 self.crbm = CRBM( \ | |
85 minibatch_size=self.minibatch_size, | |
86 image_size=self.image_size, | |
87 conv_params=self.cp, | |
88 learning_rate=self.lr, | |
89 sparsity_lambda=self.sparsity_lambda, | |
90 sparsity_p=self.sparsity_p) | |
91 | |
92 self.num_epochs = 10 | |
93 | |
94 self.init_series() | |
95 | |
96 def init_series(self): | |
97 | |
98 series = {} | |
99 | |
100 basedir = os.getcwd() | |
101 | |
102 h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w") | |
103 | |
104 cd_series_names = self.crbm.cd_return_desc | |
105 series['cd'] = \ | |
106 get_accumulator_series_array( \ | |
107 h5f, 'cd', cd_series_names, | |
108 REDUCE_EVERY, | |
109 stdout_too=True) | |
110 | |
111 sparsity_series_names = self.crbm.sparsity_return_desc | |
112 series['sparsity'] = \ | |
113 get_accumulator_series_array( \ | |
114 h5f, 'sparsity', sparsity_series_names, | |
115 REDUCE_EVERY, | |
116 stdout_too=True) | |
117 | |
118 # so first we create the names for each table, based on | |
119 # position of each param in the array | |
120 params_stdout = StdoutAppendTarget("\n------\nParams") | |
121 series['params'] = SharedParamsStatisticsWrapper( | |
122 new_group_name="params", | |
123 base_group="/", | |
124 arrays_names=['W','b_h','b_x'], | |
125 hdf5_file=h5f, | |
126 index_names=('epoch','minibatch'), | |
127 other_targets=[params_stdout]) | |
128 | |
129 self.series = series | |
130 | |
131 def train(self): | |
132 num_minibatches = len(self.mnist.train.x) / self.minibatch_size | |
133 | |
134 print_every = 1000 | |
135 visualize_every = 5000 | |
136 gibbs_steps_from_random = 1000 | |
137 | |
138 for epoch in xrange(self.num_epochs): | |
139 for mb_index in xrange(num_minibatches): | |
140 mb_x = self.mnist.train.x \ | |
141 [mb_index : mb_index+self.minibatch_size] | |
142 mb_x = mb_x.reshape((self.minibatch_size, 1, 28, 28)) | |
143 | |
144 #E_h = crbm.E_h_given_x_func(mb_x) | |
145 #print "Shape of E_h", E_h.shape | |
146 | |
147 cd_return = self.crbm.CD_step(mb_x) | |
148 sp_return = self.crbm.sparsity_step(mb_x) | |
149 | |
150 self.series['cd'].append( \ | |
151 (epoch, mb_index), cd_return) | |
152 self.series['sparsity'].append( \ | |
153 (epoch, mb_index), sp_return) | |
154 | |
155 total_idx = epoch*num_minibatches + mb_index | |
156 | |
157 if (total_idx+1) % REDUCE_EVERY == 0: | |
158 self.series['params'].append( \ | |
159 (epoch, mb_index), self.crbm.params) | |
160 | |
161 if total_idx % visualize_every == 0: | |
162 self.visualize_gibbs_result(\ | |
163 mb_x, gibbs_steps_from_random) | |
164 self.visualize_gibbs_result(mb_x, 1) | |
165 self.visualize_filters() | |
166 | |
167 def visualize_gibbs_result(self, start_x, gibbs_steps): | |
168 # Run minibatch_size chains for gibbs_steps | |
169 x_samples = None | |
170 if not start_x is None: | |
171 x_samples = self.crbm.gibbs_samples_from(start_x, gibbs_steps) | |
172 else: | |
173 x_samples = self.crbm.random_gibbs_samples(gibbs_steps) | |
174 x_samples = x_samples.reshape((self.minibatch_size, 28*28)) | |
175 | |
176 tile = tile_raster_images(x_samples, self.image_size, | |
177 (1, self.minibatch_size), output_pixel_vals=True) | |
178 | |
179 filepath = os.path.join(IMAGE_OUTPUT_DIR, | |
180 filename_from_time("gibbs")) | |
181 img = Image.fromarray(tile) | |
182 img.save(filepath) | |
183 | |
184 print "Result of running Gibbs", \ | |
185 gibbs_steps, "times outputed to", filepath | |
186 | |
187 def visualize_filters(self): | |
188 cp = self.cp | |
189 | |
190 # filter size | |
191 fsz = (cp.height_filters, cp.width_filters) | |
192 tile_shape = (cp.num_filters, cp.num_input_planes) | |
193 | |
194 filters_flattened = self.crbm.W.value.reshape( | |
195 (tile_shape[0]*tile_shape[1], | |
196 fsz[0]*fsz[1])) | |
197 | |
198 tile = tile_raster_images(filters_flattened, fsz, | |
199 tile_shape, output_pixel_vals=True) | |
200 | |
201 filepath = os.path.join(IMAGE_OUTPUT_DIR, | |
202 filename_from_time("filters")) | |
203 img = Image.fromarray(tile) | |
204 img.save(filepath) | |
205 | |
206 print "Filters (as images) outputed to", filepath | |
207 print "b_h is", self.crbm.b_h.value | |
208 | |
209 | |
210 | |
211 | |
212 if __name__ == '__main__': | |
213 mc = MnistCrbm() | |
214 mc.train() | |
215 |