AshanGimhana's picture
Upload folder using huggingface_hub
9375c9a verified
raw
history blame
30.8 kB
// Copyright (C) 2012 Davis E. King ([email protected])
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_BsP_Hh_
#define DLIB_BsP_Hh_
#include "bsp_abstract.h"
#include <memory>
#include <queue>
#include <vector>
#include "../sockets.h"
#include "../array.h"
#include "../sockstreambuf.h"
#include "../string.h"
#include "../serialize.h"
#include "../map.h"
#include "../ref.h"
#include "../vectorstream.h"
namespace dlib
{
// ----------------------------------------------------------------------------------------
namespace impl1
{
inline void null_notify(
unsigned short
) {}
struct bsp_con
{
bsp_con(
const network_address& dest
) :
con(connect(dest)),
buf(con),
stream(&buf),
terminated(false)
{
con->disable_nagle();
}
bsp_con(
std::unique_ptr<connection>& conptr
) :
buf(conptr),
stream(&buf),
terminated(false)
{
// make sure we own the connection
conptr.swap(con);
con->disable_nagle();
}
std::unique_ptr<connection> con;
sockstreambuf buf;
std::iostream stream;
bool terminated;
};
typedef dlib::map<unsigned long, std::unique_ptr<bsp_con> >::kernel_1a_c map_id_to_con;
void connect_all (
map_id_to_con& cons,
const std::vector<network_address>& hosts,
unsigned long node_id
);
/*!
ensures
- creates connections to all the given hosts and stores them into cons
!*/
void send_out_connection_orders (
map_id_to_con& cons,
const std::vector<network_address>& hosts
);
// ------------------------------------------------------------------------------------
struct hostinfo
{
hostinfo() {}
hostinfo (
const network_address& addr_,
unsigned long node_id_
) :
addr(addr_),
node_id(node_id_)
{
}
network_address addr;
unsigned long node_id;
};
inline void serialize (
const hostinfo& item,
std::ostream& out
)
{
dlib::serialize(item.addr, out);
dlib::serialize(item.node_id, out);
}
inline void deserialize (
hostinfo& item,
std::istream& in
)
{
dlib::deserialize(item.addr, in);
dlib::deserialize(item.node_id, in);
}
// ------------------------------------------------------------------------------------
void connect_all_hostinfo (
map_id_to_con& cons,
const std::vector<hostinfo>& hosts,
unsigned long node_id,
std::string& error_string
);
// ------------------------------------------------------------------------------------
template <
typename port_notify_function_type
>
void listen_and_connect_all(
unsigned long& node_id,
map_id_to_con& cons,
unsigned short port,
port_notify_function_type port_notify_function
)
{
cons.clear();
std::unique_ptr<listener> list;
const int status = create_listener(list, port);
if (status == PORTINUSE)
{
throw socket_error("Unable to create listening port " + cast_to_string(port) +
". The port is already in use");
}
else if (status != 0)
{
throw socket_error("Unable to create listening port " + cast_to_string(port) );
}
port_notify_function(list->get_listening_port());
std::unique_ptr<connection> con;
if (list->accept(con))
{
throw socket_error("Error occurred while accepting new connection");
}
std::unique_ptr<bsp_con> temp(new bsp_con(con));
unsigned long remote_node_id;
dlib::deserialize(remote_node_id, temp->stream);
dlib::deserialize(node_id, temp->stream);
std::vector<hostinfo> targets;
dlib::deserialize(targets, temp->stream);
unsigned long num_incoming_connections;
dlib::deserialize(num_incoming_connections, temp->stream);
cons.add(remote_node_id,temp);
// make a thread that will connect to all the targets
map_id_to_con cons2;
std::string error_string;
thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string));
if (error_string.size() != 0)
throw socket_error(error_string);
// accept any incoming connections
for (unsigned long i = 0; i < num_incoming_connections; ++i)
{
// If it takes more than 10 seconds for the other nodes to connect to us
// then something has gone horribly wrong and it almost certainly will
// never connect at all. So just give up if that happens.
const unsigned long timeout_milliseconds = 10000;
if (list->accept(con, timeout_milliseconds))
{
throw socket_error("Error occurred while accepting new connection");
}
temp.reset(new bsp_con(con));
dlib::deserialize(remote_node_id, temp->stream);
cons.add(remote_node_id,temp);
}
// put all the connections created by the thread into cons
thread.wait();
while (cons2.size() > 0)
{
unsigned long id;
std::unique_ptr<bsp_con> temp;
cons2.remove_any(id,temp);
cons.add(id,temp);
}
}
// ------------------------------------------------------------------------------------
struct msg_data
{
std::shared_ptr<std::vector<char> > data;
unsigned long sender_id;
char msg_type;
dlib::uint64 epoch;
msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {}
std::string data_to_string() const
{
if (data && data->size() != 0)
return std::string(&(*data)[0], data->size());
else
return "";
}
};
// ------------------------------------------------------------------------------------
class thread_safe_message_queue : noncopyable
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a simple message queue for msg_data objects. Note that it
has the special property that, while messages will generally leave
the queue in the order they are inserted, any message with a smaller
epoch value will always be popped out first. But for all messages
with equal epoch values the queue functions as a normal FIFO queue.
!*/
private:
struct msg_wrap
{
msg_wrap(
const msg_data& data_,
const dlib::uint64& sequence_number_
) : data(data_), sequence_number(sequence_number_) {}
msg_wrap() : sequence_number(0){}
msg_data data;
dlib::uint64 sequence_number;
// Make it so that when msg_wrap objects are in a std::priority_queue,
// messages with a smaller epoch number always come first. Then, within an
// epoch, messages are ordered by their sequence number (so smaller first
// there as well).
bool operator<(const msg_wrap& item) const
{
if (data.epoch < item.data.epoch)
{
return false;
}
else if (data.epoch > item.data.epoch)
{
return true;
}
else
{
if (sequence_number < item.sequence_number)
return false;
else
return true;
}
}
};
public:
thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {}
~thread_safe_message_queue()
{
disable();
}
void disable()
{
auto_mutex lock(class_mutex);
disabled = true;
sig.broadcast();
}
unsigned long size() const
{
auto_mutex lock(class_mutex);
return data.size();
}
void push_and_consume( msg_data& item)
{
auto_mutex lock(class_mutex);
data.push(msg_wrap(item, next_seq_num++));
// do this here so that we don't have to worry about different threads touching the shared_ptr.
item.data.reset();
sig.signal();
}
bool pop (
msg_data& item
)
/*!
ensures
- if (this function returns true) then
- #item == the next thing from the queue
- else
- this object is disabled
!*/
{
auto_mutex lock(class_mutex);
while (data.size() == 0 && !disabled)
sig.wait();
if (disabled)
return false;
item = data.top().data;
data.pop();
return true;
}
bool pop (
msg_data& item,
const dlib::uint64& max_epoch
)
/*!
ensures
- if (this function returns true) then
- #item == the next thing from the queue that has an epoch <= max_epoch
- else
- this object is disabled
!*/
{
auto_mutex lock(class_mutex);
while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled)
sig.wait();
if (disabled)
return false;
item = data.top().data;
data.pop();
return true;
}
private:
std::priority_queue<msg_wrap> data;
dlib::mutex class_mutex;
dlib::signaler sig;
bool disabled;
dlib::uint64 next_seq_num;
};
}
// ----------------------------------------------------------------------------------------
class bsp_context : noncopyable
{
public:
template <typename T>
void send(
const T& item,
unsigned long target_node_id
)
{
// make sure requires clause is not broken
DLIB_CASSERT(target_node_id < number_of_nodes() &&
target_node_id != node_id(),
"\t void bsp_context::send()"
<< "\n\t Invalid arguments were given to this function."
<< "\n\t target_node_id: " << target_node_id
<< "\n\t node_id(): " << node_id()
<< "\n\t number_of_nodes(): " << number_of_nodes()
<< "\n\t this: " << this
);
std::vector<char> buf;
vectorstream sout(buf);
serialize(item, sout);
send_data(buf, target_node_id);
}
template <typename T>
void broadcast (
const T& item
)
{
std::vector<char> buf;
vectorstream sout(buf);
serialize(item, sout);
for (unsigned long i = 0; i < number_of_nodes(); ++i)
{
// Don't send to yourself.
if (i == node_id())
continue;
send_data(buf, i);
}
}
unsigned long node_id (
) const { return _node_id; }
unsigned long number_of_nodes (
) const { return _cons.size()+1; }
void receive (
)
{
unsigned long id;
std::shared_ptr<std::vector<char> > temp;
if (receive_data(temp,id))
throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message.");
}
template <typename T>
void receive (
T& item
)
{
if(!try_receive(item))
throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked.");
}
template <typename T>
bool try_receive (
T& item
)
{
unsigned long sending_node_id;
return try_receive(item, sending_node_id);
}
template <typename T>
void receive (
T& item,
unsigned long& sending_node_id
)
{
if(!try_receive(item, sending_node_id))
throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked.");
}
template <typename T>
bool try_receive (
T& item,
unsigned long& sending_node_id
)
{
std::shared_ptr<std::vector<char> > temp;
if (receive_data(temp, sending_node_id))
{
vectorstream sin(*temp);
deserialize(item, sin);
if (sin.peek() != EOF)
throw serialization_error("deserialize() did not consume all bytes produced by serialize(). "
"This probably means you are calling a receive method with a different type "
"of object than the one which was sent.");
return true;
}
else
{
return false;
}
}
~bsp_context();
private:
bsp_context();
bsp_context(
unsigned long node_id_,
impl1::map_id_to_con& cons_
);
void close_all_connections_gracefully();
/*!
ensures
- closes all the connections to other nodes and lets them know that
we are terminating normally rather than as the result of some kind
of error.
!*/
bool receive_data (
std::shared_ptr<std::vector<char> >& item,
unsigned long& sending_node_id
);
void notify_control_node (
char val
);
void broadcast_byte (
char val
);
void send_data(
const std::vector<char>& item,
unsigned long target_node_id
);
/*!
requires
- target_node_id < number_of_nodes()
- target_node_id != node_id()
ensures
- sends a copy of item to the node with the given id.
!*/
unsigned long outstanding_messages;
unsigned long num_waiting_nodes;
unsigned long num_terminated_nodes;
dlib::uint64 current_epoch;
impl1::thread_safe_message_queue msg_buffer;
impl1::map_id_to_con& _cons;
const unsigned long _node_id;
array<std::unique_ptr<thread_function> > threads;
// -----------------------------------
template <
typename funct_type
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct
);
template <
typename funct_type,
typename ARG1
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1
);
template <
typename funct_type,
typename ARG1,
typename ARG2
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2
);
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
);
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
friend void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
);
// -----------------------------------
template <
typename port_notify_function_type,
typename funct_type
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct
);
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1
);
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2
);
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
);
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
friend void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
);
// -----------------------------------
};
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename funct_type
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj,arg1);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2,arg3);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_connect (
const std::vector<network_address>& hosts,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
)
{
impl1::map_id_to_con cons;
const unsigned long node_id = 0;
connect_all(cons, hosts, node_id);
send_out_connection_orders(cons, hosts);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2,arg3,arg4);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename funct_type
>
void bsp_listen (
unsigned short listening_port,
funct_type funct
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct);
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1);
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2);
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3);
}
// ----------------------------------------------------------------------------------------
template <
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_listen (
unsigned short listening_port,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
)
{
// make sure requires clause is not broken
DLIB_CASSERT(listening_port != 0,
"\t void bsp_listen()"
<< "\n\t Invalid arguments were given to this function."
);
bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4);
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj,arg1);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2,arg3);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
template <
typename port_notify_function_type,
typename funct_type,
typename ARG1,
typename ARG2,
typename ARG3,
typename ARG4
>
void bsp_listen_dynamic_port (
unsigned short listening_port,
port_notify_function_type port_notify_function,
funct_type funct,
ARG1 arg1,
ARG2 arg2,
ARG3 arg3,
ARG4 arg4
)
{
impl1::map_id_to_con cons;
unsigned long node_id;
listen_and_connect_all(node_id, cons, listening_port, port_notify_function);
bsp_context obj(node_id, cons);
funct(obj,arg1,arg2,arg3,arg4);
obj.close_all_connections_gracefully();
}
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------------
}
#ifdef NO_MAKEFILE
#include "bsp.cpp"
#endif
#endif // DLIB_BsP_Hh_