AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
// Copyright (C) 2007 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BAYES_UTILs_
#define DLIB_BAYES_UTILs_
#include "bayes_utils_abstract.h"
#include <algorithm>
#include <ctime>
#include <memory>
#include <vector>
#include "../string.h"
#include "../map.h"
#include "../matrix.h"
#include "../rand.h"
#include "../array.h"
#include "../set.h"
#include "../algs.h"
#include "../noncopyable.h"
#include "../graph.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
class assignment
{
public:
assignment()
{
}
assignment(
const assignment& a
)
{
a.reset();
while (a.move_next())
{
unsigned long idx = a.element().key();
unsigned long value = a.element().value();
vals.add(idx,value);
}
}
assignment& operator = (
const assignment& rhs
)
{
if (this == &rhs)
return *this;
assignment(rhs).swap(*this);
return *this;
}
void clear()
{
vals.clear();
}
bool operator < (
const assignment& item
) const
{
if (size() < item.size())
return true;
else if (size() > item.size())
return false;
reset();
item.reset();
while (move_next())
{
item.move_next();
if (element().key() < item.element().key())
return true;
else if (element().key() > item.element().key())
return false;
else if (element().value() < item.element().value())
return true;
else if (element().value() > item.element().value())
return false;
}
return false;
}
bool has_index (
unsigned long idx
) const
{
return vals.is_in_domain(idx);
}
void add (
unsigned long idx,
unsigned long value = 0
)
{
// make sure requires clause is not broken
DLIB_ASSERT( has_index(idx) == false ,
"\tvoid assignment::add(idx)"
<< "\n\tYou can't add the same index to an assignment object more than once"
<< "\n\tidx: " << idx
<< "\n\tthis: " << this
);
vals.add(idx, value);
}
unsigned long& operator[] (
const long idx
)
{
// make sure requires clause is not broken
DLIB_ASSERT( has_index(idx) == true ,
"\tunsigned long assignment::operator[](idx)"
<< "\n\tYou can't access an index value if it isn't already in the object"
<< "\n\tidx: " << idx
<< "\n\tthis: " << this
);
return vals[idx];
}
const unsigned long& operator[] (
const long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT( has_index(idx) == true ,
"\tunsigned long assignment::operator[](idx)"
<< "\n\tYou can't access an index value if it isn't already in the object"
<< "\n\tidx: " << idx
<< "\n\tthis: " << this
);
return vals[idx];
}
void swap (
assignment& item
)
{
vals.swap(item.vals);
}
void remove (
unsigned long idx
)
{
// make sure requires clause is not broken
DLIB_ASSERT( has_index(idx) == true ,
"\tunsigned long assignment::remove(idx)"
<< "\n\tYou can't remove an index value if it isn't already in the object"
<< "\n\tidx: " << idx
<< "\n\tthis: " << this
);
vals.destroy(idx);
}
unsigned long size() const { return vals.size(); }
void reset() const { vals.reset(); }
bool move_next() const { return vals.move_next(); }
map_pair<unsigned long, unsigned long>& element()
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tmap_pair<unsigned long,unsigned long>& assignment::element()"
<< "\n\tyou can't access the current element if it doesn't exist"
<< "\n\tthis: " << this
);
return vals.element();
}
const map_pair<unsigned long, unsigned long>& element() const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tconst map_pair<unsigned long,unsigned long>& assignment::element() const"
<< "\n\tyou can't access the current element if it doesn't exist"
<< "\n\tthis: " << this
);
return vals.element();
}
bool at_start() const { return vals.at_start(); }
bool current_element_valid() const { return vals.current_element_valid(); }
friend inline void serialize (
const assignment& item,
std::ostream& out
)
{
serialize(item.vals, out);
}
friend inline void deserialize (
assignment& item,
std::istream& in
)
{
deserialize(item.vals, in);
}
private:
mutable dlib::map<unsigned long, unsigned long>::kernel_1b_c vals;
};
inline std::ostream& operator << (
std::ostream& out,
const assignment& a
)
{
a.reset();
out << "(";
if (a.move_next())
out << a.element().key() << ":" << a.element().value();
while (a.move_next())
{
out << ", " << a.element().key() << ":" << a.element().value();
}
out << ")";
return out;
}
inline void swap (
assignment& a,
assignment& b
)
{
a.swap(b);
}
// ------------------------------------------------------------------------
class joint_probability_table
{
/*!
INITIAL VALUE
- table.size() == 0
CONVENTION
- size() == table.size()
- probability(a) == table[a]
!*/
public:
joint_probability_table (
const joint_probability_table& t
)
{
t.reset();
while (t.move_next())
{
assignment a = t.element().key();
double p = t.element().value();
set_probability(a,p);
}
}
joint_probability_table() {}
joint_probability_table& operator= (
const joint_probability_table& rhs
)
{
if (this == &rhs)
return *this;
joint_probability_table(rhs).swap(*this);
return *this;
}
void set_probability (
const assignment& a,
double p
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0.0 <= p && p <= 1.0,
"\tvoid& joint_probability_table::set_probability(a,p)"
<< "\n\tyou have given an invalid probability value"
<< "\n\tp: " << p
<< "\n\ta: " << a
<< "\n\tthis: " << this
);
if (table.is_in_domain(a))
{
table[a] = p;
}
else
{
assignment temp(a);
table.add(temp,p);
}
}
bool has_entry_for (
const assignment& a
) const
{
return table.is_in_domain(a);
}
void add_probability (
const assignment& a,
double p
)
{
// make sure requires clause is not broken
DLIB_ASSERT(0.0 <= p && p <= 1.0,
"\tvoid& joint_probability_table::add_probability(a,p)"
<< "\n\tyou have given an invalid probability value"
<< "\n\tp: " << p
<< "\n\ta: " << a
<< "\n\tthis: " << this
);
if (table.is_in_domain(a))
{
table[a] += p;
if (table[a] > 1.0)
table[a] = 1.0;
}
else
{
assignment temp(a);
table.add(temp,p);
}
}
double probability (
const assignment& a
) const
{
return table[a];
}
void clear()
{
table.clear();
}
size_t size () const { return table.size(); }
bool move_next() const { return table.move_next(); }
void reset() const { table.reset(); }
map_pair<assignment,double>& element()
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tmap_pair<assignment,double>& joint_probability_table::element()"
<< "\n\tyou can't access the current element if it doesn't exist"
<< "\n\tthis: " << this
);
return table.element();
}
const map_pair<assignment,double>& element() const
{
// make sure requires clause is not broken
DLIB_ASSERT(current_element_valid() == true,
"\tconst map_pair<assignment,double>& joint_probability_table::element() const"
<< "\n\tyou can't access the current element if it doesn't exist"
<< "\n\tthis: " << this
);
return table.element();
}
bool at_start() const { return table.at_start(); }
bool current_element_valid() const { return table.current_element_valid(); }
template <typename T>
void marginalize (
const T& vars,
joint_probability_table& out
) const
{
out.clear();
double p;
reset();
while (move_next())
{
assignment a;
const assignment& asrc = element().key();
p = element().value();
asrc.reset();
while (asrc.move_next())
{
if (vars.is_member(asrc.element().key()))
a.add(asrc.element().key(), asrc.element().value());
}
out.add_probability(a,p);
}
}
void marginalize (
const unsigned long var,
joint_probability_table& out
) const
{
out.clear();
double p;
reset();
while (move_next())
{
assignment a;
const assignment& asrc = element().key();
p = element().value();
asrc.reset();
while (asrc.move_next())
{
if (var == asrc.element().key())
a.add(asrc.element().key(), asrc.element().value());
}
out.add_probability(a,p);
}
}
void normalize (
)
{
double sum = 0;
reset();
while (move_next())
sum += element().value();
reset();
while (move_next())
element().value() /= sum;
}
void swap (
joint_probability_table& item
)
{
table.swap(item.table);
}
friend inline void serialize (
const joint_probability_table& item,
std::ostream& out
)
{
serialize(item.table, out);
}
friend inline void deserialize (
joint_probability_table& item,
std::istream& in
)
{
deserialize(item.table, in);
}
private:
dlib::map<assignment, double >::kernel_1b_c table;
};
inline void swap (
joint_probability_table& a,
joint_probability_table& b
) { a.swap(b); }
// ----------------------------------------------------------------------------------------
class conditional_probability_table : noncopyable
{
/*!
INITIAL VALUE
- table.size() == 0
CONVENTION
- if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0) then
- has_entry_for(value,ps) == true
- probability(value,ps) == table[ps](value)
- else
- has_entry_for(value,ps) == false
- num_values() == num_vals
!*/
public:
conditional_probability_table()
{
clear();
}
void set_num_values (
unsigned long num
)
{
num_vals = num;
table.clear();
}
bool has_entry_for (
unsigned long value,
const assignment& ps
) const
{
if (table.is_in_domain(ps) && value < num_vals && table[ps](value) >= 0)
return true;
else
return false;
}
unsigned long num_values (
) const { return num_vals; }
void set_probability (
unsigned long value,
const assignment& ps,
double p
)
{
// make sure requires clause is not broken
DLIB_ASSERT( value < num_values() && 0.0 <= p && p <= 1.0 ,
"\tvoid conditional_probability_table::set_probability()"
<< "\n\tinvalid arguments to set_probability"
<< "\n\tvalue: " << value
<< "\n\tnum_values(): " << num_values()
<< "\n\tp: " << p
<< "\n\tps: " << ps
<< "\n\tthis: " << this
);
if (table.is_in_domain(ps))
{
table[ps](value) = p;
}
else
{
matrix<double,1> dist(num_vals);
set_all_elements(dist,-1);
dist(value) = p;
assignment temp(ps);
table.add(temp,dist);
}
}
double probability(
unsigned long value,
const assignment& ps
) const
{
// make sure requires clause is not broken
DLIB_ASSERT( value < num_values() && has_entry_for(value,ps) ,
"\tvoid conditional_probability_table::probability()"
<< "\n\tinvalid arguments to probability"
<< "\n\tvalue: " << value
<< "\n\tnum_values(): " << num_values()
<< "\n\tps: " << ps
<< "\n\tthis: " << this
);
return table[ps](value);
}
void clear()
{
table.clear();
num_vals = 0;
}
void empty_table ()
{
table.clear();
}
void swap (
conditional_probability_table& item
)
{
exchange(num_vals, item.num_vals);
table.swap(item.table);
}
friend inline void serialize (
const conditional_probability_table& item,
std::ostream& out
)
{
serialize(item.table, out);
serialize(item.num_vals, out);
}
friend inline void deserialize (
conditional_probability_table& item,
std::istream& in
)
{
deserialize(item.table, in);
deserialize(item.num_vals, in);
}
private:
dlib::map<assignment, matrix<double,1> >::kernel_1b_c table;
unsigned long num_vals;
};
inline void swap (
conditional_probability_table& a,
conditional_probability_table& b
) { a.swap(b); }
// ------------------------------------------------------------------------
class bayes_node : noncopyable
{
public:
bayes_node ()
{
is_instantiated = false;
value_ = 0;
}
unsigned long value (
) const { return value_;}
void set_value (
unsigned long new_value
)
{
// make sure requires clause is not broken
DLIB_ASSERT( new_value < table().num_values(),
"\tvoid bayes_node::set_value(new_value)"
<< "\n\tnew_value must be less than the number of possible values for this node"
<< "\n\tnew_value: " << new_value
<< "\n\ttable().num_values(): " << table().num_values()
<< "\n\tthis: " << this
);
value_ = new_value;
}
conditional_probability_table& table (
) { return table_; }
const conditional_probability_table& table (
) const { return table_; }
bool is_evidence (
) const { return is_instantiated; }
void set_as_nonevidence (
) { is_instantiated = false; }
void set_as_evidence (
) { is_instantiated = true; }
void swap (
bayes_node& item
)
{
exchange(value_, item.value_);
exchange(is_instantiated, item.is_instantiated);
table_.swap(item.table_);
}
friend inline void serialize (
const bayes_node& item,
std::ostream& out
)
{
serialize(item.value_, out);
serialize(item.is_instantiated, out);
serialize(item.table_, out);
}
friend inline void deserialize (
bayes_node& item,
std::istream& in
)
{
deserialize(item.value_, in);
deserialize(item.is_instantiated, in);
deserialize(item.table_, in);
}
private:
unsigned long value_;
bool is_instantiated;
conditional_probability_table table_;
};
inline void swap (
bayes_node& a,
bayes_node& b
) { a.swap(b); }
// ------------------------------------------------------------------------
namespace bayes_node_utils
{
template <typename T>
unsigned long node_num_values (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tvoid bayes_node_utils::node_num_values(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
return bn.node(n).data.table().num_values();
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_value (
T& bn,
unsigned long n,
unsigned long val
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes() && val < node_num_values(bn,n),
"\tvoid bayes_node_utils::set_node_value(bn, n, val)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tval: " << val
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
);
bn.node(n).data.set_value(val);
}
// ----------------------------------------------------------------------------------------
template <typename T>
unsigned long node_value (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tunsigned long bayes_node_utils::node_value(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
return bn.node(n).data.value();
}
// ----------------------------------------------------------------------------------------
template <typename T>
bool node_is_evidence (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tbool bayes_node_utils::node_is_evidence(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
return bn.node(n).data.is_evidence();
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_as_evidence (
T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tvoid bayes_node_utils::set_node_as_evidence(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
bn.node(n).data.set_as_evidence();
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_as_nonevidence (
T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tvoid bayes_node_utils::set_node_as_nonevidence(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
bn.node(n).data.set_as_nonevidence();
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_num_values (
T& bn,
unsigned long n,
unsigned long num
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tvoid bayes_node_utils::set_node_num_values(bn, n, num)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
bn.node(n).data.table().set_num_values(num);
}
// ----------------------------------------------------------------------------------------
template <typename T>
double node_probability (
const T& bn,
unsigned long n,
unsigned long value,
const assignment& parents
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n),
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tvalue: " << value
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
);
DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(),
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tparents.size(): " << parents.size()
<< "\n\tb.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
);
#ifdef ENABLE_ASSERTS
parents.reset();
while (parents.move_next())
{
const unsigned long x = parents.element().key();
DLIB_ASSERT( bn.has_edge(x, n),
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
);
DLIB_ASSERT( parents[x] < node_num_values(bn,x),
"\tdouble bayes_node_utils::node_probability(bn, n, value, parents)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
<< "\n\tparents[x]: " << parents[x]
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
);
}
#endif
return bn.node(n).data.table().probability(value, parents);
}
// ----------------------------------------------------------------------------------------
template <typename T>
void set_node_probability (
T& bn,
unsigned long n,
unsigned long value,
const assignment& parents,
double p
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes() && value < node_num_values(bn,n),
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tp: " << p
<< "\n\tvalue: " << value
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
<< "\n\tnode_num_values(bn,n): " << node_num_values(bn,n)
);
DLIB_ASSERT( parents.size() == bn.node(n).number_of_parents(),
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tp: " << p
<< "\n\tparents.size(): " << parents.size()
<< "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
);
DLIB_ASSERT( 0.0 <= p && p <= 1.0,
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tp: " << p
);
#ifdef ENABLE_ASSERTS
parents.reset();
while (parents.move_next())
{
const unsigned long x = parents.element().key();
DLIB_ASSERT( bn.has_edge(x, n),
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
);
DLIB_ASSERT( parents[x] < node_num_values(bn,x),
"\tvoid bayes_node_utils::set_node_probability(bn, n, value, parents, p)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
<< "\n\tparents[x]: " << parents[x]
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
);
}
#endif
bn.node(n).data.table().set_probability(value,parents,p);
}
// ----------------------------------------------------------------------------------------
template <typename T>
const assignment node_first_parent_assignment (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tconst assignment bayes_node_utils::node_first_parent_assignment(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
);
assignment a;
const unsigned long num_parents = bn.node(n).number_of_parents();
for (unsigned long i = 0; i < num_parents; ++i)
{
a.add(bn.node(n).parent(i).index(), 0);
}
return a;
}
// ----------------------------------------------------------------------------------------
template <typename T>
bool node_next_parent_assignment (
const T& bn,
unsigned long n,
assignment& a
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
);
DLIB_ASSERT( a.size() == bn.node(n).number_of_parents(),
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\ta.size(): " << a.size()
<< "\n\tbn.node(n).number_of_parents(): " << bn.node(n).number_of_parents()
);
#ifdef ENABLE_ASSERTS
a.reset();
while (a.move_next())
{
const unsigned long x = a.element().key();
DLIB_ASSERT( bn.has_edge(x, n),
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
);
DLIB_ASSERT( a[x] < node_num_values(bn,x),
"\tbool bayes_node_utils::node_next_parent_assignment(bn, n, a)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tx: " << x
<< "\n\ta[x]: " << a[x]
<< "\n\tnode_num_values(bn,x): " << node_num_values(bn,x)
);
}
#endif
// basically this loop just adds 1 to the assignment but performs
// carries if necessary
for (unsigned long p = 0; p < a.size(); ++p)
{
const unsigned long pindex = bn.node(n).parent(p).index();
a[pindex] += 1;
// if we need to perform a carry
if (a[pindex] >= node_num_values(bn,pindex))
{
a[pindex] = 0;
}
else
{
// no carry necessary so we are done
return true;
}
}
// we got through the entire loop which means a carry propagated all the way out
// so there must not be any more valid assignments left
return false;
}
// ----------------------------------------------------------------------------------------
template <typename T>
bool node_cpt_filled_out (
const T& bn,
unsigned long n
)
{
// make sure requires clause is not broken
DLIB_ASSERT( n < bn.number_of_nodes(),
"\tbool bayes_node_utils::node_cpt_filled_out(bn, n)"
<< "\n\tInvalid arguments to this function"
<< "\n\tn: " << n
<< "\n\tbn.number_of_nodes(): " << bn.number_of_nodes()
);
const unsigned long num_values = node_num_values(bn,n);
const conditional_probability_table& table = bn.node(n).data.table();
// now loop over all the possible parent assignments for this node
assignment a(node_first_parent_assignment(bn,n));
do
{
double sum = 0;
// make sure that this assignment has an entry for all the values this node can take one
for (unsigned long value = 0; value < num_values; ++value)
{
if (table.has_entry_for(value,a) == false)
return false;
else
sum += table.probability(value,a);
}
// check if the sum of probabilities equals 1 as it should
if (std::abs(sum-1.0) > 1e-5)
return false;
} while (node_next_parent_assignment(bn,n,a));
return true;
}
}
// ----------------------------------------------------------------------------------------
class bayesian_network_gibbs_sampler : noncopyable
{
public:
bayesian_network_gibbs_sampler ()
{
rnd.set_seed(cast_to_string(std::time(0)));
}
template <
typename T
>
void sample_graph (
T& bn
)
{
using namespace bayes_node_utils;
for (unsigned long n = 0; n < bn.number_of_nodes(); ++n)
{
if (node_is_evidence(bn, n))
continue;
samples.set_size(node_num_values(bn,n));
// obtain the probability distribution for this node
for (long i = 0; i < samples.nc(); ++i)
{
set_node_value(bn, n, i);
samples(i) = node_probability(bn, n);
for (unsigned long j = 0; j < bn.node(n).number_of_children(); ++j)
samples(i) *= node_probability(bn, bn.node(n).child(j).index());
}
//normalize samples
samples /= sum(samples);
// select a random point in the probability distribution
double prob = rnd.get_random_double();
// now find the point in the distribution this probability corresponds to
long j;
for (j = 0; j < samples.nc()-1; ++j)
{
if (prob <= samples(j))
break;
else
prob -= samples(j);
}
set_node_value(bn, n, j);
}
}
private:
template <
typename T
>
double node_probability (
const T& bn,
unsigned long n
)
/*!
requires
- n < bn.number_of_nodes()
ensures
- computes the probability of node n having its current value given
the current values of its parents in the network bn
!*/
{
v.clear();
for (unsigned long i = 0; i < bn.node(n).number_of_parents(); ++i)
{
v.add(bn.node(n).parent(i).index(), bn.node(n).parent(i).data.value());
}
return bn.node(n).data.table().probability(bn.node(n).data.value(), v);
}
assignment v;
dlib::rand rnd;
matrix<double,1> samples;
};
// ----------------------------------------------------------------------------------------
namespace bayesian_network_join_tree_helpers
{
class bnjt
{
/*!
this object is the base class used in this pimpl idiom
!*/
public:
virtual ~bnjt() {}
virtual const matrix<double,1> probability(
unsigned long idx
) const = 0;
};
template <typename T, typename U>
class bnjt_impl : public bnjt
{
/*!
This object is the implementation in the pimpl idiom
!*/
public:
bnjt_impl (
const T& bn,
const U& join_tree
)
{
create_bayesian_network_join_tree(bn, join_tree, join_tree_values);
cliques.resize(bn.number_of_nodes());
// figure out which cliques contain each node
for (unsigned long i = 0; i < cliques.size(); ++i)
{
// find the smallest clique that contains node with index i
unsigned long smallest_clique = 0;
unsigned long size = std::numeric_limits<unsigned long>::max();
for (unsigned long n = 0; n < join_tree.number_of_nodes(); ++n)
{
if (join_tree.node(n).data.is_member(i) && join_tree.node(n).data.size() < size)
{
size = join_tree.node(n).data.size();
smallest_clique = n;
}
}
cliques[i] = smallest_clique;
}
}
virtual const matrix<double,1> probability(
unsigned long idx
) const
{
join_tree_values.node(cliques[idx]).data.marginalize(idx, table);
table.normalize();
var.clear();
var.add(idx);
dist.set_size(table.size());
// read the probabilities out of the table and into the row matrix
for (unsigned long i = 0; i < table.size(); ++i)
{
var[idx] = i;
dist(i) = table.probability(var);
}
return dist;
}
private:
graph< joint_probability_table, joint_probability_table >::kernel_1a_c join_tree_values;
array<unsigned long> cliques;
mutable joint_probability_table table;
mutable assignment var;
mutable matrix<double,1> dist;
// ----------------------------------------------------------------------------------------
template <typename set_type, typename node_type>
bool set_contains_all_parents_of_node (
const set_type& set,
const node_type& node
)
{
for (unsigned long i = 0; i < node.number_of_parents(); ++i)
{
if (set.is_member(node.parent(i).index()) == false)
return false;
}
return true;
}
// ----------------------------------------------------------------------------------------
template <
typename V
>
void pass_join_tree_message (
const U& join_tree,
V& bn_join_tree ,
unsigned long from,
unsigned long to
)
{
using namespace bayes_node_utils;
const typename U::edge_type& e = edge(join_tree, from, to);
typename V::edge_type& old_s = edge(bn_join_tree, from, to);
typedef typename V::edge_type joint_prob_table;
joint_prob_table new_s;
bn_join_tree.node(from).data.marginalize(e, new_s);
joint_probability_table temp(new_s);
// divide new_s by old_s and store the result in temp.
// if old_s is empty then that is the same as if it was all 1s
// so we don't have to do this if that is the case.
if (old_s.size() > 0)
{
temp.reset();
old_s.reset();
while (temp.move_next())
{
old_s.move_next();
if (old_s.element().value() != 0)
temp.element().value() /= old_s.element().value();
}
}
// now multiply temp by d and store the results in d
joint_probability_table& d = bn_join_tree.node(to).data;
d.reset();
while (d.move_next())
{
assignment a;
const assignment& asrc = d.element().key();
asrc.reset();
while (asrc.move_next())
{
if (e.is_member(asrc.element().key()))
a.add(asrc.element().key(), asrc.element().value());
}
d.element().value() *= temp.probability(a);
}
// store new_s in old_s
new_s.swap(old_s);
}
// ----------------------------------------------------------------------------------------
template <
typename V
>
void create_bayesian_network_join_tree (
const T& bn,
const U& join_tree,
V& bn_join_tree
)
/*!
requires
- bn is a proper bayesian network
- join_tree is the join tree for that bayesian network
ensures
- bn_join_tree == the output of the join tree algorithm for bayesian network inference.
So each node in this graph contains a joint_probability_table for the clique
in the corresponding node in the join_tree graph.
!*/
{
using namespace bayes_node_utils;
bn_join_tree.clear();
copy_graph_structure(join_tree, bn_join_tree);
// we need to keep track of which node is "in" each clique for the purposes of
// initializing the tables in each clique. So this vector will be used to do that
// and a value of join_tree.number_of_nodes() means that the node with
// that index is unassigned.
std::vector<unsigned long> node_assigned_to(bn.number_of_nodes(),join_tree.number_of_nodes());
// populate evidence with all the evidence node indices and their values
dlib::map<unsigned long, unsigned long>::kernel_1b_c evidence;
for (unsigned long i = 0; i < bn.number_of_nodes(); ++i)
{
if (node_is_evidence(bn, i))
{
unsigned long idx = i;
unsigned long value = node_value(bn, i);
evidence.add(idx,value);
}
}
// initialize the bn join tree
for (unsigned long i = 0; i < join_tree.number_of_nodes(); ++i)
{
bool contains_evidence = false;
std::vector<unsigned long> indices;
assignment value;
// loop over all the nodes in this clique in the join tree. In this loop
// we are making an assignment with all the values of the nodes it represents set to 0
join_tree.node(i).data.reset();
while (join_tree.node(i).data.move_next())
{
const unsigned long idx = join_tree.node(i).data.element();
indices.push_back(idx);
value.add(idx);
if (evidence.is_in_domain(join_tree.node(i).data.element()))
contains_evidence = true;
}
// now loop over all possible combinations of values that the nodes this
// clique in the join tree can take on. We do this by counting by one through all
// legal values
bool more_assignments = true;
while (more_assignments)
{
bn_join_tree.node(i).data.set_probability(value,1);
// account for any evidence
if (contains_evidence)
{
// loop over all the nodes in this cluster
for (unsigned long j = 0; j < indices.size(); ++j)
{
// if the current node is an evidence node
if (evidence.is_in_domain(indices[j]))
{
const unsigned long idx = indices[j];
const unsigned long evidence_value = evidence[idx];
if (value[idx] != evidence_value)
bn_join_tree.node(i).data.set_probability(value , 0);
}
}
}
// now check if any of the nodes in this cluster also have their parents in this cluster
join_tree.node(i).data.reset();
while (join_tree.node(i).data.move_next())
{
const unsigned long idx = join_tree.node(i).data.element();
// if this clique contains all the parents of this node and also hasn't
// been assigned to another clique
if (set_contains_all_parents_of_node(join_tree.node(i).data, bn.node(idx)) &&
(i == node_assigned_to[idx] || node_assigned_to[idx] == join_tree.number_of_nodes()) )
{
// note that this node is now assigned to this clique
node_assigned_to[idx] = i;
// node idx has all its parents in the cluster
assignment parent_values;
for (unsigned long j = 0; j < bn.node(idx).number_of_parents(); ++j)
{
const unsigned long pidx = bn.node(idx).parent(j).index();
parent_values.add(pidx, value[pidx]);
}
double temp = bn_join_tree.node(i).data.probability(value);
bn_join_tree.node(i).data.set_probability(value, temp * node_probability(bn, idx, value[idx], parent_values));
}
}
// now advance the value variable to its next possible state if there is one
more_assignments = false;
value.reset();
while (value.move_next())
{
value.element().value() += 1;
// if overflow
if (value.element().value() == node_num_values(bn, value.element().key()))
{
value.element().value() = 0;
}
else
{
more_assignments = true;
break;
}
}
} // end while (more_assignments)
}
// the tree is now initialized. Now all we need to do is perform the propagation and
// we are done
dlib::array<dlib::set<unsigned long>::compare_1b_c> remaining_msg_to_send;
dlib::array<dlib::set<unsigned long>::compare_1b_c> remaining_msg_to_receive;
remaining_msg_to_receive.resize(join_tree.number_of_nodes());
remaining_msg_to_send.resize(join_tree.number_of_nodes());
for (unsigned long i = 0; i < remaining_msg_to_receive.size(); ++i)
{
for (unsigned long j = 0; j < join_tree.node(i).number_of_neighbors(); ++j)
{
const unsigned long idx = join_tree.node(i).neighbor(j).index();
unsigned long temp;
temp = idx; remaining_msg_to_receive[i].add(temp);
temp = idx; remaining_msg_to_send[i].add(temp);
}
}
// now remaining_msg_to_receive[i] contains all the nodes that node i hasn't yet received
// a message from.
// we will consider node 0 to be the root node.
bool message_sent = true;
std::vector<unsigned long>::iterator iter;
while (message_sent)
{
message_sent = false;
for (unsigned long i = 1; i < remaining_msg_to_send.size(); ++i)
{
// if node i hasn't sent any messages but has received all but one then send a message to the one
// node who hasn't sent i a message
if (remaining_msg_to_send[i].size() == join_tree.node(i).number_of_neighbors() && remaining_msg_to_receive[i].size() == 1)
{
unsigned long to;
// get the last remaining thing from this set
remaining_msg_to_receive[i].remove_any(to);
// send the message
pass_join_tree_message(join_tree, bn_join_tree, i, to);
// record that we sent this message
remaining_msg_to_send[i].destroy(to);
remaining_msg_to_receive[to].destroy(i);
// put to back in since we still need to receive it
remaining_msg_to_receive[i].add(to);
message_sent = true;
}
else if (remaining_msg_to_receive[i].size() == 0 && remaining_msg_to_send[i].size() > 0)
{
unsigned long to;
remaining_msg_to_send[i].remove_any(to);
remaining_msg_to_receive[to].destroy(i);
pass_join_tree_message(join_tree, bn_join_tree, i, to);
message_sent = true;
}
}
if (remaining_msg_to_receive[0].size() == 0)
{
// send a message to all of the root nodes neighbors unless we have already sent out he messages
while (remaining_msg_to_send[0].size() > 0)
{
unsigned long to;
remaining_msg_to_send[0].remove_any(to);
remaining_msg_to_receive[to].destroy(0);
pass_join_tree_message(join_tree, bn_join_tree, 0, to);
message_sent = true;
}
}
}
}
};
}
class bayesian_network_join_tree : noncopyable
{
/*!
use the pimpl idiom to push the template arguments from the class level to the
constructor level
!*/
public:
template <
typename T,
typename U
>
bayesian_network_join_tree (
const T& bn,
const U& join_tree
)
{
// make sure requires clause is not broken
DLIB_ASSERT( bn.number_of_nodes() > 0 ,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid bayesian network"
<< "\n\tthis: " << this
);
DLIB_ASSERT( is_join_tree(bn, join_tree) == true ,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid join tree for the supplied bayesian network"
<< "\n\tthis: " << this
);
DLIB_ASSERT( graph_contains_length_one_cycle(bn) == false,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid bayesian network"
<< "\n\tthis: " << this
);
DLIB_ASSERT( graph_is_connected(bn) == true,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid bayesian network"
<< "\n\tthis: " << this
);
#ifdef ENABLE_ASSERTS
for (unsigned long i = 0; i < bn.number_of_nodes(); ++i)
{
DLIB_ASSERT(bayes_node_utils::node_cpt_filled_out(bn,i) == true,
"\tbayesian_network_join_tree::bayesian_network_join_tree(bn,join_tree)"
<< "\n\tYou have given an invalid bayesian network. "
<< "\n\tYou must finish filling out the conditional_probability_table of node " << i
<< "\n\tthis: " << this
);
}
#endif
impl.reset(new bayesian_network_join_tree_helpers::bnjt_impl<T,U>(bn, join_tree));
num_nodes = bn.number_of_nodes();
}
const matrix<double,1> probability(
unsigned long idx
) const
{
// make sure requires clause is not broken
DLIB_ASSERT( idx < number_of_nodes() ,
"\tconst matrix<double,1> bayesian_network_join_tree::probability(idx)"
<< "\n\tYou have specified an invalid node index"
<< "\n\tidx: " << idx
<< "\n\tnumber_of_nodes(): " << number_of_nodes()
<< "\n\tthis: " << this
);
return impl->probability(idx);
}
unsigned long number_of_nodes (
) const { return num_nodes; }
void swap (
bayesian_network_join_tree& item
)
{
exchange(num_nodes, item.num_nodes);
impl.swap(item.impl);
}
private:
std::unique_ptr<bayesian_network_join_tree_helpers::bnjt> impl;
unsigned long num_nodes;
};
inline void swap (
bayesian_network_join_tree& a,
bayesian_network_join_tree& b
) { a.swap(b); }
}
// ----------------------------------------------------------------------------------------
#endif // DLIB_BAYES_UTILs_