Mercurial > ift6266
view deep/crbm/mnist_crbm.py @ 347:9685e9d94cc4
base class for an rbm
author | goldfinger |
---|---|
date | Mon, 19 Apr 2010 08:16:56 -0400 |
parents | 8d116d4a7593 |
children | ffbf0e41bcee |
line wrap: on
line source
#!/usr/bin/python import sys import os, os.path import numpy as N import theano import theano.tensor as T from crbm import CRBM, ConvolutionParams from pylearn.datasets import MNIST from pylearn.io.image_tiling import tile_raster_images import Image from pylearn.io.seriestables import * import tables IMAGE_OUTPUT_DIR = 'img/' REDUCE_EVERY = 100 def filename_from_time(suffix): import datetime return str(datetime.datetime.now()) + suffix + ".png" # Just a shortcut for a common case where we need a few # related Error (float) series def get_accumulator_series_array( \ hdf5_file, group_name, series_names, reduce_every, index_names=('epoch','minibatch'), stdout_too=True, skip_hdf5_append=False): all_series = [] hdf5_file.createGroup('/', group_name) other_targets = [] if stdout_too: other_targets = [StdoutAppendTarget()] for sn in series_names: series_base = \ ErrorSeries(error_name=sn, table_name=sn, hdf5_file=hdf5_file, hdf5_group='/'+group_name, index_names=index_names, other_targets=other_targets, skip_hdf5_append=skip_hdf5_append) all_series.append( \ AccumulatorSeriesWrapper( \ base_series=series_base, reduce_every=reduce_every)) ret_wrapper = SeriesArrayWrapper(all_series) return ret_wrapper class MnistCrbm(object): def __init__(self): self.mnist = MNIST.full()#first_10k() self.cp = ConvolutionParams( \ num_filters=40, num_input_planes=1, height_filters=12, width_filters=12) self.image_size = (28,28) self.minibatch_size = 10 self.lr = 0.01 self.sparsity_lambda = 1.0 # about 1/num_filters, so only one filter active at a time # 40 * 0.05 = ~2 filters active for any given pixel self.sparsity_p = 0.05 self.crbm = CRBM( \ minibatch_size=self.minibatch_size, image_size=self.image_size, conv_params=self.cp, learning_rate=self.lr, sparsity_lambda=self.sparsity_lambda, sparsity_p=self.sparsity_p) self.num_epochs = 10 self.init_series() def init_series(self): series = {} basedir = os.getcwd() h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w") cd_series_names = self.crbm.cd_return_desc series['cd'] = \ get_accumulator_series_array( \ h5f, 'cd', cd_series_names, REDUCE_EVERY, stdout_too=True) sparsity_series_names = self.crbm.sparsity_return_desc series['sparsity'] = \ get_accumulator_series_array( \ h5f, 'sparsity', sparsity_series_names, REDUCE_EVERY, stdout_too=True) # so first we create the names for each table, based on # position of each param in the array params_stdout = StdoutAppendTarget("\n------\nParams") series['params'] = SharedParamsStatisticsWrapper( new_group_name="params", base_group="/", arrays_names=['W','b_h','b_x'], hdf5_file=h5f, index_names=('epoch','minibatch'), other_targets=[params_stdout]) self.series = series def train(self): num_minibatches = len(self.mnist.train.x) / self.minibatch_size print_every = 1000 visualize_every = 5000 gibbs_steps_from_random = 1000 for epoch in xrange(self.num_epochs): for mb_index in xrange(num_minibatches): mb_x = self.mnist.train.x \ [mb_index : mb_index+self.minibatch_size] mb_x = mb_x.reshape((self.minibatch_size, 1, 28, 28)) #E_h = crbm.E_h_given_x_func(mb_x) #print "Shape of E_h", E_h.shape cd_return = self.crbm.CD_step(mb_x) sp_return = self.crbm.sparsity_step(mb_x) self.series['cd'].append( \ (epoch, mb_index), cd_return) self.series['sparsity'].append( \ (epoch, mb_index), sp_return) total_idx = epoch*num_minibatches + mb_index if (total_idx+1) % REDUCE_EVERY == 0: self.series['params'].append( \ (epoch, mb_index), self.crbm.params) if total_idx % visualize_every == 0: self.visualize_gibbs_result(\ mb_x, gibbs_steps_from_random) self.visualize_gibbs_result(mb_x, 1) self.visualize_filters() def visualize_gibbs_result(self, start_x, gibbs_steps): # Run minibatch_size chains for gibbs_steps x_samples = None if not start_x is None: x_samples = self.crbm.gibbs_samples_from(start_x, gibbs_steps) else: x_samples = self.crbm.random_gibbs_samples(gibbs_steps) x_samples = x_samples.reshape((self.minibatch_size, 28*28)) tile = tile_raster_images(x_samples, self.image_size, (1, self.minibatch_size), output_pixel_vals=True) filepath = os.path.join(IMAGE_OUTPUT_DIR, filename_from_time("gibbs")) img = Image.fromarray(tile) img.save(filepath) print "Result of running Gibbs", \ gibbs_steps, "times outputed to", filepath def visualize_filters(self): cp = self.cp # filter size fsz = (cp.height_filters, cp.width_filters) tile_shape = (cp.num_filters, cp.num_input_planes) filters_flattened = self.crbm.W.value.reshape( (tile_shape[0]*tile_shape[1], fsz[0]*fsz[1])) tile = tile_raster_images(filters_flattened, fsz, tile_shape, output_pixel_vals=True) filepath = os.path.join(IMAGE_OUTPUT_DIR, filename_from_time("filters")) img = Image.fromarray(tile) img.save(filepath) print "Filters (as images) outputed to", filepath print "b_h is", self.crbm.b_h.value if __name__ == '__main__': mc = MnistCrbm() mc.train()