AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
// Copyright (C) 2012 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_GRAPH_LaBELER_Hh_
#define DLIB_GRAPH_LaBELER_Hh_
#include "graph_labeler_abstract.h"
#include "../matrix.h"
#include "../string.h"
#include <vector>
#include "find_max_factor_graph_potts.h"
#include "../svm/sparse_vector.h"
#include "../graph.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
class graph_labeler
{
public:
typedef std::vector<bool> label_type;
typedef label_type result_type;
graph_labeler()
{
}
graph_labeler(
const vector_type& edge_weights_,
const vector_type& node_weights_
) :
edge_weights(edge_weights_),
node_weights(node_weights_)
{
// make sure requires clause is not broken
DLIB_ASSERT(edge_weights.size() == 0 || min(edge_weights) >= 0,
"\t graph_labeler::graph_labeler()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t min(edge_weights): " << min(edge_weights)
<< "\n\t this: " << this
);
}
const vector_type& get_edge_weights (
) const { return edge_weights; }
const vector_type& get_node_weights (
) const { return node_weights; }
template <typename graph_type>
void operator() (
const graph_type& sample,
std::vector<bool>& labels
) const
{
// make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
DLIB_ASSERT(graph_contains_length_one_cycle(sample) == false,
"\t void graph_labeler::operator()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t get_edge_weights().size(): " << get_edge_weights().size()
<< "\n\t get_node_weights().size(): " << get_node_weights().size()
<< "\n\t graph_contains_length_one_cycle(sample): " << graph_contains_length_one_cycle(sample)
<< "\n\t this: " << this
);
for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
{
if (is_matrix<vector_type>::value &&
is_matrix<typename graph_type::type>::value)
{
// check that dot() is legal.
DLIB_ASSERT((unsigned long)get_node_weights().size() == (unsigned long)sample.node(i).data.size(),
"\t void graph_labeler::operator()"
<< "\n\t The size of the node weight vector must match the one in the node."
<< "\n\t get_node_weights().size(): " << get_node_weights().size()
<< "\n\t sample.node(i).data.size(): " << sample.node(i).data.size()
<< "\n\t i: " << i
<< "\n\t this: " << this
);
}
for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
{
if (is_matrix<vector_type>::value &&
is_matrix<typename graph_type::edge_type>::value)
{
// check that dot() is legal.
DLIB_ASSERT((unsigned long)get_edge_weights().size() == (unsigned long)sample.node(i).edge(n).size(),
"\t void graph_labeler::operator()"
<< "\n\t The size of the edge weight vector must match the one in graph's edge."
<< "\n\t get_edge_weights().size(): " << get_edge_weights().size()
<< "\n\t sample.node(i).edge(n).size(): " << sample.node(i).edge(n).size()
<< "\n\t i: " << i
<< "\n\t this: " << this
);
}
DLIB_ASSERT(sample.node(i).edge(n).size() == 0 || min(sample.node(i).edge(n)) >= 0,
"\t void graph_labeler::operator()"
<< "\n\t No edge vectors are allowed to have negative elements."
<< "\n\t min(sample.node(i).edge(n)): " << min(sample.node(i).edge(n))
<< "\n\t i: " << i
<< "\n\t n: " << n
<< "\n\t this: " << this
);
}
}
#endif
graph<double,double>::kernel_1a g;
copy_graph_structure(sample, g);
for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
{
g.node(i).data = dot(node_weights, sample.node(i).data);
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
{
const unsigned long j = g.node(i).neighbor(n).index();
// Don't compute an edge weight more than once.
if (i < j)
{
g.node(i).edge(n) = dot(edge_weights, sample.node(i).edge(n));
}
}
}
labels.clear();
std::vector<node_label> temp;
find_max_factor_graph_potts(g, temp);
for (unsigned long i = 0; i < temp.size(); ++i)
{
if (temp[i] != 0)
labels.push_back(true);
else
labels.push_back(false);
}
}
template <typename graph_type>
std::vector<bool> operator() (
const graph_type& sample
) const
{
std::vector<bool> temp;
(*this)(sample, temp);
return temp;
}
private:
vector_type edge_weights;
vector_type node_weights;
};
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void serialize (
const graph_labeler<vector_type>& item,
std::ostream& out
)
{
int version = 1;
serialize(version, out);
serialize(item.get_edge_weights(), out);
serialize(item.get_node_weights(), out);
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void deserialize (
graph_labeler<vector_type>& item,
std::istream& in
)
{
int version = 0;
deserialize(version, in);
if (version != 1)
{
throw dlib::serialization_error("While deserializing graph_labeler, found unexpected version number of " +
cast_to_string(version) + ".");
}
vector_type edge_weights, node_weights;
deserialize(edge_weights, in);
deserialize(node_weights, in);
item = graph_labeler<vector_type>(edge_weights, node_weights);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_GRAPH_LaBELER_Hh_