changeset 823:e53c06901f8f

added daa example
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 10 Sep 2009 10:30:35 -0400
parents a7dc8b28f4bc
children bfc5914642ce
files pylearn/examples/daa.py
diffstat 1 files changed, 206 insertions(+), 0 deletions(-) [+]
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/pylearn/examples/daa.py	Thu Sep 10 10:30:35 2009 -0400
@@ -0,0 +1,206 @@
+import math, sys, time, copy, cPickle, shutil, functools
+import numpy 
+
+import theano
+from theano import tensor
+from theano.compile import module
+from theano.sandbox.softsign import softsign
+from theano import tensor as T, sparse as S
+from theano.tensor import nnet, as_tensor
+from theano.tensor.randomstreams import RandomStreams
+from theano.printing import Print
+from theano.compile.mode import Mode
+
+from ..algorithms import cost, minimizer, logistic_regression, sgd, stopper
+from ..io import filetensor
+from ..datasets import MNIST
+
+
+class AbstractFunction(Exception): """Override me"""
+
+#
+# default corruption function for BasicDAA
+# to extend
+#
+def corrupt_random_zeros(x, rng, p_zero):
+    mask = rng.binomial(T.shape(x), 1, 1.0 - p_zero)
+    return mask * x
+
+#
+# BasicDAA
+# ========
+#
+# Re-usable module, could be in pylearn.algorithms.daa
+#
+class BasicDAA(theano.Module):
+    def __init__(self, n_visible, n_hidden, seed,
+            w=None, 
+            vbias=None, 
+            hbias=None,
+            sigmoid=nnet.sigmoid,
+            corrupt=functools.partial(corrupt_random_zeros, p_zero=0.1),
+            reconstruction_cost=cost.cross_entropy,
+            w_scale = None
+            ):
+        """
+        :param w_scale: should be a floating-point value or None.  The weights are initialized
+        to a random value on the interval (-1,1) and scaled by this value.
+        None means the default of 1/sqrt(max(n_visible, n_hidden))
+
+
+        """
+        super(BasicDAA, self).__init__()
+
+        self.n_visible = n_visible
+        self.n_hidden = n_hidden
+        self.random = RandomStreams()
+        self.seed = seed
+        self.w_scale = w_scale if w_scale is not None else 1.0 / math.sqrt(max(n_visible,n_hidden))
+        self.sigmoid = sigmoid
+        self.corrupt = corrupt
+        self.reconstruction_cost = reconstruction_cost
+
+        self.w = tensor.dmatrix() if w is None else w
+        self.vbias = tensor.dvector() if vbias is None else vbias
+        self.hbias = tensor.dvector() if hbias is None else hbias
+
+        self.params = [self.w, self.vbias, self.hbias]
+
+    def _instance_initialize(self, obj):
+        #consider offering override parameters for seed, n_visible, n_hidden, w_scale
+        super(BasicDAA, self)._instance_initialize(obj)
+        rng = numpy.random.RandomState(self.seed)
+        s = rng.randint(2**30)
+        obj.w = (rng.rand(self.n_visible, self.n_hidden) * 2.0 - 1.0) * self.w_scale
+        obj.vbias = numpy.zeros(self.n_visible)
+        obj.hbias = numpy.zeros(self.n_hidden)
+        obj.random.initialize(int(s))
+
+    def hidden_act(self, visible):
+        return theano.dot(visible, self.w) + self.hbias
+
+    def hidden(self, visible):
+        return self.sigmoid(self.hidden_act(visible))
+
+    def visible_act(self, hidden):
+        return theano.dot(hidden, self.w.T) + self.vbias
+
+    def visible(self, hidden):
+        return self.sigmoid(self.visible_act(hidden))
+
+    def reconstruction(self, visible):
+        return self.visible(self.hidden(self.corrupt(visible, self.random)))
+
+    def daa_cost(self, visible):
+        return self.reconstruction_cost(visible, self.reconstruction(visible))
+
+    def l2_cost(self):
+        return self.w.norm(2)
+
+
+
+#
+# StackedDAA
+# ==========
+#
+# Hacky experiment type code, which would *not* be in pylearn.
+#
+# This Module can be extended / parametrized so many ways that I think this code is best cut &
+# pasted.
+#
+class StackedDAA(theano.Module):
+    def __init__(self, layer_widths, seed, finetune_lr=1e-3, pretrain_lr=1e-4):
+        super(StackedDAA, self).__init__()
+        input = theano.tensor.dmatrix()
+
+        #the parameters of this function, required for the minimizer
+        self.params = [] 
+        daa_widths = layer_widths[:-1]
+
+        #create the stack of DAA modules, and the self.params list
+        self.daa_list = []
+        daa_input = input
+        for i in xrange(len(daa_widths)-1):
+            self.daa_list.append(BasicDAA(daa_widths[i], daa_widths[i+1], seed+i))
+            daa_input = self.daa_list[-1].hidden(daa_input)
+            self.params.extend(self.daa_list[-1].params)
+
+        #put a logistic regression module on top for classification
+        self.classif = logistic_regression.LogRegN(input=daa_input,
+                n_in=layer_widths[-2], n_out = layer_widths[-1])
+        self.params.extend(self.classif.params)
+
+        #set up the fine-tuning function (minimizer.step)
+        FineTuneMinimizer = sgd.sgd_minimizer(stepsize=finetune_lr)
+        self.finetune = FineTuneMinimizer([input, self.classif.target], self.classif.unregularized_cost, self.params)
+
+        #set up the pre-training function
+        pretrain_cost, pretrain_input, pretrain_params = reduce(
+                lambda (c,i,p), daa: (c + daa.daa_cost(i), daa.hidden(i), p + daa.params),
+                self.daa_list,
+                (0.0, input, []))
+        PreTrainMinimizer = sgd.sgd_minimizer(stepsize=pretrain_lr)
+        self.pretrain = PreTrainMinimizer([input], pretrain_cost, pretrain_params)
+        
+    def _instance_initialize(self, obj):
+        #consider offering override parameters for seed, n_visible, n_hidden, w_scale
+        super(StackedDAA, self)._instance_initialize(obj)
+
+        #ugh... why do i need to do this???
+        for daa in obj.daa_list:
+            daa.initialize()
+        obj.classif.initialize()
+        obj.finetune.initialize()
+        obj.pretrain.initialize()
+
+#
+# DRIVER
+# ======
+#
+# This learning algorithm is the sort of thing that we've put in 'experiment' functions, that
+# can be run using dbdict.  
+#
+def demo_random(layer_widths=[3,4,5]):
+    sdaa = StackedDAA(layer_widths, seed=666).make(mode='FAST_COMPILE')
+
+    # create some training data
+    rng = numpy.random.RandomState(7832)
+    input_data = rng.randn(10,3)
+    targ_data = rng.binomial(1,.5, size=10)
+
+    print 'Pre-training ...'
+    for i in xrange(5):
+        print sdaa.pretrain.step_cost(input_data)
+
+    print 'Fine-tuning ...'
+    for i in xrange(5):
+        print sdaa.finetune.step_cost(input_data, targ_data)
+
+
+def demo_mnist(layer_widths=[784,500,500]):
+    sdaa = StackedDAA(layer_widths, seed=666).make()
+
+    mnist = MNIST.full()
+    batchsize=10
+    n_pretrain_batches=10000
+    n_finetune_batches=10000
+
+    t0 = time.time()
+    print 'Pre-training ...'
+    for i in xrange(n_pretrain_batches):
+        ii = (i*batchsize) % len(mnist.train.x)
+        x = mnist.train.x[ii:ii+batchsize]
+        c = sdaa.pretrain.step_cost(x)
+        if not i % 100:
+            print i, n_pretrain_batches, time.time() - t0, c
+
+    t1 = time.time()
+    print 'Fine-tuning ...'
+    for i in xrange(n_finetune_batches):
+        ii = (i*batchsize) % len(mnist.train.x)
+        x = mnist.train.x[ii:ii+batchsize]
+        y = mnist.train.y[ii:ii+batchsize]
+        c = sdaa.finetune.step_cost(x, y)
+        if not i % 100:
+            print i, n_finetune_batches, time.time() - t1, c
+