File size: 5,783 Bytes
9375c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Copyright (C) 2013  Davis E. King ([email protected])
// License: Boost Software License   See LICENSE.txt for the full license.

#include "opaque_types.h"
#include <dlib/python.h>
#include <dlib/matrix.h>
#include <dlib/svm.h>
#include "testing_results.h"
#include <pybind11/stl_bind.h>

using namespace dlib;
using namespace std;
namespace py = pybind11;

typedef matrix<double,0,1> sample_type; 


// ----------------------------------------------------------------------------------------

namespace dlib
{
    template <typename T>
    bool operator== (
        const ranking_pair<T>&, 
        const ranking_pair<T>& 
    )
    {
        pyassert(false, "It is illegal to compare ranking pair objects for equality.");
        return false;
    }
}

template <typename T>
void resize(T& v, unsigned long n) { v.resize(n); }

// ----------------------------------------------------------------------------------------

template <typename trainer_type>
typename trainer_type::trained_function_type train1 (
    const trainer_type& trainer,
    const ranking_pair<typename trainer_type::sample_type>& sample
)
{
    typedef ranking_pair<typename trainer_type::sample_type> st;
    pyassert(is_ranking_problem(std::vector<st>(1, sample)), "Invalid inputs");
    return trainer.train(sample);
}

template <typename trainer_type>
typename trainer_type::trained_function_type train2 (
    const trainer_type& trainer,
    const std::vector<ranking_pair<typename trainer_type::sample_type> >& samples
)
{
    pyassert(is_ranking_problem(samples), "Invalid inputs");
    return trainer.train(samples);
}

template <typename trainer_type>
void set_epsilon ( trainer_type& trainer, double eps)
{
    pyassert(eps > 0, "epsilon must be > 0");
    trainer.set_epsilon(eps);
}

template <typename trainer_type>
double get_epsilon ( const trainer_type& trainer) { return trainer.get_epsilon(); }

template <typename trainer_type>
void set_c ( trainer_type& trainer, double C)
{
    pyassert(C > 0, "C must be > 0");
    trainer.set_c(C);
}

template <typename trainer_type>
double get_c (const trainer_type& trainer)
{
    return trainer.get_c();
}


template <typename trainer>
void add_ranker (
    py::module& m,
    const char* name
)
{
    py::class_<trainer>(m, name)
        .def(py::init())
        .def_property("epsilon", get_epsilon<trainer>, set_epsilon<trainer>)
        .def_property("c", get_c<trainer>, set_c<trainer>)
        .def_property("max_iterations", &trainer::get_max_iterations, &trainer::set_max_iterations)
        .def_property("force_last_weight_to_1", &trainer::forces_last_weight_to_1, &trainer::force_last_weight_to_1)
        .def_property("learns_nonnegative_weights", &trainer::learns_nonnegative_weights, &trainer::set_learns_nonnegative_weights)
        .def_property_readonly("has_prior", &trainer::has_prior)
        .def("train", train1<trainer>)
        .def("train", train2<trainer>)
        .def("set_prior", &trainer::set_prior)
        .def("be_verbose", &trainer::be_verbose)
        .def("be_quiet", &trainer::be_quiet);
}

// ----------------------------------------------------------------------------------------

template <
    typename trainer_type,
    typename T
    >
const ranking_test _cross_ranking_validate_trainer (
    const trainer_type& trainer,
    const std::vector<ranking_pair<T> >& samples,
    const unsigned long folds
)
{
    pyassert(is_ranking_problem(samples), "Training data does not make a valid training set.");
    pyassert(1 < folds && folds <= samples.size(), "Invalid number of folds given.");
    return cross_validate_ranking_trainer(trainer, samples, folds);
}

// ----------------------------------------------------------------------------------------

void bind_svm_rank_trainer(py::module& m)
{
    py::class_<ranking_pair<sample_type> >(m, "ranking_pair")
        .def(py::init())
        .def_readwrite("relevant", &ranking_pair<sample_type>::relevant)
        .def_readwrite("nonrelevant", &ranking_pair<sample_type>::nonrelevant)
        .def(py::pickle(&getstate<ranking_pair<sample_type>>, &setstate<ranking_pair<sample_type>>));

    py::class_<ranking_pair<sparse_vect> >(m, "sparse_ranking_pair")
        .def(py::init())
        .def_readwrite("relevant", &ranking_pair<sparse_vect>::relevant)
        .def_readwrite("nonrelevant", &ranking_pair<sparse_vect>::nonrelevant)
        .def(py::pickle(&getstate<ranking_pair<sparse_vect>>, &setstate<ranking_pair<sparse_vect>>));

    py::bind_vector<ranking_pairs>(m, "ranking_pairs")
        .def("clear", &ranking_pairs::clear)
        .def("resize", resize<ranking_pairs>)
        .def("extend", extend_vector_with_python_list<ranking_pair<sample_type>>)
        .def(py::pickle(&getstate<ranking_pairs>, &setstate<ranking_pairs>));

    py::bind_vector<sparse_ranking_pairs>(m, "sparse_ranking_pairs")
        .def("clear", &sparse_ranking_pairs::clear)
        .def("resize", resize<sparse_ranking_pairs>)
        .def("extend", extend_vector_with_python_list<ranking_pair<sparse_vect>>)
        .def(py::pickle(&getstate<sparse_ranking_pairs>, &setstate<sparse_ranking_pairs>));

    add_ranker<svm_rank_trainer<linear_kernel<sample_type> > >(m, "svm_rank_trainer");
    add_ranker<svm_rank_trainer<sparse_linear_kernel<sparse_vect> > >(m, "svm_rank_trainer_sparse");

    m.def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
                svm_rank_trainer<linear_kernel<sample_type> >,sample_type>,
                py::arg("trainer"), py::arg("samples"), py::arg("folds") );
    m.def("cross_validate_ranking_trainer", &_cross_ranking_validate_trainer<
                svm_rank_trainer<sparse_linear_kernel<sparse_vect> > ,sparse_vect>,
                py::arg("trainer"), py::arg("samples"), py::arg("folds") );
}