changeset 688:49e531f7b0ba

added picklable svm_model to wrap_libsvm
author James Bergstra <bergstrj@iro.umontreal.ca>
date Thu, 14 May 2009 16:54:59 -0400
parents 4b5e0b5a11e1
children 651eb6506d91
files pylearn/external/wrap_libsvm.py
diffstat 1 files changed, 115 insertions(+), 3 deletions(-) [+]
line wrap: on
line diff
--- a/pylearn/external/wrap_libsvm.py	Thu May 14 16:33:38 2009 -0400
+++ b/pylearn/external/wrap_libsvm.py	Thu May 14 16:54:59 2009 -0400
@@ -7,16 +7,75 @@
 #
 # This module uses a specific convention for libsvm's installation.
 # I base this on installing libsvm-2.88.
-# To install libsvm's python module, do three things:
+# To install libsvm's python module, do the following:
 # 1. Build libsvm (run make in both the root dir and the python subdir).
 # 2. touch a '__init__.py' file in the python subdir
 # 3. add a symbolic link to a PYTHONPATH location that looks like this:
 #    libsvm -> <your root path>/libsvm-2.88/python/
+# 4. modify the svm_model class in python/svm.py to inherit from object
 #
 # That is the sort of thing that this module expects from 'import libsvm'
 
 import libsvm
 
+class svm_model(libsvm.svm_model):
+    """
+    This class is a picklable drop-in replacement for libsvm.svm_model.
+    """
+    def __getstate__(self):
+        return PicklableSVM.svm_to_str(self)
+
+    def __setstate__(self, svm_str):
+        PicklableSVM.str_to_svm(svm_str, self=self)
+
+    @staticmethod
+    def str_to_svm(s, self=None):
+        fname = tempfile.mktemp()
+        f = open(fname,'w')
+        f.write(s)
+        f.close()
+        rval = self
+        try:
+            if self:
+                self.__init__(fname)
+            else:
+                rval = libsvm.svm_model(fname)
+        finally:
+            os.remove(fname)
+        return rval
+
+    @staticmethod
+    def svm_to_str(svm):
+        fname = tempfile.mktemp()
+        svm.save(fname)
+        rval = open(fname, 'r').read()
+        os.remove(fname)
+        return rval
+
+    def predict(self, x):
+        if type(x) != numpy.ndarray:
+            raise TypeError(x)
+        if x.ndim != 1:
+            raise TypeError(x)
+        return libsvm.svm_model.predict(self, numpy.asarray(x, dtype='float64'))
+
+    def predict_probability(self, x):
+        if x.ndim != 1:
+            raise TypeError(x)
+        return libsvm.svm_model.predict_probability(self, numpy.asarray(x, dtype='float64'))
+
+svm_problem = libsvm.svm_problem
+svm_parameter = libsvm.svm_parameter
+RBF = libsvm.svm_RBF
+
+
+####################################
+# Extra stuff that is less essential 
+#
+# TODO: Move stuff below to a file 
+#       in algorithms
+####################################
+
 def score_01(x, y, model):
     assert len(x) == len(y)
     size = len(x)
@@ -42,8 +101,7 @@
         for k, v in kwargs:
             setattr(self, k, type(getattr(self, k))(v))
 
-
-def dbdict_run_svm_experiment(state, channel=lambda *args, **kwargs:None):
+def state_run_svm_experiment(state, channel=lambda *args, **kwargs:None):
     """Parameters are described in state, and returned in state.
 
     :param state: object instance to store parameters and return values
@@ -98,3 +156,57 @@
     state_run_svm_experiment(state=kwargs)
     return kwargs
 
+def train_rbf_model(train_X, train_Y, C, gamma):
+    param = libsvm.svm_parameter(C=C, kernel_type=libsvm.RBF, gamma=gamma)
+    problem = libsvm.svm_problem(train_Y, train_X)
+    model libsvm.svm_model(problem, param)
+
+    #save_filename = state.save_filename
+    #model.save(save_filename)
+
+
+def jobman_train_model(state, channel):
+    """
+    
+    According to the given validation set,
+    What is the best libsvm parameter setting to train on?
+    """
+    (train_X, train_Y) = jobman.tools.make(state.train_set)
+    (valid_X, valid_Y) = jobman.tools.make(state.valid_set)
+
+    C_grid = [1,2,3]
+    gamma_grid = [0.1, 1, 10]
+
+    grid = [dict(
+        train_set=None,
+        svm_param=dict(kernel='RBF', C=C, gamma=g),
+        save_filename='model_RBF_C%f_G%f.libsvm')
+            for C in C_grid,
+            for g in gamma_grid]
+
+    # will return quickly if jobs have already run
+    # and the rootpath is populated with results
+    grid = jobman.map(
+            jobman_train_model_given_all_params, 
+            grid,
+            path=jobman.rootpath(state)+'/gridmap',
+            cleanup=False)
+
+    # evaluate all these sub_state models on our validation_set
+    valid_perf = []
+    for sub_state in grid:
+        # create a file in this state-space called model.tmp
+        # with the same contents as the
+        # save_filename file in the sub_state
+        jobman.link('model.tmp', jobman.rootpath(sub_state)+'/'+sub_state.save_filename)
+        model = svm.model('model.tmp')
+        valid_perf.append((score_01(valid_X, valid_Y, model), sub_state))
+        jobman.unlink('model.tmp')
+
+    # calculate the return value
+    valid_perf.sort() #lowest first
+    state.lowest_valid_err = valid_perf[0][0]
+    state.lowest_valid_svm_param = valid_perf[0][1].svm_param
+
+
+