Mercurial > pylearn
changeset 823:e53c06901f8f
added daa example
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 10 Sep 2009 10:30:35 -0400 |
parents | a7dc8b28f4bc |
children | bfc5914642ce |
files | pylearn/examples/daa.py |
diffstat | 1 files changed, 206 insertions(+), 0 deletions(-) [+] |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/pylearn/examples/daa.py Thu Sep 10 10:30:35 2009 -0400 @@ -0,0 +1,206 @@ +import math, sys, time, copy, cPickle, shutil, functools +import numpy + +import theano +from theano import tensor +from theano.compile import module +from theano.sandbox.softsign import softsign +from theano import tensor as T, sparse as S +from theano.tensor import nnet, as_tensor +from theano.tensor.randomstreams import RandomStreams +from theano.printing import Print +from theano.compile.mode import Mode + +from ..algorithms import cost, minimizer, logistic_regression, sgd, stopper +from ..io import filetensor +from ..datasets import MNIST + + +class AbstractFunction(Exception): """Override me""" + +# +# default corruption function for BasicDAA +# to extend +# +def corrupt_random_zeros(x, rng, p_zero): + mask = rng.binomial(T.shape(x), 1, 1.0 - p_zero) + return mask * x + +# +# BasicDAA +# ======== +# +# Re-usable module, could be in pylearn.algorithms.daa +# +class BasicDAA(theano.Module): + def __init__(self, n_visible, n_hidden, seed, + w=None, + vbias=None, + hbias=None, + sigmoid=nnet.sigmoid, + corrupt=functools.partial(corrupt_random_zeros, p_zero=0.1), + reconstruction_cost=cost.cross_entropy, + w_scale = None + ): + """ + :param w_scale: should be a floating-point value or None. The weights are initialized + to a random value on the interval (-1,1) and scaled by this value. + None means the default of 1/sqrt(max(n_visible, n_hidden)) + + + """ + super(BasicDAA, self).__init__() + + self.n_visible = n_visible + self.n_hidden = n_hidden + self.random = RandomStreams() + self.seed = seed + self.w_scale = w_scale if w_scale is not None else 1.0 / math.sqrt(max(n_visible,n_hidden)) + self.sigmoid = sigmoid + self.corrupt = corrupt + self.reconstruction_cost = reconstruction_cost + + self.w = tensor.dmatrix() if w is None else w + self.vbias = tensor.dvector() if vbias is None else vbias + self.hbias = tensor.dvector() if hbias is None else hbias + + self.params = [self.w, self.vbias, self.hbias] + + def _instance_initialize(self, obj): + #consider offering override parameters for seed, n_visible, n_hidden, w_scale + super(BasicDAA, self)._instance_initialize(obj) + rng = numpy.random.RandomState(self.seed) + s = rng.randint(2**30) + obj.w = (rng.rand(self.n_visible, self.n_hidden) * 2.0 - 1.0) * self.w_scale + obj.vbias = numpy.zeros(self.n_visible) + obj.hbias = numpy.zeros(self.n_hidden) + obj.random.initialize(int(s)) + + def hidden_act(self, visible): + return theano.dot(visible, self.w) + self.hbias + + def hidden(self, visible): + return self.sigmoid(self.hidden_act(visible)) + + def visible_act(self, hidden): + return theano.dot(hidden, self.w.T) + self.vbias + + def visible(self, hidden): + return self.sigmoid(self.visible_act(hidden)) + + def reconstruction(self, visible): + return self.visible(self.hidden(self.corrupt(visible, self.random))) + + def daa_cost(self, visible): + return self.reconstruction_cost(visible, self.reconstruction(visible)) + + def l2_cost(self): + return self.w.norm(2) + + + +# +# StackedDAA +# ========== +# +# Hacky experiment type code, which would *not* be in pylearn. +# +# This Module can be extended / parametrized so many ways that I think this code is best cut & +# pasted. +# +class StackedDAA(theano.Module): + def __init__(self, layer_widths, seed, finetune_lr=1e-3, pretrain_lr=1e-4): + super(StackedDAA, self).__init__() + input = theano.tensor.dmatrix() + + #the parameters of this function, required for the minimizer + self.params = [] + daa_widths = layer_widths[:-1] + + #create the stack of DAA modules, and the self.params list + self.daa_list = [] + daa_input = input + for i in xrange(len(daa_widths)-1): + self.daa_list.append(BasicDAA(daa_widths[i], daa_widths[i+1], seed+i)) + daa_input = self.daa_list[-1].hidden(daa_input) + self.params.extend(self.daa_list[-1].params) + + #put a logistic regression module on top for classification + self.classif = logistic_regression.LogRegN(input=daa_input, + n_in=layer_widths[-2], n_out = layer_widths[-1]) + self.params.extend(self.classif.params) + + #set up the fine-tuning function (minimizer.step) + FineTuneMinimizer = sgd.sgd_minimizer(stepsize=finetune_lr) + self.finetune = FineTuneMinimizer([input, self.classif.target], self.classif.unregularized_cost, self.params) + + #set up the pre-training function + pretrain_cost, pretrain_input, pretrain_params = reduce( + lambda (c,i,p), daa: (c + daa.daa_cost(i), daa.hidden(i), p + daa.params), + self.daa_list, + (0.0, input, [])) + PreTrainMinimizer = sgd.sgd_minimizer(stepsize=pretrain_lr) + self.pretrain = PreTrainMinimizer([input], pretrain_cost, pretrain_params) + + def _instance_initialize(self, obj): + #consider offering override parameters for seed, n_visible, n_hidden, w_scale + super(StackedDAA, self)._instance_initialize(obj) + + #ugh... why do i need to do this??? + for daa in obj.daa_list: + daa.initialize() + obj.classif.initialize() + obj.finetune.initialize() + obj.pretrain.initialize() + +# +# DRIVER +# ====== +# +# This learning algorithm is the sort of thing that we've put in 'experiment' functions, that +# can be run using dbdict. +# +def demo_random(layer_widths=[3,4,5]): + sdaa = StackedDAA(layer_widths, seed=666).make(mode='FAST_COMPILE') + + # create some training data + rng = numpy.random.RandomState(7832) + input_data = rng.randn(10,3) + targ_data = rng.binomial(1,.5, size=10) + + print 'Pre-training ...' + for i in xrange(5): + print sdaa.pretrain.step_cost(input_data) + + print 'Fine-tuning ...' + for i in xrange(5): + print sdaa.finetune.step_cost(input_data, targ_data) + + +def demo_mnist(layer_widths=[784,500,500]): + sdaa = StackedDAA(layer_widths, seed=666).make() + + mnist = MNIST.full() + batchsize=10 + n_pretrain_batches=10000 + n_finetune_batches=10000 + + t0 = time.time() + print 'Pre-training ...' + for i in xrange(n_pretrain_batches): + ii = (i*batchsize) % len(mnist.train.x) + x = mnist.train.x[ii:ii+batchsize] + c = sdaa.pretrain.step_cost(x) + if not i % 100: + print i, n_pretrain_batches, time.time() - t0, c + + t1 = time.time() + print 'Fine-tuning ...' + for i in xrange(n_finetune_batches): + ii = (i*batchsize) % len(mnist.train.x) + x = mnist.train.x[ii:ii+batchsize] + y = mnist.train.y[ii:ii+batchsize] + c = sdaa.finetune.step_cost(x, y) + if not i % 100: + print i, n_finetune_batches, time.time() - t1, c +