// Copyright (C) 2012 Davis E. King ([email protected]) | |
// License: Boost Software License See LICENSE.txt for the full license. | |
namespace dlib | |
{ | |
// ---------------------------------------------------------------------------------------- | |
template < | |
typename vector_type | |
> | |
class graph_labeler | |
{ | |
/*! | |
REQUIREMENTS ON vector_type | |
- vector_type is a dlib::matrix capable of representing column | |
vectors or it is a sparse vector type as defined in dlib/svm/sparse_vector_abstract.h. | |
WHAT THIS OBJECT REPRESENTS | |
This object is a tool for labeling each node in a graph with a value | |
of true or false, subject to a labeling consistency constraint between | |
nodes that share an edge. In particular, this object is useful for | |
representing a graph labeling model learned via some machine learning | |
method. | |
To elaborate, suppose we have a graph we want to label. Moreover, | |
suppose we can assign a score to each node which represents how much | |
we want to label the node as true, and we also have scores for each | |
edge which represent how much we wanted the nodes sharing the edge to | |
have the same label. If we could do this then we could find the optimal | |
labeling using the find_max_factor_graph_potts() routine. Therefore, | |
the graph_labeler is just an object which contains the necessary data | |
to compute these score functions and then call find_max_factor_graph_potts(). | |
Additionally, this object uses linear functions to represent these score | |
functions. | |
THREAD SAFETY | |
It is always safe to use distinct instances of this object in different | |
threads. However, when a single instance is shared between threads then | |
the following rules apply: | |
It is safe to call the const members of this object from multiple | |
threads. This is because the const members are purely read-only | |
operations. However, any operation that modifies a graph_labeler is | |
not threadsafe. | |
!*/ | |
public: | |
typedef std::vector<bool> label_type; | |
typedef label_type result_type; | |
graph_labeler( | |
); | |
/*! | |
ensures | |
- this object is properly initialized | |
- #get_node_weights() == an initial value of type vector_type. | |
- #get_edge_weights() == an initial value of type vector_type. | |
!*/ | |
graph_labeler( | |
const vector_type& edge_weights, | |
const vector_type& node_weights | |
); | |
/*! | |
requires | |
- min(edge_weights) >= 0 | |
ensures | |
- #get_edge_weights() == edge_weights | |
- #get_node_weights() == node_weights | |
!*/ | |
const vector_type& get_edge_weights ( | |
) const; | |
/*! | |
ensures | |
- Recall that the score function for an edge is a linear function of | |
the vector stored at that edge. This means there is some vector, E, | |
which we dot product with the vector in the graph to compute the | |
score. Therefore, this function returns that E vector which defines | |
the edge score function. | |
!*/ | |
const vector_type& get_node_weights ( | |
) const; | |
/*! | |
ensures | |
- Recall that the score function for a node is a linear function of | |
the vector stored in that node. This means there is some vector, W, | |
which we dot product with the vector in the graph to compute the score. | |
Therefore, this function returns that W vector which defines the node | |
score function. | |
!*/ | |
template <typename graph_type> | |
void operator() ( | |
const graph_type& sample, | |
std::vector<bool>& labels | |
) const; | |
/*! | |
requires | |
- graph_type is an implementation of dlib/graph/graph_kernel_abstract.h | |
- graph_type::type and graph_type::edge_type must be either matrix objects | |
capable of representing column vectors or some kind of sparse vector | |
type as defined in dlib/svm/sparse_vector_abstract.h. | |
- graph_contains_length_one_cycle(sample) == false | |
- for all valid i and j: | |
- min(edge(sample,i,j)) >= 0 | |
- it must be legal to call dot(edge(sample,i,j), get_edge_weights()) | |
- it must be legal to call dot(sample.node(i).data, get_node_weights()) | |
ensures | |
- Computes a labeling for each node in the given graph and stores the result | |
in #labels. | |
- #labels.size() == sample.number_of_nodes() | |
- for all valid i: | |
- #labels[i] == the label of the node sample.node(i). | |
- The labels are computed by creating a graph, G, with scalar values on each node | |
and edge. The scalar values are calculated according to the following: | |
- for all valid i: | |
- G.node(i).data == dot(get_node_weights(), sample.node(i).data) | |
- for all valid i and j: | |
- edge(G,i,j) == dot(get_edge_weights(), edge(sample,i,j)) | |
Then the labels are computed by calling find_max_factor_graph_potts(G,#labels). | |
!*/ | |
template <typename graph_type> | |
std::vector<bool> operator() ( | |
const graph_type& sample | |
) const; | |
/*! | |
requires | |
- graph_type is an implementation of dlib/graph/graph_kernel_abstract.h | |
- graph_contains_length_one_cycle(sample) == false | |
- for all valid i and j: | |
- min(edge(sample,i,j)) >= 0 | |
- it must be legal to call dot(edge(sample,i,j), get_edge_weights()) | |
- it must be legal to call dot(sample.node(i).data, get_node_weights()) | |
ensures | |
- Performs (*this)(sample, labels); return labels; | |
(i.e. This is just another version of the above operator() routine | |
but instead of returning the labels via the second argument, it | |
returns them as the normal return value). | |
!*/ | |
}; | |
// ---------------------------------------------------------------------------------------- | |
template < | |
typename vector_type | |
> | |
void serialize ( | |
const graph_labeler<vector_type>& item, | |
std::ostream& out | |
); | |
/*! | |
provides serialization support | |
!*/ | |
// ---------------------------------------------------------------------------------------- | |
template < | |
typename vector_type | |
> | |
void deserialize ( | |
graph_labeler<vector_type>& item, | |
std::istream& in | |
); | |
/*! | |
provides deserialization support | |
!*/ | |
// ---------------------------------------------------------------------------------------- | |
} | |