Mercurial > pylearn
diff external/wrap_libsvm.py @ 517:716c04512dbe
init
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 12 Nov 2008 10:54:38 -0500 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/external/wrap_libsvm.py Wed Nov 12 10:54:38 2008 -0500 @@ -0,0 +1,99 @@ +"""Run an experiment using libsvm. +""" +import numpy +from ..datasets import dataset_from_descr + +# libsvm currently has no python installation instructions/convention. +# +# This module uses a specific convention for libsvm's installation. +# I base this on installing libsvm-2.88. +# To install libsvm's python module, do three things: +# 1. Build libsvm (run make in both the root dir and the python subdir). +# 2. touch a '__init__.py' file in the python subdir +# 3. add a symbolic link to a PYTHONPATH location that looks like this: +# libsvm -> <your root path>/libsvm-2.88/python/ +# +# That is the sort of thing that this module expects from 'import libsvm' + +import libsvm + +def score_01(x, y, model): + assert len(x) == len(y) + size = len(x) + errors = 0 + for i in range(size): + prediction = model.predict(x[i]) + #probability = model.predict_probability + if (y[i] != prediction): + errors = errors + 1 + return float(errors)/size + +#this is the dbdict experiment interface... if you happen to use dbdict +class State(object): + #TODO: parametrize to get all the kernel types, not hardcode for RBF + dataset = 'MNIST_1k' + C = 10.0 + kernel = 'RBF' + # rel_gamma is related to the procedure Jerome used. He mentioned why in + # quadratic_neurons/neuropaper/draft3.pdf. + rel_gamma = 1.0 + + def __init__(self, **kwargs): + for k, v in kwargs: + setattr(self, k, type(getattr(self, k))(v)) + + +def dbdict_run_svm_experiment(state, channel=lambda *args, **kwargs:None): + """Parameters are described in state, and returned in state. + + :param state: object instance to store parameters and return values + :param channel: not used + + :returns: None + + This is the kind of function that dbdict-run can use. + + """ + ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)) = dataset_from_descr(state.dataset) + + #libsvm needs stuff in int32 on a 32bit machine + #TODO: test this on a 64bit machine + train_y = numpy.asarray(train_y, dtype='int32') + valid_y = numpy.asarray(valid_y, dtype='int32') + test_y = numpy.asarray(test_y, dtype='int32') + problem = svm.svm_problem(train_y, train_x); + + gamma0 = 0.5 / numpy.sum(numpy.var(train_x, axis=0)) + + param = svm.svm_parameter(C=state.C, + kernel_type=getattr(svm, state.kernel), + gamma=state.rel_gamma * gamma0) + + model = svm.svm_model(problem, param) #this is the expensive part + + state.train_01 = score_01(train_x, train_y, model) + state.valid_01 = score_01(valid_x, valid_y, model) + state.test_01 = score_01(test_x, test_y, model) + + state.n_train = len(train_y) + state.n_valid = len(valid_y) + state.n_test = len(test_y) + +def run_svm_experiment(**kwargs): + """Python-friendly interface to dbdict_run_svm_experiment + + Parameters are used to construct a `State` instance, which is returned after running + `dbdict_run_svm_experiment` on it. + + .. code-block:: python + results = run_svm_experiment(dataset='MNIST_1k', C=100.0, rel_gamma=0.01) + print results.n_train + # 1000 + print results.valid_01, results.test_01 + # 0.14, 0.10 #.. or something... + + """ + state = State(**kwargs) + state_run_svm_experiment(state) + return state +