diff deep/convolutional_dae/scdae.py @ 276:727ed56fad12

Add reworked code for convolutional auto-encoder.
author Arnaud Bergeron <abergeron@gmail.com>
date Mon, 22 Mar 2010 13:33:29 -0400
parents
children 20ebc1f2a9fe
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/deep/convolutional_dae/scdae.py	Mon Mar 22 13:33:29 2010 -0400
@@ -0,0 +1,229 @@
+from pynnet import *
+# use hacks also
+from pynnet.utils import *
+
+import numpy
+import theano
+import theano.tensor as T
+
+from itertools import izip
+
+class cdae(LayerStack):
+    def __init__(self, filter_size, num_filt, num_in, subsampling, corruption,
+                 dtype, img_shape):
+        LayerStack.__init__(self, [ConvAutoencoder(filter_size=filter_size, 
+                                                   num_filt=num_filt,
+                                                   num_in=num_in,
+                                                   noisyness=corruption,
+                                                   dtype=dtype,
+                                                   image_shape=img_shape),
+                                   MaxPoolLayer(subsampling)])
+
+    def build(self, input):
+        LayerStack.build(self, input)
+        self.cost = self.layers[0].cost
+
+def cdae_out_size(in_size, filt_size, num_filt, num_in, subs):
+    out = [None] * 3
+    out[0] = num_filt
+    out[1] = (in_size[1]-filt_size[0]+1)/subs[0]
+    out[2] = (in_size[2]-filt_size[1]+1)/subs[1]
+    return out
+
+def scdae(in_size, num_in, filter_sizes, num_filts,
+          subsamplings, corruptions, dtype):
+    layers = []
+    old_nfilt = 1
+    for fsize, nfilt, subs, corr in izip(filter_sizes, num_filts,
+                                         subsamplings, corruptions):
+        layers.append(cdae(fsize, nfilt, old_nfilt, subs, corr, dtype,
+                           (num_in, in_size[0], in_size[1], in_size[2])))
+        in_size = cdae_out_size(in_size, fsize, nfilt, old_nfilt, subs)
+        old_nfilt = nfilt
+    return LayerStack(layers), in_size
+
+def mlp(layer_sizes, dtype):
+    layers = []
+    old_size = layer_sizes[0]
+    for size in layer_sizes[1:]:
+        layers.append(SimpleLayer(old_size, size, activation=nlins.tanh,
+                                  dtype=dtype))
+        old_size = size
+    return LayerStack(layers)
+
+def scdae_net(in_size, num_in, filter_sizes, num_filts, subsamplings,
+              corruptions, layer_sizes, out_size, dtype, batch_size):
+    rl1 = ReshapeLayer((None,)+in_size)
+    ls, outs = scdae(in_size, num_in, filter_sizes, num_filts, subsamplings, 
+                     corruptions, dtype)
+    outs = numpy.prod(outs)
+    rl2 = ReshapeLayer((None, outs))
+    layer_sizes = [outs]+layer_sizes
+    ls2 = mlp(layer_sizes, dtype)
+    lrl = SimpleLayer(layer_sizes[-1], out_size, activation=nlins.sigmoid)
+    return NNet([rl1, ls, rl2, ls2, lrl], error=errors.nll)
+
+def build_funcs(batch_size, img_size, filter_sizes, num_filters, subs,
+                noise, mlp_sizes, out_size, dtype, pretrain_lr, train_lr):
+    
+    n = scdae_net((1,)+img_size, batch_size, filter_sizes, num_filters, subs,
+                  noise, mlp_sizes, out_size, dtype, batch_size)
+    x = T.fmatrix('x')
+    y = T.ivector('y')
+    
+    def pretrainfunc(net, alpha):
+        up = trainers.get_updates(net.params, net.cost, alpha)
+        return theano.function([x], net.cost, updates=up)
+
+    def trainfunc(net, alpha):
+        up = trainers.get_updates(net.params, net.cost, alpha)
+        return theano.function([x, y], net.cost, updates=up)
+
+    n.build(x, y)
+    pretrain_funcs_opt = [pretrainfunc(l, pretrain_lr) for l in n.layers[1].layers]
+    trainf_opt = trainfunc(n, train_lr)
+    evalf_opt = theano.function([x, y], errors.class_error(n.output, y))
+    
+    clear_imgshape(n)
+    n.build(x, y)
+    pretrain_funcs_reg = [pretrainfunc(l, 0.01) for l in n.layers[1].layers]
+    trainf_reg = trainfunc(n, 0.1)
+    evalf_reg = theano.function([x, y], errors.class_error(n.output, y))
+    
+    def select_f(f1, f2, bsize):
+        def f(x):
+            if x.shape[0] == bsize:
+                return f1(x)
+            else:
+                return f2(x)
+        return f
+    
+    pretrain_funcs = [select_f(p_opt, p_reg, batch_size) for p_opt, p_reg in zip(pretrain_funcs_opt, pretrain_funcs_reg)]
+    
+    def select_f2(f1, f2, bsize):
+        def f(x, y):
+            if x.shape[0] == bsize:
+                return f1(x, y)
+            else:
+                return f2(x, y)
+        return f
+
+    trainf = select_f2(trainf_opt, trainf_reg, batch_size)
+    evalf = select_f2(evalf_opt, evalf_reg, batch_size)
+    return pretrain_funcs, trainf, evalf
+
+def do_pretrain(pretrain_funcs, pretrain_epochs):
+    for f in pretrain_funcs:
+        for i in xrange(pretrain_epochs):
+            f()
+
+def massage_funcs(batch_size, dset, pretrain_funcs, trainf, evalf):
+    def pretrain_f(f):
+        def res():
+            for x, y in dset.train(batch_size):
+                print "pretrain:", f(x)
+        return res
+
+    pretrain_fs = map(pretrain_f, pretrain_funcs)
+
+    def train_f(f, dsetf):
+        def dset_it():
+            while True:
+                for x, y in dsetf(batch_size):
+                    yield f(x, y)
+        it = dset_it()
+        return lambda: it.next()
+
+    train = train_f(trainf, dset.train)
+
+    def eval_f(f, dsetf):
+        def res():
+            c = 0
+            i = 0
+            for x, y in dsetf(batch_size):
+                i += x.shape[0]
+                c += f(x, y)*x.shape[0]
+            return c/i
+        return res
+    
+    test = eval_f(evalf, dset.test)
+    valid = eval_f(evalf, dset.valid)
+
+    return pretrain_fs, train, valid, test
+
+def run_exp(state, channel):
+    from ift6266 import datasets
+    from sgd_opt import sgd_opt
+    import sys, time
+
+    channel.save()
+
+    # params: bsize, pretrain_lr, train_lr, nfilts1, nfilts2, nftils3, nfilts4
+    #         pretrain_rounds
+
+    dset = dataset.nist_all()
+
+    nfilts = []
+    if state.nfilts1 != 0:
+        nfilts.append(state.nfilts1)
+        if state.nfilts2 != 0:
+            nfilts.append(state.nfilts2)
+            if state.nfilts3 != 0:
+                nfilts.append(state.nfilts3)
+                if state.nfilts4 != 0:
+                    nfilts.append(state.nfilts4)
+
+    fsizes = [(5,5)]*len(nfilts)
+    subs = [(2,2)]*len(nfilts)
+    noise = [state.noise]*len(nfilts)
+
+    pretrain_funcs, trainf, evalf = build_funcs(
+        img_size=(32, 32),
+        batch_size=state.bsize,
+        filter_sizes=fsizes,
+        num_filters=nfilts,
+        subs=subs,
+        noise=noise,
+        mlp_sizes=[state.mlp_sz],
+        out_size=62,
+        dtype=numpy.float32,
+        pretrain_lr=state.pretrain_lr,
+        train_lr=state.train_lr)
+
+    pretrain_fs, train, valid, test = massage_funcs(
+        state.bsize, dset, pretrain_funcs, trainf, evalf)
+
+    do_pretrain(pretrain_fs, state.pretrain_rounds)
+
+    sgd_opt(train, valid, test, training_epochs=100000, patience=10000,
+            patience_increase=2., improvement_threshold=0.995,
+            validation_frequency=2500)
+
+if __name__ == '__main__':
+    from ift6266 import datasets
+    from sgd_opt import sgd_opt
+    import sys, time
+    
+    batch_size = 100
+    dset = datasets.mnist(200)
+
+    pretrain_funcs, trainf, evalf = build_funcs(
+        img_size = (28, 28),
+        batch_size=batch_size, filter_sizes=[(5,5), (5,5)],
+        num_filters=[4, 3], subs=[(2,2), (2,2)], noise=[0.2, 0.2],
+        mlp_sizes=[500], out_size=10, dtype=numpy.float32,
+        pretrain_lr=0.01, train_lr=0.1)
+
+    pretrain_fs, train, valid, test = massage_funcs(
+        batch_size, dset, pretrain_funcs, trainf, evalf)
+
+    print "pretraining ...",
+    sys.stdout.flush()
+    start = time.time()
+    do_pretrain(pretrain_fs, 0)
+    end = time.time()
+    print "done (in", end-start, "s)"
+    
+    sgd_opt(train, valid, test, training_epochs=1000, patience=1000,
+            patience_increase=2., improvement_threshold=0.995,
+            validation_frequency=500)