changeset 476:8fcd0f3d9a17

added a few algorithms
author Olivier Breuleux <breuleuo@iro.umontreal.ca>
date Mon, 27 Oct 2008 17:26:00 -0400
parents 11e0357f06f4
children 1babf35fcef5 b15dad843c8c
files algorithms/__init__.py algorithms/aa.py algorithms/daa.py algorithms/logistic_regression.py algorithms/regressor.py algorithms/stacker.py algorithms/tests/test_aa.py algorithms/tests/test_regressor.py algorithms/tests/test_stacker.py
diffstat 9 files changed, 578 insertions(+), 1 deletions(-) [+]
line wrap: on
line diff
--- a/algorithms/__init__.py	Thu Oct 23 18:06:21 2008 -0400
+++ b/algorithms/__init__.py	Mon Oct 27 17:26:00 2008 -0400
@@ -0,0 +1,5 @@
+
+from regressor import Regressor, BinRegressor
+from aa import AutoEncoder, SigmoidXEAutoEncoder
+from daa import DenoisingAA, SigmoidXEDenoisingAA
+from stacker import Stacker
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/algorithms/aa.py	Mon Oct 27 17:26:00 2008 -0400
@@ -0,0 +1,108 @@
+
+import theano
+from theano import tensor as T
+from theano.tensor import nnet as NN
+import numpy as N
+
+class AutoEncoder(theano.FancyModule):
+
+    def __init__(self, input = None, regularize = True, tie_weights = True):
+        super(AutoEncoder, self).__init__()
+
+        # MODEL CONFIGURATION
+        self.regularize = regularize
+        self.tie_weights = tie_weights
+
+        # ACQUIRE/MAKE INPUT
+        if not input:
+            input = T.matrix('input')
+        self.input = theano.External(input)
+
+        # HYPER-PARAMETERS
+        self.lr = theano.Member(T.scalar())
+
+        # PARAMETERS
+        self.w1 = theano.Member(T.matrix())
+        if not tie_weights:
+            self.w2 = theano.Member(T.matrix())
+        else:
+            self.w2 = self.w1.T
+        self.b1 = theano.Member(T.vector())
+        self.b2 = theano.Member(T.vector())
+
+        # HIDDEN LAYER
+        self.hidden_activation = T.dot(input, self.w1) + self.b1
+        self.hidden = self.build_hidden()
+
+        # RECONSTRUCTION LAYER
+        self.output_activation = T.dot(self.hidden, self.w2) + self.b2
+        self.output = self.build_output()
+
+        # RECONSTRUCTION COST
+        self.reconstruction_cost = self.build_reconstruction_cost()
+
+        # REGULARIZATION COST
+        self.regularization = self.build_regularization()
+
+        # TOTAL COST
+        self.cost = self.reconstruction_cost
+        if self.regularize:
+            self.cost = self.cost + self.regularization
+
+        # GRADIENTS AND UPDATES
+        if self.tie_weights:
+            self.params = self.w1, self.b1, self.b2
+        else:
+            self.params = self.w1, self.w2, self.b1, self.b2
+        gradients = T.grad(self.cost, self.params)
+        updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gradients))
+
+        # INTERFACE METHODS
+        self.update = theano.Method(input, self.cost, updates)
+        self.reconstruction = theano.Method(input, self.output)
+        self.representation = theano.Method(input, self.hidden)
+
+    def _instance_initialize(self, obj, input_size = None, hidden_size = None, seed = None, **init):
+        if (input_size is None) ^ (hidden_size is None):
+            raise ValueError("Must specify hidden_size and target_size or neither.")
+        super(AutoEncoder, self)._instance_initialize(obj, **init)
+        if seed is not None:
+            R = N.random.RandomState(seed)
+        else:
+            R = N.random
+        if input_size is not None:
+            sz = (input_size, hidden_size)
+            range = 1/N.sqrt(input_size)
+            obj.w1 = R.uniform(size = sz, low = -range, high = range)
+            if not self.tie_weights:
+                obj.w2 = R.uniform(size = list(reversed(sz)), low = -range, high = range)
+            obj.b1 = N.zeros(hidden_size)
+            obj.b2 = N.zeros(input_size)
+
+    def build_regularization(self):
+        return T.zero() # no regularization!
+
+
+class SigmoidXEAutoEncoder(AutoEncoder):
+
+    def build_hidden(self):
+        return NN.sigmoid(self.hidden_activation)
+
+    def build_output(self):
+        return NN.sigmoid(self.output_activation)
+
+    def build_reconstruction_cost(self):
+        self.reconstruction_cost_matrix = self.input * T.log(self.output) + (1.0 - self.input) * T.log(1.0 - self.output)
+        self.reconstruction_costs = -T.sum(self.reconstruction_cost_matrix, axis=1)
+        return T.sum(self.reconstruction_costs)
+
+    def build_regularization(self):
+        self.l2_coef = theano.Member(T.scalar())
+        if self.tie_weights:
+            return self.l2_coef * T.sum(self.w1 * self.w1)
+        else:
+            return self.l2_coef * T.sum(self.w1 * self.w1) + T.sum(self.w2 * self.w2)
+
+    def _instance_initialize(self, obj, input_size = None, hidden_size = None, **init):
+        init.setdefault('l2_coef', 0)
+        super(SigmoidXEAutoEncoder, self)._instance_initialize(obj, input_size, hidden_size, **init)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/algorithms/daa.py	Mon Oct 27 17:26:00 2008 -0400
@@ -0,0 +1,147 @@
+
+import theano
+from theano import tensor as T
+from theano.tensor import nnet as NN
+import numpy as N
+
+class DenoisingAA(T.RModule):
+
+    def __init__(self, input = None, regularize = True, tie_weights = True):
+        super(DenoisingAA, self).__init__()
+
+        # MODEL CONFIGURATION
+        self.regularize = regularize
+        self.tie_weights = tie_weights
+
+        # ACQUIRE/MAKE INPUT
+        if not input:
+            input = T.matrix('input')
+        self.input = theano.External(input)
+
+        # HYPER-PARAMETERS
+        self.lr = theano.Member(T.scalar())
+
+        # PARAMETERS
+        self.w1 = theano.Member(T.matrix())
+        if not tie_weights:
+            self.w2 = theano.Member(T.matrix())
+        else:
+            self.w2 = self.w1.T
+        self.b1 = theano.Member(T.vector())
+        self.b2 = theano.Member(T.vector())
+
+
+        # REGULARIZATION COST
+        self.regularization = self.build_regularization()
+
+
+        ### NOISELESS ###
+
+        # HIDDEN LAYER
+        self.hidden_activation = T.dot(self.input, self.w1) + self.b1
+        self.hidden = self.hid_activation_function(self.hidden_activation)
+
+        # RECONSTRUCTION LAYER
+        self.output_activation = T.dot(self.hidden, self.w2) + self.b2
+        self.output = self.out_activation_function(self.output_activation)
+
+        # RECONSTRUCTION COST
+        self.reconstruction_costs = self.build_reconstruction_costs(self.output)
+        self.reconstruction_cost = T.mean(self.reconstruction_costs)
+
+        # TOTAL COST
+        self.cost = self.reconstruction_cost
+        if self.regularize:
+            self.cost = self.cost + self.regularization
+
+
+        ### WITH NOISE ###
+        self.corrupted_input = self.build_corrupted_input()
+
+        # HIDDEN LAYER
+        self.nhidden_activation = T.dot(self.corrupted_input, self.w1) + self.b1
+        self.nhidden = self.hid_activation_function(self.nhidden_activation)
+
+        # RECONSTRUCTION LAYER
+        self.noutput_activation = T.dot(self.nhidden, self.w2) + self.b2
+        self.noutput = self.out_activation_function(self.noutput_activation)
+
+        # RECONSTRUCTION COST
+        self.nreconstruction_costs = self.build_reconstruction_costs(self.noutput)
+        self.nreconstruction_cost = T.mean(self.nreconstruction_costs)
+
+        # TOTAL COST
+        self.ncost = self.nreconstruction_cost
+        if self.regularize:
+            self.ncost = self.ncost + self.regularization
+
+
+        # GRADIENTS AND UPDATES
+        if self.tie_weights:
+            self.params = self.w1, self.b1, self.b2
+        else:
+            self.params = self.w1, self.w2, self.b1, self.b2
+        gradients = T.grad(self.ncost, self.params)
+        updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gradients))
+
+        # INTERFACE METHODS
+        self.update = theano.Method(self.input, self.ncost, updates)
+        self.compute_cost = theano.Method(self.input, self.cost)
+        self.noisify = theano.Method(self.input, self.corrupted_input)
+        self.reconstruction = theano.Method(self.input, self.output)
+        self.representation = theano.Method(self.input, self.hidden)
+        self.reconstruction_through_noise = theano.Method(self.input, [self.corrupted_input, self.noutput])
+
+    def _instance_initialize(self, obj, input_size = None, hidden_size = None, seed = None, **init):
+        if (input_size is None) ^ (hidden_size is None):
+            raise ValueError("Must specify hidden_size and target_size or neither.")
+        super(DenoisingAA, self)._instance_initialize(obj, **init)
+        if seed is not None:
+            R = N.random.RandomState(seed)
+        else:
+            R = N.random
+        if input_size is not None:
+            sz = (input_size, hidden_size)
+            inf = 1/N.sqrt(input_size)
+            hif = 1/N.sqrt(hidden_size)
+            obj.w1 = R.uniform(size = sz, low = -inf, high = inf)
+            if not self.tie_weights:
+                obj.w2 = R.uniform(size = list(reversed(sz)), low = -inf, high = inf)
+            obj.b1 = N.zeros(hidden_size)
+            obj.b2 = N.zeros(input_size)
+        if seed is not None:
+            self.seed(seed)
+        obj.__hide__ = ['params']
+
+    def build_regularization(self):
+        return T.zero() # no regularization!
+
+
+class SigmoidXEDenoisingAA(DenoisingAA):
+
+    def build_corrupted_input(self):
+        self.noise_level = theano.Member(T.scalar())
+        return self.random.binomial(T.shape(self.input), 1, 1 - self.noise_level) * self.input
+
+    def hid_activation_function(self, activation):
+        return NN.sigmoid(activation)
+
+    def out_activation_function(self, activation):
+        return NN.sigmoid(activation)
+
+    def build_reconstruction_costs(self, output):
+        reconstruction_cost_matrix = -(self.input * T.log(output) + (1 - self.input) * T.log(1 - output))
+        return T.sum(reconstruction_cost_matrix, axis=1)
+
+    def build_regularization(self):
+        self.l2_coef = theano.Member(T.scalar())
+        if self.tie_weights:
+            return self.l2_coef * T.sum(self.w1 * self.w1)
+        else:
+            return self.l2_coef * T.sum(self.w1 * self.w1) + T.sum(self.w2 * self.w2)
+
+    def _instance_initialize(self, obj, input_size = None, hidden_size = None, seed = None, **init):
+        init.setdefault('noise_level', 0)
+        init.setdefault('l2_coef', 0)
+        super(SigmoidXEDenoisingAA, self)._instance_initialize(obj, input_size, hidden_size, seed, **init)
+
--- a/algorithms/logistic_regression.py	Thu Oct 23 18:06:21 2008 -0400
+++ b/algorithms/logistic_regression.py	Mon Oct 27 17:26:00 2008 -0400
@@ -10,7 +10,7 @@
 
 class Module_Nclass(module.FancyModule):
     class InstanceType(module.FancyModuleInstance):
-        def initialize(self, n_in, n_out):
+        def initialize(self, n_in, n_out, rng=N.random):
             #self.component is the LogisticRegressionTemplate instance that built this guy.
 
             self.w = N.zeros((n_in, n_out))
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/algorithms/regressor.py	Mon Oct 27 17:26:00 2008 -0400
@@ -0,0 +1,103 @@
+
+import theano
+from theano import tensor as T
+from theano.tensor import nnet as NN
+import numpy as N
+
+class Regressor(theano.FancyModule):
+
+    def __init__(self, input = None, target = None, regularize = True):
+        super(Regressor, self).__init__()
+
+        # MODEL CONFIGURATION
+        self.regularize = regularize
+
+        # ACQUIRE/MAKE INPUT AND TARGET
+        self.input = theano.External(input) if input else T.matrix('input')
+        self.target = theano.External(target) if target else T.matrix('target')
+
+        # HYPER-PARAMETERS
+        self.lr = theano.Member(T.scalar())
+
+        # PARAMETERS
+        self.w = theano.Member(T.matrix())
+        self.b = theano.Member(T.vector())
+
+        # OUTPUT
+        self.output_activation = T.dot(self.input, self.w) + self.b
+        self.output = self.build_output()
+
+        # REGRESSION COST
+        self.regression_cost = self.build_regression_cost()
+
+        # REGULARIZATION COST
+        self.regularization = self.build_regularization()
+
+        # TOTAL COST
+        self.cost = self.regression_cost
+        if self.regularize:
+            self.cost = self.cost + self.regularization
+
+        # GRADIENTS AND UPDATES
+        self.params = self.w, self.b
+        gradients = T.grad(self.cost, self.params)
+        updates = dict((p, p - self.lr * g) for p, g in zip(self.params, gradients))
+
+        # INTERFACE METHODS
+        self.update = theano.Method([self.input, self.target], self.cost, updates)
+        self.predict = theano.Method(self.input, self.output)
+
+        self.build_extensions()
+
+    def _instance_initialize(self, obj, input_size = None, output_size = None, seed = None, **init):
+        if seed is not None:
+            R = N.random.RandomState(seed)
+        else:
+            R = N.random
+        if (input_size is None) ^ (output_size is None):
+            raise ValueError("Must specify input_size and output_size or neither.")
+        super(Regressor, self)._instance_initialize(obj, **init)
+        if input_size is not None:
+            sz = (input_size, output_size)
+            range = 1/N.sqrt(input_size)
+            obj.w = R.uniform(size = sz, low = -range, high = range)
+            obj.b = N.zeros(output_size)
+        obj.__hide__ = ['params']
+
+    def _instance_flops_approx(self, obj):
+        return obj.w.size
+
+    def build_extensions(self):
+        pass
+
+    def build_output(self):
+        raise NotImplementedError('override in subclass')
+
+    def build_regression_cost(self):
+        raise NotImplementedError('override in subclass')
+
+    def build_regularization(self):
+        return T.zero() # no regularization!
+
+
+class BinRegressor(Regressor):
+
+    def build_extensions(self):
+        self.classes = T.iround(self.output)
+        self.classify = theano.Method(self.input, self.classes)
+
+    def build_output(self):
+        return NN.sigmoid(self.output_activation)
+
+    def build_regression_cost(self):
+        self.regression_cost_matrix = self.target * T.log(self.output) + (1.0 - self.target) * T.log(1.0 - self.output)
+        self.regression_costs = -T.sum(self.regression_cost_matrix, axis=1)
+        return T.mean(self.regression_costs)
+
+    def build_regularization(self):
+        self.l2_coef = theano.Member(T.scalar())
+        return self.l2_coef * T.sum(self.w * self.w)
+
+    def _instance_initialize(self, obj, input_size = None, output_size = 1, seed = None, **init):
+        init.setdefault('l2_coef', 0)
+        super(BinRegressor, self)._instance_initialize(obj, input_size, output_size, seed, **init)
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/algorithms/stacker.py	Mon Oct 27 17:26:00 2008 -0400
@@ -0,0 +1,83 @@
+
+import theano
+from theano import tensor as T
+import sys
+import numpy as N
+
+class Stacker(T.RModule):
+
+    def __init__(self, submodules, input = None, regularize = False):
+        super(Stacker, self).__init__()
+
+        current = input
+        layers = []
+        for i, (submodule, outname) in enumerate(submodules):
+            layer = submodule(current, regularize = regularize)
+            layers.append(layer)
+            current = layer[outname]
+        self.layers = layers
+
+        self.input = self.layers[0].input
+        self.output = current
+
+        local_update = []
+        global_update = []
+        to_update = []
+        all_kits = []
+        for layer in layers:
+            u = layer.update
+            u.resolve_all()
+            to_update += u.updates.keys()
+            all_kits += u.kits
+            # the input is the whole deep model's input instead of the layer's own
+            # input (which is previous_layer[outname])
+            inputs = [self.input] + u.inputs[1:]
+            method = theano.Method(inputs, u.outputs, u.updates, u.kits)
+            local_update.append(method)
+            global_update.append(
+                theano.Method(inputs,
+                              u.outputs,
+                              # we update the params of the previous layers too but wrt
+                              # this layer's cost
+                              dict((param, param - layer.lr * T.grad(layer.cost, param))
+                                   for param in to_update),
+                              list(all_kits)))
+
+        self.local_update = local_update
+        self.global_update = global_update
+        self.update = self.global_update[-1]
+        self.compute = theano.Method(self.input, self.output)
+        ll = self.layers[-1]
+        for name, method in ll.components_map():
+            if isinstance(method, theano.Method) and not hasattr(self, name):
+                m = method.dup()
+                m.resolve_all()
+                m.inputs = [self.input if x is ll.input else x for x in m.inputs]
+                setattr(self, name, m)
+
+    def _instance_initialize(self, obj, nunits = None, lr = 0.01, seed = None, **kwargs):
+        super(Stacker, self)._instance_initialize(obj, **kwargs)
+        if seed is not None:
+            R = N.random.RandomState(seed)
+        else:
+            R = N.random
+        for layer in obj.layers:
+            if layer.lr is None:
+                layer.lr = lr
+        if nunits:
+            if len(nunits) != len(obj.layers) + 1:
+                raise ValueError('You should give exactly one more unit numbers as there are layers.')
+            for ni, no, layer in zip(nunits[:-1], nunits[1:], obj.layers):
+                if seed is not None:
+                    layer.initialize(ni, no, seed = R.random_integers(sys.maxint - 1))
+                else:
+                    layer.initialize(ni, no)
+        if seed is not None:
+            obj.seed(seed)
+
+    def _instance_flops_approx(self, obj):
+        rval = 0
+        for layer in obj.layers:
+            rval += layer.flops_approx()
+        return rval
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/algorithms/tests/test_aa.py	Mon Oct 27 17:26:00 2008 -0400
@@ -0,0 +1,42 @@
+
+import models
+import theano
+import numpy
+import time
+
+
+def test_train(mode = theano.Mode('c|py', 'fast_run')):
+
+    aa = models.SigmoidXEAutoEncoder(regularize = False)
+#     print aa.update.pretty(mode = theano.Mode('py', 'fast_run').excluding('inplace'))
+
+    model = aa.make(lr = 0.01,
+                    input_size = 100,
+                    hidden_size = 1000,
+                    mode = mode)
+
+    data = [[0, 1, 0, 0, 1, 1, 1, 0, 1, 0]*10]*10
+    #data = numpy.random.rand(10, 100)
+
+    t1 = time.time()
+    for i in xrange(1001):
+        cost = model.update(data)
+        if i % 100 == 0:
+            print i, cost
+    t2 = time.time()
+    return t2 - t1
+
+if __name__ == '__main__':
+    numpy.random.seed(10)
+    print 'optimized:'
+    t1 = test_train(theano.Mode('c|py', 'fast_run'))
+    print 'time:',t1
+    print
+
+    numpy.random.seed(10)
+    print 'not optimized:'
+    t2 = test_train(theano.Mode('c|py', 'fast_compile'))
+    print 'time:',t2
+
+
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/algorithms/tests/test_regressor.py	Mon Oct 27 17:26:00 2008 -0400
@@ -0,0 +1,46 @@
+
+
+import models
+import theano
+import numpy
+import time
+
+
+def test_train(mode = theano.Mode('c|py', 'fast_run')):
+
+    reg = models.BinRegressor(regularize = False)
+
+    model = reg.make(lr = 0.01,
+                     input_size = 100,
+                     mode = mode,
+                     seed = 10)
+
+#     data = [[0, 1, 0, 0, 1, 1, 1, 0, 1, 0]*10]*10
+#     targets = [[1]]*10
+    #data = numpy.random.rand(10, 100)
+
+    R = numpy.random.RandomState(100)
+    t1 = time.time()
+    for i in xrange(1001):
+        data = R.random_integers(0, 1, size = (10, 100))
+        targets = data[:, 6].reshape((10, 1))
+        cost = model.update(data, targets)
+        if i % 100 == 0:
+            print i, '\t', cost, '\t', 1*(targets.T == model.classify(data).T)
+    t2 = time.time()
+    return t2 - t1
+
+if __name__ == '__main__':
+    print 'optimized:'
+    t1 = test_train(theano.Mode('c|py', 'fast_run'))
+    print 'time:',t1
+    print
+
+    print 'not optimized:'
+    t2 = test_train(theano.Mode('c|py', 'fast_compile'))
+    print 'time:',t2
+
+
+
+
+
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/algorithms/tests/test_stacker.py	Mon Oct 27 17:26:00 2008 -0400
@@ -0,0 +1,43 @@
+
+import models
+import theano
+import numpy
+import time
+
+
+def test_train(mode = theano.Mode('c|py', 'fast_run')):
+
+    reg = models.Stacker([(models.BinRegressor, 'output'), (models.BinRegressor, 'output')],
+                         regularize = False)
+    #print reg.global_update[1].pretty(mode = mode.excluding('inplace'))
+
+    model = reg.make([100, 200, 1],
+                     lr = 0.01,
+                     mode = mode,
+                     seed = 10)
+
+    R = numpy.random.RandomState(100)
+    t1 = time.time()
+    for i in xrange(1001):
+        data = R.random_integers(0, 1, size = (10, 100))
+        targets = data[:, 6].reshape((10, 1))
+        cost = model.update(data, targets)
+        if i % 100 == 0:
+            print i, '\t', cost, '\t', 1*(targets.T == model.classify(data).T)
+    t2 = time.time()
+    return t2 - t1
+
+if __name__ == '__main__':
+    print 'optimized:'
+    t1 = test_train(theano.Mode('c|py', 'fast_run'))
+    print 'time:',t1
+    print
+
+    print 'not optimized:'
+    t2 = test_train(theano.Mode('c|py', 'fast_compile'))
+    print 'time:',t2
+
+
+
+
+