view external/wrap_libsvm.py @ 517:716c04512dbe

init
author James Bergstra <bergstrj@iro.umontreal.ca>
date Wed, 12 Nov 2008 10:54:38 -0500
parents
children
line wrap: on
line source

"""Run an experiment using libsvm.
"""
import numpy
from ..datasets import dataset_from_descr

# libsvm currently has no python installation instructions/convention.
#
# This module uses a specific convention for libsvm's installation.
# I base this on installing libsvm-2.88.
# To install libsvm's python module, do three things:
# 1. Build libsvm (run make in both the root dir and the python subdir).
# 2. touch a '__init__.py' file in the python subdir
# 3. add a symbolic link to a PYTHONPATH location that looks like this:
#    libsvm -> <your root path>/libsvm-2.88/python/
#
# That is the sort of thing that this module expects from 'import libsvm'

import libsvm

def score_01(x, y, model):
    assert len(x) == len(y)
    size = len(x)
    errors = 0
    for i in range(size):
        prediction = model.predict(x[i])
        #probability = model.predict_probability
        if (y[i] != prediction):
            errors = errors + 1
    return float(errors)/size

#this is the dbdict experiment interface... if you happen to use dbdict
class State(object):
    #TODO: parametrize to get all the kernel types, not hardcode for RBF
    dataset = 'MNIST_1k'
    C = 10.0
    kernel = 'RBF'
    # rel_gamma is related to the procedure Jerome used. He mentioned why in
    # quadratic_neurons/neuropaper/draft3.pdf.
    rel_gamma = 1.0   

    def __init__(self, **kwargs):
        for k, v in kwargs:
            setattr(self, k, type(getattr(self, k))(v))


def dbdict_run_svm_experiment(state, channel=lambda *args, **kwargs:None):
    """Parameters are described in state, and returned in state.

    :param state: object instance to store parameters and return values
    :param channel: not used

    :returns: None

    This is the kind of function that dbdict-run can use.

    """
    ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)) = dataset_from_descr(state.dataset)

    #libsvm needs stuff in int32 on a 32bit machine
    #TODO: test this on a 64bit machine
    train_y = numpy.asarray(train_y, dtype='int32')
    valid_y = numpy.asarray(valid_y, dtype='int32')
    test_y = numpy.asarray(test_y, dtype='int32')
    problem = svm.svm_problem(train_y, train_x);

    gamma0 = 0.5 / numpy.sum(numpy.var(train_x, axis=0))

    param = svm.svm_parameter(C=state.C,
            kernel_type=getattr(svm, state.kernel),
            gamma=state.rel_gamma * gamma0)

    model = svm.svm_model(problem, param) #this is the expensive part

    state.train_01 = score_01(train_x, train_y, model)
    state.valid_01 = score_01(valid_x, valid_y, model)
    state.test_01 = score_01(test_x, test_y, model)

    state.n_train = len(train_y)
    state.n_valid = len(valid_y)
    state.n_test = len(test_y)

def run_svm_experiment(**kwargs):
    """Python-friendly interface to dbdict_run_svm_experiment

    Parameters are used to construct a `State` instance, which is returned after running
    `dbdict_run_svm_experiment` on it.

    .. code-block:: python
        results = run_svm_experiment(dataset='MNIST_1k', C=100.0, rel_gamma=0.01)
        print results.n_train
        # 1000
        print results.valid_01, results.test_01
        # 0.14, 0.10  #.. or something...

    """
    state = State(**kwargs)
    state_run_svm_experiment(state)
    return state