Mercurial > pylearn
changeset 688:49e531f7b0ba
added picklable svm_model to wrap_libsvm
author | James Bergstra <bergstrj@iro.umontreal.ca> |
---|---|
date | Thu, 14 May 2009 16:54:59 -0400 |
parents | 4b5e0b5a11e1 |
children | 651eb6506d91 |
files | pylearn/external/wrap_libsvm.py |
diffstat | 1 files changed, 115 insertions(+), 3 deletions(-) [+] |
line wrap: on
line diff
--- a/pylearn/external/wrap_libsvm.py Thu May 14 16:33:38 2009 -0400 +++ b/pylearn/external/wrap_libsvm.py Thu May 14 16:54:59 2009 -0400 @@ -7,16 +7,75 @@ # # This module uses a specific convention for libsvm's installation. # I base this on installing libsvm-2.88. -# To install libsvm's python module, do three things: +# To install libsvm's python module, do the following: # 1. Build libsvm (run make in both the root dir and the python subdir). # 2. touch a '__init__.py' file in the python subdir # 3. add a symbolic link to a PYTHONPATH location that looks like this: # libsvm -> <your root path>/libsvm-2.88/python/ +# 4. modify the svm_model class in python/svm.py to inherit from object # # That is the sort of thing that this module expects from 'import libsvm' import libsvm +class svm_model(libsvm.svm_model): + """ + This class is a picklable drop-in replacement for libsvm.svm_model. + """ + def __getstate__(self): + return PicklableSVM.svm_to_str(self) + + def __setstate__(self, svm_str): + PicklableSVM.str_to_svm(svm_str, self=self) + + @staticmethod + def str_to_svm(s, self=None): + fname = tempfile.mktemp() + f = open(fname,'w') + f.write(s) + f.close() + rval = self + try: + if self: + self.__init__(fname) + else: + rval = libsvm.svm_model(fname) + finally: + os.remove(fname) + return rval + + @staticmethod + def svm_to_str(svm): + fname = tempfile.mktemp() + svm.save(fname) + rval = open(fname, 'r').read() + os.remove(fname) + return rval + + def predict(self, x): + if type(x) != numpy.ndarray: + raise TypeError(x) + if x.ndim != 1: + raise TypeError(x) + return libsvm.svm_model.predict(self, numpy.asarray(x, dtype='float64')) + + def predict_probability(self, x): + if x.ndim != 1: + raise TypeError(x) + return libsvm.svm_model.predict_probability(self, numpy.asarray(x, dtype='float64')) + +svm_problem = libsvm.svm_problem +svm_parameter = libsvm.svm_parameter +RBF = libsvm.svm_RBF + + +#################################### +# Extra stuff that is less essential +# +# TODO: Move stuff below to a file +# in algorithms +#################################### + def score_01(x, y, model): assert len(x) == len(y) size = len(x) @@ -42,8 +101,7 @@ for k, v in kwargs: setattr(self, k, type(getattr(self, k))(v)) - -def dbdict_run_svm_experiment(state, channel=lambda *args, **kwargs:None): +def state_run_svm_experiment(state, channel=lambda *args, **kwargs:None): """Parameters are described in state, and returned in state. :param state: object instance to store parameters and return values @@ -98,3 +156,57 @@ state_run_svm_experiment(state=kwargs) return kwargs +def train_rbf_model(train_X, train_Y, C, gamma): + param = libsvm.svm_parameter(C=C, kernel_type=libsvm.RBF, gamma=gamma) + problem = libsvm.svm_problem(train_Y, train_X) + model libsvm.svm_model(problem, param) + + #save_filename = state.save_filename + #model.save(save_filename) + + +def jobman_train_model(state, channel): + """ + + According to the given validation set, + What is the best libsvm parameter setting to train on? + """ + (train_X, train_Y) = jobman.tools.make(state.train_set) + (valid_X, valid_Y) = jobman.tools.make(state.valid_set) + + C_grid = [1,2,3] + gamma_grid = [0.1, 1, 10] + + grid = [dict( + train_set=None, + svm_param=dict(kernel='RBF', C=C, gamma=g), + save_filename='model_RBF_C%f_G%f.libsvm') + for C in C_grid, + for g in gamma_grid] + + # will return quickly if jobs have already run + # and the rootpath is populated with results + grid = jobman.map( + jobman_train_model_given_all_params, + grid, + path=jobman.rootpath(state)+'/gridmap', + cleanup=False) + + # evaluate all these sub_state models on our validation_set + valid_perf = [] + for sub_state in grid: + # create a file in this state-space called model.tmp + # with the same contents as the + # save_filename file in the sub_state + jobman.link('model.tmp', jobman.rootpath(sub_state)+'/'+sub_state.save_filename) + model = svm.model('model.tmp') + valid_perf.append((score_01(valid_X, valid_Y, model), sub_state)) + jobman.unlink('model.tmp') + + # calculate the return value + valid_perf.sort() #lowest first + state.lowest_valid_err = valid_perf[0][0] + state.lowest_valid_svm_param = valid_perf[0][1].svm_param + + +