Mercurial > pylearn
comparison external/wrap_libsvm.py @ 517:716c04512dbe
init
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Wed, 12 Nov 2008 10:54:38 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
516:2b0e10ac6929 | 517:716c04512dbe |
---|---|
1 """Run an experiment using libsvm. | |
2 """ | |
3 import numpy | |
4 from ..datasets import dataset_from_descr | |
5 | |
6 # libsvm currently has no python installation instructions/convention. | |
7 # | |
8 # This module uses a specific convention for libsvm's installation. | |
9 # I base this on installing libsvm-2.88. | |
10 # To install libsvm's python module, do three things: | |
11 # 1. Build libsvm (run make in both the root dir and the python subdir). | |
12 # 2. touch a '__init__.py' file in the python subdir | |
13 # 3. add a symbolic link to a PYTHONPATH location that looks like this: | |
14 # libsvm -> <your root path>/libsvm-2.88/python/ | |
15 # | |
16 # That is the sort of thing that this module expects from 'import libsvm' | |
17 | |
18 import libsvm | |
19 | |
20 def score_01(x, y, model): | |
21 assert len(x) == len(y) | |
22 size = len(x) | |
23 errors = 0 | |
24 for i in range(size): | |
25 prediction = model.predict(x[i]) | |
26 #probability = model.predict_probability | |
27 if (y[i] != prediction): | |
28 errors = errors + 1 | |
29 return float(errors)/size | |
30 | |
31 #this is the dbdict experiment interface... if you happen to use dbdict | |
32 class State(object): | |
33 #TODO: parametrize to get all the kernel types, not hardcode for RBF | |
34 dataset = 'MNIST_1k' | |
35 C = 10.0 | |
36 kernel = 'RBF' | |
37 # rel_gamma is related to the procedure Jerome used. He mentioned why in | |
38 # quadratic_neurons/neuropaper/draft3.pdf. | |
39 rel_gamma = 1.0 | |
40 | |
41 def __init__(self, **kwargs): | |
42 for k, v in kwargs: | |
43 setattr(self, k, type(getattr(self, k))(v)) | |
44 | |
45 | |
46 def dbdict_run_svm_experiment(state, channel=lambda *args, **kwargs:None): | |
47 """Parameters are described in state, and returned in state. | |
48 | |
49 :param state: object instance to store parameters and return values | |
50 :param channel: not used | |
51 | |
52 :returns: None | |
53 | |
54 This is the kind of function that dbdict-run can use. | |
55 | |
56 """ | |
57 ((train_x, train_y), (valid_x, valid_y), (test_x, test_y)) = dataset_from_descr(state.dataset) | |
58 | |
59 #libsvm needs stuff in int32 on a 32bit machine | |
60 #TODO: test this on a 64bit machine | |
61 train_y = numpy.asarray(train_y, dtype='int32') | |
62 valid_y = numpy.asarray(valid_y, dtype='int32') | |
63 test_y = numpy.asarray(test_y, dtype='int32') | |
64 problem = svm.svm_problem(train_y, train_x); | |
65 | |
66 gamma0 = 0.5 / numpy.sum(numpy.var(train_x, axis=0)) | |
67 | |
68 param = svm.svm_parameter(C=state.C, | |
69 kernel_type=getattr(svm, state.kernel), | |
70 gamma=state.rel_gamma * gamma0) | |
71 | |
72 model = svm.svm_model(problem, param) #this is the expensive part | |
73 | |
74 state.train_01 = score_01(train_x, train_y, model) | |
75 state.valid_01 = score_01(valid_x, valid_y, model) | |
76 state.test_01 = score_01(test_x, test_y, model) | |
77 | |
78 state.n_train = len(train_y) | |
79 state.n_valid = len(valid_y) | |
80 state.n_test = len(test_y) | |
81 | |
82 def run_svm_experiment(**kwargs): | |
83 """Python-friendly interface to dbdict_run_svm_experiment | |
84 | |
85 Parameters are used to construct a `State` instance, which is returned after running | |
86 `dbdict_run_svm_experiment` on it. | |
87 | |
88 .. code-block:: python | |
89 results = run_svm_experiment(dataset='MNIST_1k', C=100.0, rel_gamma=0.01) | |
90 print results.n_train | |
91 # 1000 | |
92 print results.valid_01, results.test_01 | |
93 # 0.14, 0.10 #.. or something... | |
94 | |
95 """ | |
96 state = State(**kwargs) | |
97 state_run_svm_experiment(state) | |
98 return state | |
99 |