view deep/crbm/mnist_crbm.py @ 347:9685e9d94cc4

base class for an rbm
author goldfinger
date Mon, 19 Apr 2010 08:16:56 -0400
parents 8d116d4a7593
children ffbf0e41bcee
line wrap: on
line source

#!/usr/bin/python

import sys
import os, os.path

import numpy as N

import theano
import theano.tensor as T

from crbm import CRBM, ConvolutionParams

from pylearn.datasets import MNIST
from pylearn.io.image_tiling import tile_raster_images

import Image

from pylearn.io.seriestables import *
import tables

IMAGE_OUTPUT_DIR = 'img/'

REDUCE_EVERY = 100

def filename_from_time(suffix):
    import datetime
    return str(datetime.datetime.now()) + suffix + ".png"

# Just a shortcut for a common case where we need a few
# related Error (float) series
def get_accumulator_series_array( \
                hdf5_file, group_name, series_names, 
                reduce_every,
                index_names=('epoch','minibatch'),
                stdout_too=True,
                skip_hdf5_append=False):
    all_series = []

    hdf5_file.createGroup('/', group_name)

    other_targets = []
    if stdout_too:
        other_targets = [StdoutAppendTarget()]

    for sn in series_names:
        series_base = \
            ErrorSeries(error_name=sn,
                table_name=sn,
                hdf5_file=hdf5_file,
                hdf5_group='/'+group_name,
                index_names=index_names,
                other_targets=other_targets,
                skip_hdf5_append=skip_hdf5_append)

        all_series.append( \
            AccumulatorSeriesWrapper( \
                    base_series=series_base,
                    reduce_every=reduce_every))

    ret_wrapper = SeriesArrayWrapper(all_series)

    return ret_wrapper

class MnistCrbm(object):
    def __init__(self):
        self.mnist = MNIST.full()#first_10k()

        self.cp = ConvolutionParams( \
                    num_filters=40,
                    num_input_planes=1,
                    height_filters=12,
                    width_filters=12)

        self.image_size = (28,28)

        self.minibatch_size = 10

        self.lr = 0.01
        self.sparsity_lambda = 1.0
        # about 1/num_filters, so only one filter active at a time
        # 40 * 0.05 = ~2 filters active for any given pixel
        self.sparsity_p = 0.05

        self.crbm = CRBM( \
                    minibatch_size=self.minibatch_size,
                    image_size=self.image_size,
                    conv_params=self.cp,
                    learning_rate=self.lr,
                    sparsity_lambda=self.sparsity_lambda,
                    sparsity_p=self.sparsity_p)
        
        self.num_epochs = 10

        self.init_series()
 
    def init_series(self):

        series = {}

        basedir = os.getcwd()

        h5f = tables.openFile(os.path.join(basedir, "series.h5"), "w")

        cd_series_names = self.crbm.cd_return_desc
        series['cd'] = \
            get_accumulator_series_array( \
                h5f, 'cd', cd_series_names,
                REDUCE_EVERY,
                stdout_too=True)

        sparsity_series_names = self.crbm.sparsity_return_desc
        series['sparsity'] = \
            get_accumulator_series_array( \
                h5f, 'sparsity', sparsity_series_names,
                REDUCE_EVERY,
                stdout_too=True)

        # so first we create the names for each table, based on 
        # position of each param in the array
        params_stdout = StdoutAppendTarget("\n------\nParams")
        series['params'] = SharedParamsStatisticsWrapper(
                            new_group_name="params",
                            base_group="/",
                            arrays_names=['W','b_h','b_x'],
                            hdf5_file=h5f,
                            index_names=('epoch','minibatch'),
                            other_targets=[params_stdout])

        self.series = series

    def train(self):
        num_minibatches = len(self.mnist.train.x) / self.minibatch_size

        print_every = 1000
        visualize_every = 5000
        gibbs_steps_from_random = 1000

        for epoch in xrange(self.num_epochs):
            for mb_index in xrange(num_minibatches):
                mb_x = self.mnist.train.x \
                         [mb_index : mb_index+self.minibatch_size]
                mb_x = mb_x.reshape((self.minibatch_size, 1, 28, 28))

                #E_h = crbm.E_h_given_x_func(mb_x)
                #print "Shape of E_h", E_h.shape

                cd_return = self.crbm.CD_step(mb_x)
                sp_return = self.crbm.sparsity_step(mb_x)

                self.series['cd'].append( \
                        (epoch, mb_index), cd_return)
                self.series['sparsity'].append( \
                        (epoch, mb_index), sp_return)

                total_idx = epoch*num_minibatches + mb_index

                if (total_idx+1) % REDUCE_EVERY == 0:
                    self.series['params'].append( \
                        (epoch, mb_index), self.crbm.params)

                if total_idx % visualize_every == 0:
                    self.visualize_gibbs_result(\
                        mb_x, gibbs_steps_from_random)
                    self.visualize_gibbs_result(mb_x, 1)
                    self.visualize_filters()
    
    def visualize_gibbs_result(self, start_x, gibbs_steps):
        # Run minibatch_size chains for gibbs_steps
        x_samples = None
        if not start_x is None:
            x_samples = self.crbm.gibbs_samples_from(start_x, gibbs_steps)
        else:
            x_samples = self.crbm.random_gibbs_samples(gibbs_steps)
        x_samples = x_samples.reshape((self.minibatch_size, 28*28))
 
        tile = tile_raster_images(x_samples, self.image_size,
                    (1, self.minibatch_size), output_pixel_vals=True)

        filepath = os.path.join(IMAGE_OUTPUT_DIR,
                    filename_from_time("gibbs"))
        img = Image.fromarray(tile)
        img.save(filepath)

        print "Result of running Gibbs", \
                gibbs_steps, "times outputed to", filepath

    def visualize_filters(self):
        cp = self.cp

        # filter size
        fsz = (cp.height_filters, cp.width_filters)
        tile_shape = (cp.num_filters, cp.num_input_planes)

        filters_flattened = self.crbm.W.value.reshape(
                                (tile_shape[0]*tile_shape[1],
                                fsz[0]*fsz[1]))

        tile = tile_raster_images(filters_flattened, fsz, 
                                    tile_shape, output_pixel_vals=True)

        filepath = os.path.join(IMAGE_OUTPUT_DIR,
                        filename_from_time("filters"))
        img = Image.fromarray(tile)
        img.save(filepath)

        print "Filters (as images) outputed to", filepath
        print "b_h is", self.crbm.b_h.value




if __name__ == '__main__':
    mc = MnistCrbm()
    mc.train()