|
|
|
|
|
#ifndef DLIB_BAYES_UTILs_ |
|
#define DLIB_BAYES_UTILs_ |
|
|
|
#include "bayes_utils_abstract.h" |
|
|
|
#include <algorithm> |
|
#include <ctime> |
|
#include <memory> |
|
#include <vector> |
|
|
|
#include "../string.h" |
|
#include "../map.h" |
|
#include "../matrix.h" |
|
#include "../rand.h" |
|
#include "../array.h" |
|
#include "../set.h" |
|
#include "../algs.h" |
|
#include "../noncopyable.h" |
|
#include "../graph.h" |
|
|
|
namespace dlib |
|
{ |
|
|
|
|
|
|
|
class assignment |
|
{ |
|
public: |
|
|
|
assignment() |
|
{ |
|
} |
|
|
|
assignment( |
|
const assignment& a |
|
) |
|
{ |
|
a.reset(); |
|
while (a.move_next()) |
|
{ |
|
unsigned long idx = a.element().key(); |
|
unsigned long value = a.element().value(); |
|
vals.add(idx,value); |
|
} |
|
} |
|
|
|
assignment& operator = ( |
|
const assignment& rhs |
|
) |
|
{ |
|
if (this == &rhs) |
|
return *this; |
|
|
|
assignment(rhs).swap(*this); |
|
return *this; |
|
} |
|
|
|
void clear() |
|
{ |
|
vals.clear(); |
|
} |
|
|
|
bool operator < ( |
|
const assignment& item |
|
) const |
|
{ |
|
if (size() < item.size()) |
|
return true; |
|
else if (size() > item.size()) |
|
return false; |
|
|
|
reset(); |
|
item.reset(); |
|
while (move_next()) |
|
{ |
|
item.move_next(); |
|
if (element().key() < item.element().key()) |
|
return true; |
|
else if (element().key() > item.element().key()) |
|
return false; |
|
else if (element().value() < item.element().value()) |
|
return true; |
|
else if (element().value() > item.element().value()) |
|
return false; |
|
} |
|
|
|
return false; |
|
} |
|
|
|
bool has_index ( |
|
unsigned long idx |
|
) const |
|
{ |
|
return vals.is_in_domain(idx); |
|
} |
|
|
|
void add ( |
|
unsigned long idx, |
|
unsigned long value = 0 |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( has_index(idx) == false , |
|
"\tvoid assignment::add(idx)" |
|
<< "\n\tYou can't add the same index to an assignment object more than once" |
|
<< "\n\tidx: " << idx |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
vals.add(idx, value); |
|
} |
|
|
|
unsigned long& operator[] ( |
|
const long idx |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( has_index(idx) == true , |
|
"\tunsigned long assignment::operator[](idx)" |
|
<< "\n\tYou can't access an index value if it isn't already in the object" |
|
<< "\n\tidx: " << idx |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
return vals[idx]; |
|
} |
|
|
|
const unsigned long& operator[] ( |
|
const long idx |
|
) const |
|
{ |
|
|
|
DLIB_ASSERT( has_index(idx) == true , |
|
"\tunsigned long assignment::operator[](idx)" |
|
<< "\n\tYou can't access an index value if it isn't already in the object" |
|
<< "\n\tidx: " << idx |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
return vals[idx]; |
|
} |
|
|
|
void swap ( |
|
assignment& item |
|
) |
|
{ |
|
vals.swap(item.vals); |
|
} |
|
|
|
void remove ( |
|
unsigned long idx |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( has_index(idx) == true , |
|
"\tunsigned long assignment::remove(idx)" |
|
<< "\n\tYou can't remove an index value if it isn't already in the object" |
|
<< "\n\tidx: " << idx |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
vals.destroy(idx); |
|
} |
|
|
|
unsigned long size() const { return vals.size(); } |
|
|
|
void reset() const { vals.reset(); } |
|
|
|
bool move_next() const { return vals.move_next(); } |
|
|
|
map_pair<unsigned long, unsigned long>& element() |
|
{ |
|
|
|
DLIB_ASSERT(current_element_valid() == true, |
|
"\tmap_pair<unsigned long,unsigned long>& assignment::element()" |
|
<< "\n\tyou can't access the current element if it doesn't exist" |
|
<< "\n\tthis: " << this |
|
); |
|
return vals.element(); |
|
} |
|
|
|
const map_pair<unsigned long, unsigned long>& element() const |
|
{ |
|
|
|
DLIB_ASSERT(current_element_valid() == true, |
|
"\tconst map_pair<unsigned long,unsigned long>& assignment::element() const" |
|
<< "\n\tyou can't access the current element if it doesn't exist" |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
return vals.element(); |
|
} |
|
|
|
bool at_start() const { return vals.at_start(); } |
|
|
|
bool current_element_valid() const { return vals.current_element_valid(); } |
|
|
|
friend inline void serialize ( |
|
const assignment& item, |
|
std::ostream& out |
|
) |
|
{ |
|
serialize(item.vals, out); |
|
} |
|
|
|
friend inline void deserialize ( |
|
assignment& item, |
|
std::istream& in |
|
) |
|
{ |
|
deserialize(item.vals, in); |
|
} |
|
|
|
private: |
|
mutable dlib::map<unsigned long, unsigned long>::kernel_1b_c vals; |
|
}; |
|
|
|
inline std::ostream& operator << ( |
|
std::ostream& out, |
|
const assignment& a |
|
) |
|
{ |
|
a.reset(); |
|
out << "("; |
|
if (a.move_next()) |
|
out << a.element().key() << ":" << a.element().value(); |
|
|
|
while (a.move_next()) |
|
{ |
|
out << ", " << a.element().key() << ":" << a.element().value(); |
|
} |
|
|
|
out << ")"; |
|
return out; |
|
} |
|
|
|
|
|
inline void swap ( |
|
assignment& a, |
|
assignment& b |
|
) |
|
{ |
|
a.swap(b); |
|
} |
|
|
|
|
|
|
|
|
|
class joint_probability_table |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public: |
|
|
|
joint_probability_table ( |
|
const joint_probability_table& t |
|
) |
|
{ |
|
t.reset(); |
|
while (t.move_next()) |
|
{ |
|
assignment a = t.element().key(); |
|
double p = t.element().value(); |
|
set_probability(a,p); |
|
} |
|
} |
|
|
|
joint_probability_table() {} |
|
|
|
joint_probability_table& operator= ( |
|
const joint_probability_table& rhs |
|
) |
|
{ |
|
if (this == &rhs) |
|
return *this; |
|
joint_probability_table(rhs).swap(*this); |
|
return *this; |
|
} |
|
|
|
void set_probability ( |
|
const assignment& a, |
|
double p |
|
) |
|
{ |
|
|
|
DLIB_ASSERT(0.0 <= p && p <= 1.0, |
|
"\tvoid& joint_probability_table::set_probability(a,p)" |
|
<< "\n\tyou have given an invalid probability value" |
|
<< "\n\tp: " << p |
|
<< "\n\ta: " << a |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
if (table.is_in_domain(a)) |
|
{ |
|
table[a] = p; |
|
} |
|
else |
|
{ |
|
assignment temp(a); |
|
table.add(temp,p); |
|
} |
|
} |
|
|
|
bool has_entry_for ( |
|
const assignment& a |
|
) const |
|
{ |
|
return table.is_in_domain(a); |
|
} |
|
|
|
void add_probability ( |
|
const assignment& a, |
|
double p |
|
) |
|
{ |
|
|
|
DLIB_ASSERT(0.0 <= p && p <= 1.0, |
|
"\tvoid& joint_probability_table::add_probability(a,p)" |
|
<< "\n\tyou have given an invalid probability value" |
|
<< "\n\tp: " << p |
|
<< "\n\ta: " << a |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
if (table.is_in_domain(a)) |
|
{ |
|
table[a] += p; |
|
if (table[a] > 1.0) |
|
table[a] = 1.0; |
|
} |
|
else |
|
{ |
|
assignment temp(a); |
|
table.add(temp,p); |
|
} |
|
} |
|
|
|
double probability ( |
|
const assignment& a |
|
) const |
|
{ |
|
return table[a]; |
|
} |
|
|
|
void clear() |
|
{ |
|
table.clear(); |
|
} |
|
|
|
size_t size () const { return table.size(); } |
|
bool move_next() const { return table.move_next(); } |
|
void reset() const { table.reset(); } |
|
map_pair<assignment,double>& element() |
|
{ |
|
|
|
DLIB_ASSERT(current_element_valid() == true, |
|
"\tmap_pair<assignment,double>& joint_probability_table::element()" |
|
<< "\n\tyou can't access the current element if it doesn't exist" |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
return table.element(); |
|
} |
|
|
|
const map_pair<assignment,double>& element() const |
|
{ |
|
|
|
DLIB_ASSERT(current_element_valid() == true, |
|
"\tconst map_pair<assignment,double>& joint_probability_table::element() const" |
|
<< "\n\tyou can't access the current element if it doesn't exist" |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
return table.element(); |
|
} |
|
|
|
bool at_start() const { return table.at_start(); } |
|
|
|
bool current_element_valid() const { return table.current_element_valid(); } |
|
|
|
|
|
template <typename T> |
|
void marginalize ( |
|
const T& vars, |
|
joint_probability_table& out |
|
) const |
|
{ |
|
out.clear(); |
|
double p; |
|
reset(); |
|
while (move_next()) |
|
{ |
|
assignment a; |
|
const assignment& asrc = element().key(); |
|
p = element().value(); |
|
|
|
asrc.reset(); |
|
while (asrc.move_next()) |
|
{ |
|
if (vars.is_member(asrc.element().key())) |
|
a.add(asrc.element().key(), asrc.element().value()); |
|
} |
|
|
|
out.add_probability(a,p); |
|
} |
|
} |
|
|
|
void marginalize ( |
|
const unsigned long var, |
|
joint_probability_table& out |
|
) const |
|
{ |
|
out.clear(); |
|
double p; |
|
reset(); |
|
while (move_next()) |
|
{ |
|
assignment a; |
|
const assignment& asrc = element().key(); |
|
p = element().value(); |
|
|
|
asrc.reset(); |
|
while (asrc.move_next()) |
|
{ |
|
if (var == asrc.element().key()) |
|
a.add(asrc.element().key(), asrc.element().value()); |
|
} |
|
|
|
out.add_probability(a,p); |
|
} |
|
} |
|
|
|
void normalize ( |
|
) |
|
{ |
|
double sum = 0; |
|
|
|
reset(); |
|
while (move_next()) |
|
sum += element().value(); |
|
|
|
reset(); |
|
while (move_next()) |
|
element().value() /= sum; |
|
} |
|
|
|
void swap ( |
|
joint_probability_table& item |
|
) |
|
{ |
|
table.swap(item.table); |
|
} |
|
|
|
friend inline void serialize ( |
|
const joint_probability_table& item, |
|
std::ostream& out |
|
) |
|
{ |
|
serialize(item.table, out); |
|
} |
|
|
|
friend inline void deserialize ( |
|
joint_probability_table& item, |
|
std::istream& in |
|
) |
|
{ |
|
deserialize(item.table, in); |
|
} |
|
|
|
private: |
|
|
|
dlib::map<assignment, double >::kernel_1b_c table; |
|
}; |
|
|
|
inline void swap ( |
|
joint_probability_table& a, |
|
joint_probability_table& b |
|
) { a.swap(b); } |
|
|
|
|
|
|
|
class conditional_probability_table : noncopyable |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
public: |
|
|
|
conditional_probability_table() |
|
{ |
|
clear(); |
|
} |
|
|
|
void set_num_values ( |
|
unsigned long num |
|
) |
|
{ |
|
num_vals = num; |
|
table.clear(); |
|
} |
|
|
|
bool has_entry_for ( |
|
unsigned long value, |
|
const assignment& ps |
|
) const |
|
{ |
|
if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) |
|
return true; |
|
else |
|
return false; |
|
} |
|
|
|
unsigned long num_values ( |
|
) const { return num_vals; } |
|
|
|
void set_probability ( |
|
unsigned long value, |
|
const assignment& ps, |
|
double p |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( value < num_values() && 0.0 <= p && p <= 1.0 , |
|
"\tvoid conditional_probability_table::set_probability()" |
|
<< "\n\tinvalid arguments to set_probability" |
|
<< "\n\tvalue: " << value |
|
<< "\n\tnum_values(): " << num_values() |
|
<< "\n\tp: " << p |
|
<< "\n\tps: " << ps |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
if (table.is_in_domain(ps)) |
|
{ |
|
table[ps](value) = p; |
|
} |
|
else |
|
{ |
|
matrix<double,1> dist(num_vals); |
|
set_all_elements(dist,-1); |
|
dist(value) = p; |
|
assignment temp(ps); |
|
table.add(temp,dist); |
|
} |
|
} |
|
|
|
double probability( |
|
unsigned long value, |
|
const assignment& ps |
|
) const |
|
{ |
|
|
|
DLIB_ASSERT( value < num_values() && has_entry_for(value,ps) , |
|
"\tvoid conditional_probability_table::probability()" |
|
<< "\n\tinvalid arguments to probability" |
|
<< "\n\tvalue: " << value |
|
<< "\n\tnum_values(): " << num_values() |
|
<< "\n\tps: " << ps |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
return table[ps](value); |
|
} |
|
|
|
void clear() |
|
{ |
|
table.clear(); |
|
num_vals = 0; |
|
} |
|
|
|
void empty_table () |
|
{ |
|
table.clear(); |
|
} |
|
|
|
void swap ( |
|
conditional_probability_table& item |
|
) |
|
{ |
|
exchange(num_vals, item.num_vals); |
|
table.swap(item.table); |
|
} |
|
|
|
friend inline void serialize ( |
|
const conditional_probability_table& item, |
|
std::ostream& out |
|
) |
|
{ |
|
serialize(item.table, out); |
|
serialize(item.num_vals, out); |
|
} |
|
|
|
friend inline void deserialize ( |
|
conditional_probability_table& item, |
|
std::istream& in |
|
) |
|
{ |
|
deserialize(item.table, in); |
|
deserialize(item.num_vals, in); |
|
} |
|
|
|
private: |
|
dlib::map<assignment, matrix<double,1> >::kernel_1b_c table; |
|
unsigned long num_vals; |
|
}; |
|
|
|
inline void swap ( |
|
conditional_probability_table& a, |
|
conditional_probability_table& b |
|
) { a.swap(b); } |
|
|
|
|
|
|
|
class bayes_node : noncopyable |
|
{ |
|
public: |
|
bayes_node () |
|
{ |
|
is_instantiated = false; |
|
value_ = 0; |
|
} |
|
|
|
unsigned long value ( |
|
) const { return value_;} |
|
|
|
void set_value ( |
|
unsigned long new_value |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( new_value < table().num_values(), |
|
"\tvoid bayes_node::set_value(new_value)" |
|
<< "\n\tnew_value must be less than the number of possible values for this node" |
|
<< "\n\tnew_value: " << new_value |
|
<< "\n\ttable().num_values(): " << table().num_values() |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
value_ = new_value; |
|
} |
|
|
|
conditional_probability_table& table ( |
|
) { return table_; } |
|
|
|
const conditional_probability_table& table ( |
|
) const { return table_; } |
|
|
|
bool is_evidence ( |
|
) const { return is_instantiated; } |
|
|
|
void set_as_nonevidence ( |
|
) { is_instantiated = false; } |
|
|
|
void set_as_evidence ( |
|
) { is_instantiated = true; } |
|
|
|
void swap ( |
|
bayes_node& item |
|
) |
|
{ |
|
exchange(value_, item.value_); |
|
exchange(is_instantiated, item.is_instantiated); |
|
table_.swap(item.table_); |
|
} |
|
|
|
friend inline void serialize ( |
|
const bayes_node& item, |
|
std::ostream& out |
|
) |
|
{ |
|
serialize(item.value_, out); |
|
serialize(item.is_instantiated, out); |
|
serialize(item.table_, out); |
|
} |
|
|
|
friend inline void deserialize ( |
|
bayes_node& item, |
|
std::istream& in |
|
) |
|
{ |
|
deserialize(item.value_, in); |
|
deserialize(item.is_instantiated, in); |
|
deserialize(item.table_, in); |
|
} |
|
|
|
private: |
|
|
|
unsigned long value_; |
|
bool is_instantiated; |
|
conditional_probability_table table_; |
|
}; |
|
|
|
inline void swap ( |
|
bayes_node& a, |
|
bayes_node& b |
|
) { a.swap(b); } |
|
|
|
|
|
|
|
namespace bayes_node_utils |
|
{ |
|
|
|
template <typename T> |
|
unsigned long node_num_values ( |
|
const T& bn, |
|
unsigned long n |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tvoid bayes_node_utils::node_num_values(bn, n)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
); |
|
|
|
return bn.node(n).data.table().num_values(); |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
void set_node_value ( |
|
T& bn, |
|
unsigned long n, |
|
unsigned long val |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes() && val < node_num_values(bn,n), |
|
"\tvoid bayes_node_utils::set_node_value(bn, n, val)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tval: " << val |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) |
|
); |
|
|
|
bn.node(n).data.set_value(val); |
|
} |
|
|
|
|
|
template <typename T> |
|
unsigned long node_value ( |
|
const T& bn, |
|
unsigned long n |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tunsigned long bayes_node_utils::node_value(bn, n)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
); |
|
|
|
return bn.node(n).data.value(); |
|
} |
|
|
|
|
|
template <typename T> |
|
bool node_is_evidence ( |
|
const T& bn, |
|
unsigned long n |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tbool bayes_node_utils::node_is_evidence(bn, n)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
); |
|
|
|
return bn.node(n).data.is_evidence(); |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
void set_node_as_evidence ( |
|
T& bn, |
|
unsigned long n |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tvoid bayes_node_utils::set_node_as_evidence(bn, n)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
); |
|
|
|
bn.node(n).data.set_as_evidence(); |
|
} |
|
|
|
|
|
template <typename T> |
|
void set_node_as_nonevidence ( |
|
T& bn, |
|
unsigned long n |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tvoid bayes_node_utils::set_node_as_nonevidence(bn, n)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
); |
|
|
|
bn.node(n).data.set_as_nonevidence(); |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
void set_node_num_values ( |
|
T& bn, |
|
unsigned long n, |
|
unsigned long num |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tvoid bayes_node_utils::set_node_num_values(bn, n, num)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
); |
|
|
|
bn.node(n).data.table().set_num_values(num); |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
double node_probability ( |
|
const T& bn, |
|
unsigned long n, |
|
unsigned long value, |
|
const assignment& parents |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), |
|
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tvalue: " << value |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) |
|
); |
|
|
|
DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), |
|
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tparents.size(): " << parents.size() |
|
<< "\n\tb.node(n).number_of_parents(): " << bn.node(n).number_of_parents() |
|
); |
|
|
|
#ifdef ENABLE_ASSERTS |
|
parents.reset(); |
|
while (parents.move_next()) |
|
{ |
|
const unsigned long x = parents.element().key(); |
|
DLIB_ASSERT( bn.has_edge(x, n), |
|
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tx: " << x |
|
); |
|
DLIB_ASSERT( parents[x] < node_num_values(bn,x), |
|
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tx: " << x |
|
<< "\n\tparents[x]: " << parents[x] |
|
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) |
|
); |
|
} |
|
#endif |
|
|
|
return bn.node(n).data.table().probability(value, parents); |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
void set_node_probability ( |
|
T& bn, |
|
unsigned long n, |
|
unsigned long value, |
|
const assignment& parents, |
|
double p |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n), |
|
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tp: " << p |
|
<< "\n\tvalue: " << value |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n) |
|
); |
|
|
|
DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(), |
|
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tp: " << p |
|
<< "\n\tparents.size(): " << parents.size() |
|
<< "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() |
|
); |
|
|
|
DLIB_ASSERT( 0.0 <= p && p <= 1.0, |
|
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tp: " << p |
|
); |
|
|
|
#ifdef ENABLE_ASSERTS |
|
parents.reset(); |
|
while (parents.move_next()) |
|
{ |
|
const unsigned long x = parents.element().key(); |
|
DLIB_ASSERT( bn.has_edge(x, n), |
|
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tx: " << x |
|
); |
|
DLIB_ASSERT( parents[x] < node_num_values(bn,x), |
|
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tx: " << x |
|
<< "\n\tparents[x]: " << parents[x] |
|
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) |
|
); |
|
} |
|
#endif |
|
|
|
bn.node(n).data.table().set_probability(value,parents,p); |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
const assignment node_first_parent_assignment ( |
|
const T& bn, |
|
unsigned long n |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tconst assignment bayes_node_utils::node_first_parent_assignment(bn, n)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
); |
|
|
|
assignment a; |
|
const unsigned long num_parents = bn.node(n).number_of_parents(); |
|
for (unsigned long i = 0; i < num_parents; ++i) |
|
{ |
|
a.add(bn.node(n).parent(i).index(), 0); |
|
} |
|
return a; |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
bool node_next_parent_assignment ( |
|
const T& bn, |
|
unsigned long n, |
|
assignment& a |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
); |
|
|
|
DLIB_ASSERT( a.size() == bn.node(n).number_of_parents(), |
|
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\ta.size(): " << a.size() |
|
<< "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents() |
|
); |
|
|
|
#ifdef ENABLE_ASSERTS |
|
a.reset(); |
|
while (a.move_next()) |
|
{ |
|
const unsigned long x = a.element().key(); |
|
DLIB_ASSERT( bn.has_edge(x, n), |
|
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tx: " << x |
|
); |
|
DLIB_ASSERT( a[x] < node_num_values(bn,x), |
|
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tx: " << x |
|
<< "\n\ta[x]: " << a[x] |
|
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x) |
|
); |
|
} |
|
#endif |
|
|
|
|
|
|
|
for (unsigned long p = 0; p < a.size(); ++p) |
|
{ |
|
const unsigned long pindex = bn.node(n).parent(p).index(); |
|
a[pindex] += 1; |
|
|
|
|
|
if (a[pindex] >= node_num_values(bn,pindex)) |
|
{ |
|
a[pindex] = 0; |
|
} |
|
else |
|
{ |
|
|
|
return true; |
|
} |
|
} |
|
|
|
|
|
|
|
return false; |
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
bool node_cpt_filled_out ( |
|
const T& bn, |
|
unsigned long n |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( n < bn.number_of_nodes(), |
|
"\tbool bayes_node_utils::node_cpt_filled_out(bn, n)" |
|
<< "\n\tInvalid arguments to this function" |
|
<< "\n\tn: " << n |
|
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes() |
|
); |
|
|
|
const unsigned long num_values = node_num_values(bn,n); |
|
|
|
|
|
const conditional_probability_table& table = bn.node(n).data.table(); |
|
|
|
|
|
assignment a(node_first_parent_assignment(bn,n)); |
|
do |
|
{ |
|
double sum = 0; |
|
|
|
for (unsigned long value = 0; value < num_values; ++value) |
|
{ |
|
if (table.has_entry_for(value,a) == false) |
|
return false; |
|
else |
|
sum += table.probability(value,a); |
|
} |
|
|
|
|
|
if (std::abs(sum-1.0) > 1e-5) |
|
return false; |
|
} while (node_next_parent_assignment(bn,n,a)); |
|
|
|
return true; |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
class bayesian_network_gibbs_sampler : noncopyable |
|
{ |
|
public: |
|
|
|
bayesian_network_gibbs_sampler () |
|
{ |
|
rnd.set_seed(cast_to_string(std::time(0))); |
|
} |
|
|
|
|
|
template < |
|
typename T |
|
> |
|
void sample_graph ( |
|
T& bn |
|
) |
|
{ |
|
using namespace bayes_node_utils; |
|
for (unsigned long n = 0; n < bn.number_of_nodes(); ++n) |
|
{ |
|
if (node_is_evidence(bn, n)) |
|
continue; |
|
|
|
samples.set_size(node_num_values(bn,n)); |
|
|
|
for (long i = 0; i < samples.nc(); ++i) |
|
{ |
|
set_node_value(bn, n, i); |
|
samples(i) = node_probability(bn, n); |
|
|
|
for (unsigned long j = 0; j < bn.node(n).number_of_children(); ++j) |
|
samples(i) *= node_probability(bn, bn.node(n).child(j).index()); |
|
} |
|
|
|
|
|
samples /= sum(samples); |
|
|
|
|
|
|
|
double prob = rnd.get_random_double(); |
|
|
|
|
|
long j; |
|
for (j = 0; j < samples.nc()-1; ++j) |
|
{ |
|
if (prob <= samples(j)) |
|
break; |
|
else |
|
prob -= samples(j); |
|
} |
|
|
|
set_node_value(bn, n, j); |
|
} |
|
} |
|
|
|
|
|
private: |
|
|
|
template < |
|
typename T |
|
> |
|
double node_probability ( |
|
const T& bn, |
|
unsigned long n |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{ |
|
v.clear(); |
|
for (unsigned long i = 0; i < bn.node(n).number_of_parents(); ++i) |
|
{ |
|
v.add(bn.node(n).parent(i).index(), bn.node(n).parent(i).data.value()); |
|
} |
|
return bn.node(n).data.table().probability(bn.node(n).data.value(), v); |
|
} |
|
|
|
assignment v; |
|
|
|
dlib::rand rnd; |
|
matrix<double,1> samples; |
|
}; |
|
|
|
|
|
|
|
namespace bayesian_network_join_tree_helpers |
|
{ |
|
class bnjt |
|
{ |
|
|
|
|
|
|
|
public: |
|
virtual ~bnjt() {} |
|
|
|
virtual const matrix<double,1> probability( |
|
unsigned long idx |
|
) const = 0; |
|
}; |
|
|
|
template <typename T, typename U> |
|
class bnjt_impl : public bnjt |
|
{ |
|
|
|
|
|
|
|
|
|
public: |
|
|
|
bnjt_impl ( |
|
const T& bn, |
|
const U& join_tree |
|
) |
|
{ |
|
create_bayesian_network_join_tree(bn, join_tree, join_tree_values); |
|
|
|
cliques.resize(bn.number_of_nodes()); |
|
|
|
|
|
for (unsigned long i = 0; i < cliques.size(); ++i) |
|
{ |
|
|
|
unsigned long smallest_clique = 0; |
|
unsigned long size = std::numeric_limits<unsigned long>::max(); |
|
|
|
for (unsigned long n = 0; n < join_tree.number_of_nodes(); ++n) |
|
{ |
|
if (join_tree.node(n).data.is_member(i) && join_tree.node(n).data.size() < size) |
|
{ |
|
size = join_tree.node(n).data.size(); |
|
smallest_clique = n; |
|
} |
|
} |
|
|
|
cliques[i] = smallest_clique; |
|
} |
|
} |
|
|
|
virtual const matrix<double,1> probability( |
|
unsigned long idx |
|
) const |
|
{ |
|
join_tree_values.node(cliques[idx]).data.marginalize(idx, table); |
|
table.normalize(); |
|
var.clear(); |
|
var.add(idx); |
|
dist.set_size(table.size()); |
|
|
|
|
|
for (unsigned long i = 0; i < table.size(); ++i) |
|
{ |
|
var[idx] = i; |
|
dist(i) = table.probability(var); |
|
} |
|
|
|
return dist; |
|
} |
|
|
|
private: |
|
|
|
graph< joint_probability_table, joint_probability_table >::kernel_1a_c join_tree_values; |
|
array<unsigned long> cliques; |
|
mutable joint_probability_table table; |
|
mutable assignment var; |
|
mutable matrix<double,1> dist; |
|
|
|
|
|
|
|
|
|
template <typename set_type, typename node_type> |
|
bool set_contains_all_parents_of_node ( |
|
const set_type& set, |
|
const node_type& node |
|
) |
|
{ |
|
for (unsigned long i = 0; i < node.number_of_parents(); ++i) |
|
{ |
|
if (set.is_member(node.parent(i).index()) == false) |
|
return false; |
|
} |
|
return true; |
|
} |
|
|
|
|
|
|
|
template < |
|
typename V |
|
> |
|
void pass_join_tree_message ( |
|
const U& join_tree, |
|
V& bn_join_tree , |
|
unsigned long from, |
|
unsigned long to |
|
) |
|
{ |
|
using namespace bayes_node_utils; |
|
const typename U::edge_type& e = edge(join_tree, from, to); |
|
typename V::edge_type& old_s = edge(bn_join_tree, from, to); |
|
|
|
typedef typename V::edge_type joint_prob_table; |
|
|
|
joint_prob_table new_s; |
|
bn_join_tree.node(from).data.marginalize(e, new_s); |
|
|
|
joint_probability_table temp(new_s); |
|
|
|
|
|
|
|
if (old_s.size() > 0) |
|
{ |
|
temp.reset(); |
|
old_s.reset(); |
|
while (temp.move_next()) |
|
{ |
|
old_s.move_next(); |
|
if (old_s.element().value() != 0) |
|
temp.element().value() /= old_s.element().value(); |
|
} |
|
} |
|
|
|
|
|
joint_probability_table& d = bn_join_tree.node(to).data; |
|
d.reset(); |
|
while (d.move_next()) |
|
{ |
|
assignment a; |
|
const assignment& asrc = d.element().key(); |
|
asrc.reset(); |
|
while (asrc.move_next()) |
|
{ |
|
if (e.is_member(asrc.element().key())) |
|
a.add(asrc.element().key(), asrc.element().value()); |
|
} |
|
|
|
d.element().value() *= temp.probability(a); |
|
|
|
} |
|
|
|
|
|
new_s.swap(old_s); |
|
|
|
} |
|
|
|
|
|
|
|
template < |
|
typename V |
|
> |
|
void create_bayesian_network_join_tree ( |
|
const T& bn, |
|
const U& join_tree, |
|
V& bn_join_tree |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
{ |
|
using namespace bayes_node_utils; |
|
bn_join_tree.clear(); |
|
copy_graph_structure(join_tree, bn_join_tree); |
|
|
|
|
|
|
|
|
|
|
|
std::vector<unsigned long> node_assigned_to(bn.number_of_nodes(),join_tree.number_of_nodes()); |
|
|
|
|
|
dlib::map<unsigned long, unsigned long>::kernel_1b_c evidence; |
|
for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) |
|
{ |
|
if (node_is_evidence(bn, i)) |
|
{ |
|
unsigned long idx = i; |
|
unsigned long value = node_value(bn, i); |
|
evidence.add(idx,value); |
|
} |
|
} |
|
|
|
|
|
|
|
for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i) |
|
{ |
|
bool contains_evidence = false; |
|
std::vector<unsigned long> indices; |
|
assignment value; |
|
|
|
|
|
|
|
join_tree.node(i).data.reset(); |
|
while (join_tree.node(i).data.move_next()) |
|
{ |
|
const unsigned long idx = join_tree.node(i).data.element(); |
|
indices.push_back(idx); |
|
value.add(idx); |
|
|
|
if (evidence.is_in_domain(join_tree.node(i).data.element())) |
|
contains_evidence = true; |
|
} |
|
|
|
|
|
|
|
|
|
bool more_assignments = true; |
|
while (more_assignments) |
|
{ |
|
bn_join_tree.node(i).data.set_probability(value,1); |
|
|
|
|
|
if (contains_evidence) |
|
{ |
|
|
|
for (unsigned long j = 0; j < indices.size(); ++j) |
|
{ |
|
|
|
if (evidence.is_in_domain(indices[j])) |
|
{ |
|
const unsigned long idx = indices[j]; |
|
const unsigned long evidence_value = evidence[idx]; |
|
if (value[idx] != evidence_value) |
|
bn_join_tree.node(i).data.set_probability(value , 0); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
join_tree.node(i).data.reset(); |
|
while (join_tree.node(i).data.move_next()) |
|
{ |
|
const unsigned long idx = join_tree.node(i).data.element(); |
|
|
|
|
|
if (set_contains_all_parents_of_node(join_tree.node(i).data, bn.node(idx)) && |
|
(i == node_assigned_to[idx] || node_assigned_to[idx] == join_tree.number_of_nodes()) ) |
|
{ |
|
|
|
node_assigned_to[idx] = i; |
|
|
|
assignment parent_values; |
|
for (unsigned long j = 0; j < bn.node(idx).number_of_parents(); ++j) |
|
{ |
|
const unsigned long pidx = bn.node(idx).parent(j).index(); |
|
parent_values.add(pidx, value[pidx]); |
|
} |
|
|
|
double temp = bn_join_tree.node(i).data.probability(value); |
|
bn_join_tree.node(i).data.set_probability(value, temp * node_probability(bn, idx, value[idx], parent_values)); |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
more_assignments = false; |
|
value.reset(); |
|
while (value.move_next()) |
|
{ |
|
value.element().value() += 1; |
|
|
|
if (value.element().value() == node_num_values(bn, value.element().key())) |
|
{ |
|
value.element().value() = 0; |
|
} |
|
else |
|
{ |
|
more_assignments = true; |
|
break; |
|
} |
|
} |
|
|
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
dlib::array<dlib::set<unsigned long>::compare_1b_c> remaining_msg_to_send; |
|
dlib::array<dlib::set<unsigned long>::compare_1b_c> remaining_msg_to_receive; |
|
remaining_msg_to_receive.resize(join_tree.number_of_nodes()); |
|
remaining_msg_to_send.resize(join_tree.number_of_nodes()); |
|
for (unsigned long i = 0; i < remaining_msg_to_receive.size(); ++i) |
|
{ |
|
for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j) |
|
{ |
|
const unsigned long idx = join_tree.node(i).neighbor(j).index(); |
|
unsigned long temp; |
|
temp = idx; remaining_msg_to_receive[i].add(temp); |
|
temp = idx; remaining_msg_to_send[i].add(temp); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
bool message_sent = true; |
|
std::vector<unsigned long>::iterator iter; |
|
while (message_sent) |
|
{ |
|
message_sent = false; |
|
for (unsigned long i = 1; i < remaining_msg_to_send.size(); ++i) |
|
{ |
|
|
|
|
|
if (remaining_msg_to_send[i].size() == join_tree.node(i).number_of_neighbors() && remaining_msg_to_receive[i].size() == 1) |
|
{ |
|
unsigned long to; |
|
|
|
remaining_msg_to_receive[i].remove_any(to); |
|
|
|
|
|
pass_join_tree_message(join_tree, bn_join_tree, i, to); |
|
|
|
|
|
remaining_msg_to_send[i].destroy(to); |
|
remaining_msg_to_receive[to].destroy(i); |
|
|
|
|
|
remaining_msg_to_receive[i].add(to); |
|
message_sent = true; |
|
} |
|
else if (remaining_msg_to_receive[i].size() == 0 && remaining_msg_to_send[i].size() > 0) |
|
{ |
|
unsigned long to; |
|
remaining_msg_to_send[i].remove_any(to); |
|
remaining_msg_to_receive[to].destroy(i); |
|
pass_join_tree_message(join_tree, bn_join_tree, i, to); |
|
message_sent = true; |
|
} |
|
} |
|
|
|
if (remaining_msg_to_receive[0].size() == 0) |
|
{ |
|
|
|
while (remaining_msg_to_send[0].size() > 0) |
|
{ |
|
unsigned long to; |
|
remaining_msg_to_send[0].remove_any(to); |
|
remaining_msg_to_receive[to].destroy(0); |
|
pass_join_tree_message(join_tree, bn_join_tree, 0, to); |
|
message_sent = true; |
|
} |
|
} |
|
|
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
} |
|
|
|
class bayesian_network_join_tree : noncopyable |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
public: |
|
|
|
template < |
|
typename T, |
|
typename U |
|
> |
|
bayesian_network_join_tree ( |
|
const T& bn, |
|
const U& join_tree |
|
) |
|
{ |
|
|
|
DLIB_ASSERT( bn.number_of_nodes() > 0 , |
|
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" |
|
<< "\n\tYou have given an invalid bayesian network" |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
DLIB_ASSERT( is_join_tree(bn, join_tree) == true , |
|
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" |
|
<< "\n\tYou have given an invalid join tree for the supplied bayesian network" |
|
<< "\n\tthis: " << this |
|
); |
|
DLIB_ASSERT( graph_contains_length_one_cycle(bn) == false, |
|
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" |
|
<< "\n\tYou have given an invalid bayesian network" |
|
<< "\n\tthis: " << this |
|
); |
|
DLIB_ASSERT( graph_is_connected(bn) == true, |
|
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" |
|
<< "\n\tYou have given an invalid bayesian network" |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
#ifdef ENABLE_ASSERTS |
|
for (unsigned long i = 0; i < bn.number_of_nodes(); ++i) |
|
{ |
|
DLIB_ASSERT(bayes_node_utils::node_cpt_filled_out(bn,i) == true, |
|
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)" |
|
<< "\n\tYou have given an invalid bayesian network. " |
|
<< "\n\tYou must finish filling out the conditional_probability_table of node " << i |
|
<< "\n\tthis: " << this |
|
); |
|
} |
|
#endif |
|
|
|
impl.reset(new bayesian_network_join_tree_helpers::bnjt_impl<T,U>(bn, join_tree)); |
|
num_nodes = bn.number_of_nodes(); |
|
} |
|
|
|
const matrix<double,1> probability( |
|
unsigned long idx |
|
) const |
|
{ |
|
|
|
DLIB_ASSERT( idx < number_of_nodes() , |
|
"\tconst matrix<double,1> bayesian_network_join_tree::probability(idx)" |
|
<< "\n\tYou have specified an invalid node index" |
|
<< "\n\tidx: " << idx |
|
<< "\n\tnumber_of_nodes(): " << number_of_nodes() |
|
<< "\n\tthis: " << this |
|
); |
|
|
|
return impl->probability(idx); |
|
} |
|
|
|
unsigned long number_of_nodes ( |
|
) const { return num_nodes; } |
|
|
|
void swap ( |
|
bayesian_network_join_tree& item |
|
) |
|
{ |
|
exchange(num_nodes, item.num_nodes); |
|
impl.swap(item.impl); |
|
} |
|
|
|
private: |
|
|
|
std::unique_ptr<bayesian_network_join_tree_helpers::bnjt> impl; |
|
unsigned long num_nodes; |
|
|
|
}; |
|
|
|
inline void swap ( |
|
bayesian_network_join_tree& a, |
|
bayesian_network_join_tree& b |
|
) { a.swap(b); } |
|
|
|
} |
|
|
|
|
|
|
|
#endif |
|
|
|
|