File size: 7,672 Bytes
9375c9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
// Copyright (C) 2012 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_GRAPH_LaBELER_Hh_
#define DLIB_GRAPH_LaBELER_Hh_
#include "graph_labeler_abstract.h"
#include "../matrix.h"
#include "../string.h"
#include <vector>
#include "find_max_factor_graph_potts.h"
#include "../svm/sparse_vector.h"
#include "../graph.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
class graph_labeler
{
public:
typedef std::vector<bool> label_type;
typedef label_type result_type;
graph_labeler()
{
}
graph_labeler(
const vector_type& edge_weights_,
const vector_type& node_weights_
) :
edge_weights(edge_weights_),
node_weights(node_weights_)
{
// make sure requires clause is not broken
DLIB_ASSERT(edge_weights.size() == 0 || min(edge_weights) >= 0,
"\t graph_labeler::graph_labeler()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t min(edge_weights): " << min(edge_weights)
<< "\n\t this: " << this
);
}
const vector_type& get_edge_weights (
) const { return edge_weights; }
const vector_type& get_node_weights (
) const { return node_weights; }
template <typename graph_type>
void operator() (
const graph_type& sample,
std::vector<bool>& labels
) const
{
// make sure requires clause is not broken
#ifdef ENABLE_ASSERTS
DLIB_ASSERT(graph_contains_length_one_cycle(sample) == false,
"\t void graph_labeler::operator()"
<< "\n\t Invalid inputs were given to this function."
<< "\n\t get_edge_weights().size(): " << get_edge_weights().size()
<< "\n\t get_node_weights().size(): " << get_node_weights().size()
<< "\n\t graph_contains_length_one_cycle(sample): " << graph_contains_length_one_cycle(sample)
<< "\n\t this: " << this
);
for (unsigned long i = 0; i < sample.number_of_nodes(); ++i)
{
if (is_matrix<vector_type>::value &&
is_matrix<typename graph_type::type>::value)
{
// check that dot() is legal.
DLIB_ASSERT((unsigned long)get_node_weights().size() == (unsigned long)sample.node(i).data.size(),
"\t void graph_labeler::operator()"
<< "\n\t The size of the node weight vector must match the one in the node."
<< "\n\t get_node_weights().size(): " << get_node_weights().size()
<< "\n\t sample.node(i).data.size(): " << sample.node(i).data.size()
<< "\n\t i: " << i
<< "\n\t this: " << this
);
}
for (unsigned long n = 0; n < sample.node(i).number_of_neighbors(); ++n)
{
if (is_matrix<vector_type>::value &&
is_matrix<typename graph_type::edge_type>::value)
{
// check that dot() is legal.
DLIB_ASSERT((unsigned long)get_edge_weights().size() == (unsigned long)sample.node(i).edge(n).size(),
"\t void graph_labeler::operator()"
<< "\n\t The size of the edge weight vector must match the one in graph's edge."
<< "\n\t get_edge_weights().size(): " << get_edge_weights().size()
<< "\n\t sample.node(i).edge(n).size(): " << sample.node(i).edge(n).size()
<< "\n\t i: " << i
<< "\n\t this: " << this
);
}
DLIB_ASSERT(sample.node(i).edge(n).size() == 0 || min(sample.node(i).edge(n)) >= 0,
"\t void graph_labeler::operator()"
<< "\n\t No edge vectors are allowed to have negative elements."
<< "\n\t min(sample.node(i).edge(n)): " << min(sample.node(i).edge(n))
<< "\n\t i: " << i
<< "\n\t n: " << n
<< "\n\t this: " << this
);
}
}
#endif
graph<double,double>::kernel_1a g;
copy_graph_structure(sample, g);
for (unsigned long i = 0; i < g.number_of_nodes(); ++i)
{
g.node(i).data = dot(node_weights, sample.node(i).data);
for (unsigned long n = 0; n < g.node(i).number_of_neighbors(); ++n)
{
const unsigned long j = g.node(i).neighbor(n).index();
// Don't compute an edge weight more than once.
if (i < j)
{
g.node(i).edge(n) = dot(edge_weights, sample.node(i).edge(n));
}
}
}
labels.clear();
std::vector<node_label> temp;
find_max_factor_graph_potts(g, temp);
for (unsigned long i = 0; i < temp.size(); ++i)
{
if (temp[i] != 0)
labels.push_back(true);
else
labels.push_back(false);
}
}
template <typename graph_type>
std::vector<bool> operator() (
const graph_type& sample
) const
{
std::vector<bool> temp;
(*this)(sample, temp);
return temp;
}
private:
vector_type edge_weights;
vector_type node_weights;
};
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void serialize (
const graph_labeler<vector_type>& item,
std::ostream& out
)
{
int version = 1;
serialize(version, out);
serialize(item.get_edge_weights(), out);
serialize(item.get_node_weights(), out);
}
// ----------------------------------------------------------------------------------------
template <
typename vector_type
>
void deserialize (
graph_labeler<vector_type>& item,
std::istream& in
)
{
int version = 0;
deserialize(version, in);
if (version != 1)
{
throw dlib::serialization_error("While deserializing graph_labeler, found unexpected version number of " +
cast_to_string(version) + ".");
}
vector_type edge_weights, node_weights;
deserialize(edge_weights, in);
deserialize(node_weights, in);
item = graph_labeler<vector_type>(edge_weights, node_weights);
}
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_GRAPH_LaBELER_Hh_
|