|
|
|
|
|
#ifndef DLIB_BSP_CPph_ |
|
#define DLIB_BSP_CPph_ |
|
|
|
#include "bsp.h" |
|
#include <memory> |
|
#include <stack> |
|
|
|
|
|
|
|
|
|
namespace dlib |
|
{ |
|
|
|
namespace impl1 |
|
{ |
|
|
|
void connect_all ( |
|
map_id_to_con& cons, |
|
const std::vector<network_address>& hosts, |
|
unsigned long node_id |
|
) |
|
{ |
|
cons.clear(); |
|
for (unsigned long i = 0; i < hosts.size(); ++i) |
|
{ |
|
std::unique_ptr<bsp_con> con(new bsp_con(hosts[i])); |
|
dlib::serialize(node_id, con->stream); |
|
unsigned long id = i+1; |
|
cons.add(id, con); |
|
} |
|
} |
|
|
|
void connect_all_hostinfo ( |
|
map_id_to_con& cons, |
|
const std::vector<hostinfo>& hosts, |
|
unsigned long node_id, |
|
std::string& error_string |
|
) |
|
{ |
|
cons.clear(); |
|
for (unsigned long i = 0; i < hosts.size(); ++i) |
|
{ |
|
try |
|
{ |
|
std::unique_ptr<bsp_con> con(new bsp_con(hosts[i].addr)); |
|
dlib::serialize(node_id, con->stream); |
|
con->stream.flush(); |
|
unsigned long id = hosts[i].node_id; |
|
cons.add(id, con); |
|
} |
|
catch (std::exception&) |
|
{ |
|
std::ostringstream sout; |
|
sout << "Could not connect to " << hosts[i].addr; |
|
error_string = sout.str(); |
|
break; |
|
} |
|
} |
|
} |
|
|
|
|
|
void send_out_connection_orders ( |
|
map_id_to_con& cons, |
|
const std::vector<network_address>& hosts |
|
) |
|
{ |
|
|
|
cons.reset(); |
|
while (cons.move_next()) |
|
{ |
|
dlib::serialize(cons.element().key(), cons.element().value()->stream); |
|
} |
|
|
|
|
|
std::vector<hostinfo> targets; |
|
for (unsigned long i = 0; i < hosts.size(); ++i) |
|
{ |
|
hostinfo info(hosts[i], i+1); |
|
|
|
dlib::serialize(targets, cons[info.node_id]->stream); |
|
targets.push_back(info); |
|
|
|
|
|
const unsigned long num = hosts.size()-targets.size(); |
|
dlib::serialize(num, cons[info.node_id]->stream); |
|
cons[info.node_id]->stream.flush(); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
namespace impl2 |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const static char MESSAGE_HEADER = 0; |
|
|
|
|
|
const static char GOT_MESSAGE = 1; |
|
|
|
|
|
const static char SENT_MESSAGE = 2; |
|
|
|
|
|
const static char IN_WAITING_STATE = 3; |
|
|
|
|
|
const static char NODE_TERMINATE = 5; |
|
|
|
|
|
|
|
|
|
const static char SEE_ALL_IN_WAITING_STATE = 6; |
|
|
|
|
|
|
|
const static char READ_ERROR = 7; |
|
|
|
|
|
|
|
void read_thread ( |
|
impl1::bsp_con* con, |
|
unsigned long node_id, |
|
unsigned long sender_id, |
|
impl1::thread_safe_message_queue& msg_buffer |
|
) |
|
{ |
|
try |
|
{ |
|
while(true) |
|
{ |
|
impl1::msg_data msg; |
|
deserialize(msg.msg_type, con->stream); |
|
msg.sender_id = sender_id; |
|
|
|
if (msg.msg_type == MESSAGE_HEADER) |
|
{ |
|
msg.data.reset(new std::vector<char>); |
|
deserialize(msg.epoch, con->stream); |
|
deserialize(*msg.data, con->stream); |
|
} |
|
|
|
msg_buffer.push_and_consume(msg); |
|
|
|
if (msg.msg_type == NODE_TERMINATE) |
|
break; |
|
} |
|
} |
|
catch (std::exception& e) |
|
{ |
|
impl1::msg_data msg; |
|
msg.data.reset(new std::vector<char>); |
|
vectorstream sout(*msg.data); |
|
sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; |
|
sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; |
|
sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; |
|
sout << " Receiving processing node id: " << node_id << std::endl; |
|
sout << " Error message in the exception: " << e.what() << std::endl; |
|
|
|
msg.sender_id = sender_id; |
|
msg.msg_type = READ_ERROR; |
|
|
|
msg_buffer.push_and_consume(msg); |
|
} |
|
catch (...) |
|
{ |
|
impl1::msg_data msg; |
|
msg.data.reset(new std::vector<char>); |
|
vectorstream sout(*msg.data); |
|
sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; |
|
sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; |
|
sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; |
|
sout << " Receiving processing node id: " << node_id << std::endl; |
|
|
|
msg.sender_id = sender_id; |
|
msg.msg_type = READ_ERROR; |
|
|
|
msg_buffer.push_and_consume(msg); |
|
} |
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void bsp_context:: |
|
close_all_connections_gracefully( |
|
) |
|
{ |
|
if (node_id() != 0) |
|
{ |
|
_cons.reset(); |
|
while (_cons.move_next()) |
|
{ |
|
|
|
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); |
|
_cons.element().value()->stream.flush(); |
|
} |
|
} |
|
|
|
impl1::msg_data msg; |
|
|
|
while (num_terminated_nodes < _cons.size() ) |
|
{ |
|
if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0) |
|
{ |
|
num_waiting_nodes = 0; |
|
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); |
|
++current_epoch; |
|
} |
|
|
|
if (!msg_buffer.pop(msg)) |
|
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); |
|
|
|
if (msg.msg_type == impl2::NODE_TERMINATE) |
|
{ |
|
++num_terminated_nodes; |
|
_cons[msg.sender_id]->terminated = true; |
|
} |
|
else if (msg.msg_type == impl2::READ_ERROR) |
|
{ |
|
throw dlib::socket_error(msg.data_to_string()); |
|
} |
|
else if (msg.msg_type == impl2::MESSAGE_HEADER) |
|
{ |
|
throw dlib::socket_error("A BSP node received a message after it has terminated."); |
|
} |
|
else if (msg.msg_type == impl2::GOT_MESSAGE) |
|
{ |
|
--num_waiting_nodes; |
|
--outstanding_messages; |
|
} |
|
else if (msg.msg_type == impl2::SENT_MESSAGE) |
|
{ |
|
++outstanding_messages; |
|
} |
|
else if (msg.msg_type == impl2::IN_WAITING_STATE) |
|
{ |
|
++num_waiting_nodes; |
|
} |
|
} |
|
|
|
if (node_id() == 0) |
|
{ |
|
_cons.reset(); |
|
while (_cons.move_next()) |
|
{ |
|
|
|
serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); |
|
_cons.element().value()->stream.flush(); |
|
} |
|
|
|
if (outstanding_messages != 0) |
|
{ |
|
std::ostringstream sout; |
|
sout << "A BSP job was allowed to terminate before all sent messages have been received.\n"; |
|
sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n"; |
|
sout << "have a corresponding call to receive()."; |
|
throw dlib::socket_error(sout.str()); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
bsp_context:: |
|
~bsp_context() |
|
{ |
|
_cons.reset(); |
|
while (_cons.move_next()) |
|
{ |
|
_cons.element().value()->con->shutdown(); |
|
} |
|
|
|
msg_buffer.disable(); |
|
|
|
|
|
threads.clear(); |
|
} |
|
|
|
|
|
|
|
bsp_context:: |
|
bsp_context( |
|
unsigned long node_id_, |
|
impl1::map_id_to_con& cons_ |
|
) : |
|
outstanding_messages(0), |
|
num_waiting_nodes(0), |
|
num_terminated_nodes(0), |
|
current_epoch(1), |
|
_cons(cons_), |
|
_node_id(node_id_) |
|
{ |
|
|
|
_cons.reset(); |
|
while (_cons.move_next()) |
|
{ |
|
std::unique_ptr<thread_function> ptr(new thread_function(&impl2::read_thread, |
|
_cons.element().value().get(), |
|
_node_id, |
|
_cons.element().key(), |
|
ref(msg_buffer))); |
|
threads.push_back(ptr); |
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool bsp_context:: |
|
receive_data ( |
|
std::shared_ptr<std::vector<char> >& item, |
|
unsigned long& sending_node_id |
|
) |
|
{ |
|
notify_control_node(impl2::IN_WAITING_STATE); |
|
|
|
while (true) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0) |
|
{ |
|
return false; |
|
} |
|
|
|
|
|
if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size()) |
|
{ |
|
num_waiting_nodes = 0; |
|
broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
++current_epoch; |
|
return false; |
|
} |
|
|
|
impl1::msg_data data; |
|
if (!msg_buffer.pop(data, current_epoch)) |
|
throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); |
|
|
|
|
|
switch(data.msg_type) |
|
{ |
|
case impl2::MESSAGE_HEADER: { |
|
item = data.data; |
|
sending_node_id = data.sender_id; |
|
notify_control_node(impl2::GOT_MESSAGE); |
|
return true; |
|
} break; |
|
|
|
case impl2::IN_WAITING_STATE: { |
|
++num_waiting_nodes; |
|
} break; |
|
|
|
case impl2::GOT_MESSAGE: { |
|
--outstanding_messages; |
|
--num_waiting_nodes; |
|
} break; |
|
|
|
case impl2::SENT_MESSAGE: { |
|
++outstanding_messages; |
|
} break; |
|
|
|
case impl2::NODE_TERMINATE: { |
|
++num_terminated_nodes; |
|
_cons[data.sender_id]->terminated = true; |
|
} break; |
|
|
|
case impl2::SEE_ALL_IN_WAITING_STATE: { |
|
++current_epoch; |
|
return false; |
|
} break; |
|
|
|
case impl2::READ_ERROR: { |
|
throw dlib::socket_error(data.data_to_string()); |
|
} break; |
|
|
|
default: { |
|
throw dlib::socket_error("Unknown message received by dlib::bsp_context"); |
|
} break; |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
void bsp_context:: |
|
notify_control_node ( |
|
char val |
|
) |
|
{ |
|
if (node_id() == 0) |
|
{ |
|
using namespace impl2; |
|
switch(val) |
|
{ |
|
case SENT_MESSAGE: { |
|
++outstanding_messages; |
|
} break; |
|
|
|
case GOT_MESSAGE: { |
|
--outstanding_messages; |
|
} break; |
|
|
|
case IN_WAITING_STATE: { |
|
|
|
} break; |
|
|
|
default: |
|
DLIB_CASSERT(false,"This should never happen"); |
|
} |
|
} |
|
else |
|
{ |
|
serialize(val, _cons[0]->stream); |
|
_cons[0]->stream.flush(); |
|
} |
|
} |
|
|
|
|
|
|
|
void bsp_context:: |
|
broadcast_byte ( |
|
char val |
|
) |
|
{ |
|
for (unsigned long i = 0; i < number_of_nodes(); ++i) |
|
{ |
|
|
|
if (i == node_id() || _cons[i]->terminated) |
|
continue; |
|
|
|
serialize(val, _cons[i]->stream); |
|
_cons[i]->stream.flush(); |
|
} |
|
} |
|
|
|
|
|
|
|
void bsp_context:: |
|
send_data( |
|
const std::vector<char>& item, |
|
unsigned long target_node_id |
|
) |
|
{ |
|
using namespace impl2; |
|
if (_cons[target_node_id]->terminated) |
|
throw socket_error("Attempt to send a message to a node that has terminated."); |
|
|
|
serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); |
|
serialize(current_epoch, _cons[target_node_id]->stream); |
|
serialize(item, _cons[target_node_id]->stream); |
|
_cons[target_node_id]->stream.flush(); |
|
|
|
notify_control_node(SENT_MESSAGE); |
|
} |
|
|
|
|
|
|
|
} |
|
|
|
#endif |
|
|
|
|