File size: 5,783 Bytes
9375c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
// Copyright (C) 2013 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#include "opaque_types.h"
#include <dlib/python.h>
#include <dlib/matrix.h>
#include <dlib/svm.h>
#include "testing_results.h"
#include <pybind11/stl_bind.h>
using namespace dlib;
using namespace std;
namespace py = pybind11;
typedef matrix<double,0,1> sample_type;
// ----------------------------------------------------------------------------------------
namespace dlib
{
template <typename T>
bool operator== (
const ranking_pair<T>&,
const ranking_pair<T>&
)
{
pyassert(false, "It is illegal to compare ranking pair objects for equality.");
return false;
}
}
template <typename T>
void resize(T& v, unsigned long n) { v.resize(n); }
// ----------------------------------------------------------------------------------------
template <typename trainer_type>
typename trainer_type::trained_function_type train1 (
const trainer_type& trainer,
const ranking_pair<typename trainer_type::sample_type>& sample
)
{
typedef ranking_pair<typename trainer_type::sample_type> st;
pyassert(is_ranking_problem(std::vector<st>(1, sample)), "Invalid inputs");
return trainer.train(sample);
}
template <typename trainer_type>
typename trainer_type::trained_function_type train2 (
const trainer_type& trainer,
const std::vector<ranking_pair<typename trainer_type::sample_type> >& samples
)
{
pyassert(is_ranking_problem(samples), "Invalid inputs");
return trainer.train(samples);
}
template <typename trainer_type>
void set_epsilon ( trainer_type& trainer, double eps)
{
pyassert(eps > 0, "epsilon must be > 0");
trainer.set_epsilon(eps);
}
template <typename trainer_type>
double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); }
template <typename trainer_type>
void set_c ( trainer_type& trainer, double C)
{
pyassert(C > 0, "C must be > 0");
trainer.set_c(C);
}
template <typename trainer_type>
double get_c (const trainer_type& trainer)
{
return trainer.get_c();
}
template <typename trainer>
void add_ranker (
py::module& m,
const char* name
)
{
py::class_<trainer>(m, name)
.def(py::init())
.def_property("epsilon", get_epsilon<trainer>, set_epsilon<trainer>)
.def_property("c", get_c<trainer>, set_c<trainer>)
.def_property("max_iterations", &trainer::get_max_iterations, &trainer::set_max_iterations)
.def_property("force_last_weight_to_1", &trainer::forces_last_weight_to_1, &trainer::force_last_weight_to_1)
.def_property("learns_nonnegative_weights", &trainer::learns_nonnegative_weights, &trainer::set_learns_nonnegative_weights)
.def_property_readonly("has_prior", &trainer::has_prior)
.def("train", train1<trainer>)
.def("train", train2<trainer>)
.def("set_prior", &trainer::set_prior)
.def("be_verbose", &trainer::be_verbose)
.def("be_quiet", &trainer::be_quiet);
}
// ----------------------------------------------------------------------------------------
template <
typename trainer_type,
typename T
>
const ranking_test _cross_ranking_validate_trainer (
const trainer_type& trainer,
const std::vector<ranking_pair<T> >& samples,
const unsigned long folds
)
{
pyassert(is_ranking_problem(samples), "Training data does not make a valid training set.");
pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given.");
return cross_validate_ranking_trainer(trainer, samples, folds);
}
// ----------------------------------------------------------------------------------------
void bind_svm_rank_trainer(py::module& m)
{
py::class_<ranking_pair<sample_type> >(m, "ranking_pair")
.def(py::init())
.def_readwrite("relevant", &ranking_pair<sample_type>::relevant)
.def_readwrite("nonrelevant", &ranking_pair<sample_type>::nonrelevant)
.def(py::pickle(&getstate<ranking_pair<sample_type>>, &setstate<ranking_pair<sample_type>>));
py::class_<ranking_pair<sparse_vect> >(m, "sparse_ranking_pair")
.def(py::init())
.def_readwrite("relevant", &ranking_pair<sparse_vect>::relevant)
.def_readwrite("nonrelevant", &ranking_pair<sparse_vect>::nonrelevant)
.def(py::pickle(&getstate<ranking_pair<sparse_vect>>, &setstate<ranking_pair<sparse_vect>>));
py::bind_vector<ranking_pairs>(m, "ranking_pairs")
.def("clear", &ranking_pairs::clear)
.def("resize", resize<ranking_pairs>)
.def("extend", extend_vector_with_python_list<ranking_pair<sample_type>>)
.def(py::pickle(&getstate<ranking_pairs>, &setstate<ranking_pairs>));
py::bind_vector<sparse_ranking_pairs>(m, "sparse_ranking_pairs")
.def("clear", &sparse_ranking_pairs::clear)
.def("resize", resize<sparse_ranking_pairs>)
.def("extend", extend_vector_with_python_list<ranking_pair<sparse_vect>>)
.def(py::pickle(&getstate<sparse_ranking_pairs>, &setstate<sparse_ranking_pairs>));
add_ranker<svm_rank_trainer<linear_kernel<sample_type> > >(m, "svm_rank_trainer");
add_ranker<svm_rank_trainer<sparse_linear_kernel<sparse_vect> > >(m, "svm_rank_trainer_sparse");
m.def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
svm_rank_trainer<linear_kernel<sample_type> >,sample_type>,
py::arg("trainer"), py::arg("samples"), py::arg("folds") );
m.def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
svm_rank_trainer<sparse_linear_kernel<sparse_vect> > ,sparse_vect>,
py::arg("trainer"), py::arg("samples"), py::arg("folds") );
}
|