|
|
|
|
|
#ifndef DLIB_GRAPH_LaBELER_Hh_ |
|
#define DLIB_GRAPH_LaBELER_Hh_ |
|
|
|
#include "graph_labeler_abstract.h" |
|
#include "../matrix.h" |
|
#include "../string.h" |
|
#include <vector> |
|
#include "find_max_factor_graph_potts.h" |
|
#include "../svm/sparse_vector.h" |
|
#include "../graph.h" |
|
|
|
namespace dlib |
|
{ |
|
|
|
|
|
|
|
template < |
|
typename vector_type |
|
> |
|
class graph_labeler |
|
{ |
|
|
|
public: |
|
|
|
typedef std::vector<bool> label_type; |
|
typedef label_type result_type; |
|
|
|
graph_labeler() |
|
{ |
|
} |
|
|
|
graph_labeler( |
|
const vector_type& edge_weights_, |
|
const vector_type& node_weights_ |
|
) : |
|
edge_weights(edge_weights_), |
|
node_weights(node_weights_) |
|
{ |
|
|
|
DLIB_ASSERT(edge_weights.size() == 0 || min(edge_weights) >= 0, |
|
"\t graph_labeler::graph_labeler()" |
|
<< "\n\t Invalid inputs were given to this function." |
|
<< "\n\t min(edge_weights): " << min(edge_weights) |
|
<< "\n\t this: " << this |
|
); |
|
} |
|
|
|
const vector_type& get_edge_weights ( |
|
) const { return edge_weights; } |
|
|
|
const vector_type& get_node_weights ( |
|
) const { return node_weights; } |
|
|
|
template <typename graph_type> |
|
void operator() ( |
|
const graph_type& sample, |
|
std::vector<bool>& labels |
|
) const |
|
{ |
|
|
|
#ifdef ENABLE_ASSERTS |
|
DLIB_ASSERT(graph_contains_length_one_cycle(sample) == false, |
|
"\t void graph_labeler::operator()" |
|
<< "\n\t Invalid inputs were given to this function." |
|
<< "\n\t get_edge_weights().size(): " << get_edge_weights().size() |
|
<< "\n\t get_node_weights().size(): " << get_node_weights().size() |
|
<< "\n\t graph_contains_length_one_cycle(sample): " << graph_contains_length_one_cycle(sample) |
|
<< "\n\t this: " << this |
|
); |
|
for (unsigned long i = 0; i < sample.number_of_nodes(); ++i) |
|
{ |
|
if (is_matrix<vector_type>::value && |
|
is_matrix<typename graph_type::type>::value) |
|
{ |
|
|
|
DLIB_ASSERT((unsigned long)get_node_weights().size() == (unsigned long)sample.node(i).data.size(), |
|
"\t void graph_labeler::operator()" |
|
<< "\n\t The size of the node weight vector must match the one in the node." |
|
<< "\n\t get_node_weights().size(): " << get_node_weights().size() |
|
<< "\n\t sample.node(i).data.size(): " << sample.node(i).data.size() |
|
<< "\n\t i: " << i |
|
<< "\n\t this: " << this |
|
); |
|
} |
|
|
|
for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n) |
|
{ |
|
if (is_matrix<vector_type>::value && |
|
is_matrix<typename graph_type::edge_type>::value) |
|
{ |
|
|
|
DLIB_ASSERT((unsigned long)get_edge_weights().size() == (unsigned long)sample.node(i).edge(n).size(), |
|
"\t void graph_labeler::operator()" |
|
<< "\n\t The size of the edge weight vector must match the one in graph's edge." |
|
<< "\n\t get_edge_weights().size(): " << get_edge_weights().size() |
|
<< "\n\t sample.node(i).edge(n).size(): " << sample.node(i).edge(n).size() |
|
<< "\n\t i: " << i |
|
<< "\n\t this: " << this |
|
); |
|
} |
|
|
|
DLIB_ASSERT(sample.node(i).edge(n).size() == 0 || min(sample.node(i).edge(n)) >= 0, |
|
"\t void graph_labeler::operator()" |
|
<< "\n\t No edge vectors are allowed to have negative elements." |
|
<< "\n\t min(sample.node(i).edge(n)): " << min(sample.node(i).edge(n)) |
|
<< "\n\t i: " << i |
|
<< "\n\t n: " << n |
|
<< "\n\t this: " << this |
|
); |
|
} |
|
} |
|
#endif |
|
|
|
|
|
graph<double,double>::kernel_1a g; |
|
copy_graph_structure(sample, g); |
|
for (unsigned long i = 0; i < g.number_of_nodes(); ++i) |
|
{ |
|
g.node(i).data = dot(node_weights, sample.node(i).data); |
|
|
|
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n) |
|
{ |
|
const unsigned long j = g.node(i).neighbor(n).index(); |
|
|
|
if (i < j) |
|
{ |
|
g.node(i).edge(n) = dot(edge_weights, sample.node(i).edge(n)); |
|
} |
|
} |
|
|
|
} |
|
|
|
labels.clear(); |
|
std::vector<node_label> temp; |
|
find_max_factor_graph_potts(g, temp); |
|
for (unsigned long i = 0; i < temp.size(); ++i) |
|
{ |
|
if (temp[i] != 0) |
|
labels.push_back(true); |
|
else |
|
labels.push_back(false); |
|
} |
|
} |
|
|
|
template <typename graph_type> |
|
std::vector<bool> operator() ( |
|
const graph_type& sample |
|
) const |
|
{ |
|
std::vector<bool> temp; |
|
(*this)(sample, temp); |
|
return temp; |
|
} |
|
|
|
private: |
|
|
|
vector_type edge_weights; |
|
vector_type node_weights; |
|
}; |
|
|
|
|
|
|
|
|
|
template < |
|
typename vector_type |
|
> |
|
void serialize ( |
|
const graph_labeler<vector_type>& item, |
|
std::ostream& out |
|
) |
|
{ |
|
int version = 1; |
|
serialize(version, out); |
|
serialize(item.get_edge_weights(), out); |
|
serialize(item.get_node_weights(), out); |
|
} |
|
|
|
|
|
|
|
template < |
|
typename vector_type |
|
> |
|
void deserialize ( |
|
graph_labeler<vector_type>& item, |
|
std::istream& in |
|
) |
|
{ |
|
int version = 0; |
|
deserialize(version, in); |
|
if (version != 1) |
|
{ |
|
throw dlib::serialization_error("While deserializing graph_labeler, found unexpected version number of " + |
|
cast_to_string(version) + "."); |
|
} |
|
|
|
vector_type edge_weights, node_weights; |
|
deserialize(edge_weights, in); |
|
deserialize(node_weights, in); |
|
|
|
item = graph_labeler<vector_type>(edge_weights, node_weights); |
|
} |
|
|
|
|
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|