// Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BSP_CPph_ #define DLIB_BSP_CPph_ #include "bsp.h" #include <memory> #include <stack> // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- namespace dlib { namespace impl1 { void connect_all ( map_id_to_con& cons, const std::vector<network_address>& hosts, unsigned long node_id ) { cons.clear(); for (unsigned long i = 0; i < hosts.size(); ++i) { std::unique_ptr<bsp_con> con(new bsp_con(hosts[i])); dlib::serialize(node_id, con->stream); // tell the other end our node_id unsigned long id = i+1; cons.add(id, con); } } void connect_all_hostinfo ( map_id_to_con& cons, const std::vector<hostinfo>& hosts, unsigned long node_id, std::string& error_string ) { cons.clear(); for (unsigned long i = 0; i < hosts.size(); ++i) { try { std::unique_ptr<bsp_con> con(new bsp_con(hosts[i].addr)); dlib::serialize(node_id, con->stream); // tell the other end our node_id con->stream.flush(); unsigned long id = hosts[i].node_id; cons.add(id, con); } catch (std::exception&) { std::ostringstream sout; sout << "Could not connect to " << hosts[i].addr; error_string = sout.str(); break; } } } void send_out_connection_orders ( map_id_to_con& cons, const std::vector<network_address>& hosts ) { // tell everyone their node ids cons.reset(); while (cons.move_next()) { dlib::serialize(cons.element().key(), cons.element().value()->stream); } // now tell them who to connect to std::vector<hostinfo> targets; for (unsigned long i = 0; i < hosts.size(); ++i) { hostinfo info(hosts[i], i+1); dlib::serialize(targets, cons[info.node_id]->stream); targets.push_back(info); // let the other host know how many incoming connections to expect const unsigned long num = hosts.size()-targets.size(); dlib::serialize(num, cons[info.node_id]->stream); cons[info.node_id]->stream.flush(); } } // ------------------------------------------------------------------------------------ } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- namespace impl2 { // These control bytes are sent before each message between nodes. Note that many // of these are only sent between the control node (node 0) and the other nodes. // This is because the controller node is responsible for handling the // synchronization that needs to happen when all nodes block on calls to // receive_data() // at the same time. // denotes a normal content message. const static char MESSAGE_HEADER = 0; // sent to the controller node when someone receives a message via receive_data(). const static char GOT_MESSAGE = 1; // sent to the controller node when someone sends a message via send(). const static char SENT_MESSAGE = 2; // sent to the controller node when someone enters a call to receive_data() const static char IN_WAITING_STATE = 3; // broadcast when a node terminates itself. const static char NODE_TERMINATE = 5; // broadcast by the controller node when it determines that all nodes are blocked // on calls to receive_data() and there aren't any messages in flight. This is also // what makes us go to the next epoch. const static char SEE_ALL_IN_WAITING_STATE = 6; // This isn't ever transmitted between nodes. It is used internally to indicate // that an error occurred. const static char READ_ERROR = 7; // ------------------------------------------------------------------------------------ void read_thread ( impl1::bsp_con* con, unsigned long node_id, unsigned long sender_id, impl1::thread_safe_message_queue& msg_buffer ) { try { while(true) { impl1::msg_data msg; deserialize(msg.msg_type, con->stream); msg.sender_id = sender_id; if (msg.msg_type == MESSAGE_HEADER) { msg.data.reset(new std::vector<char>); deserialize(msg.epoch, con->stream); deserialize(*msg.data, con->stream); } msg_buffer.push_and_consume(msg); if (msg.msg_type == NODE_TERMINATE) break; } } catch (std::exception& e) { impl1::msg_data msg; msg.data.reset(new std::vector<char>); vectorstream sout(*msg.data); sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; sout << " Receiving processing node id: " << node_id << std::endl; sout << " Error message in the exception: " << e.what() << std::endl; msg.sender_id = sender_id; msg.msg_type = READ_ERROR; msg_buffer.push_and_consume(msg); } catch (...) { impl1::msg_data msg; msg.data.reset(new std::vector<char>); vectorstream sout(*msg.data); sout << "An exception was thrown while attempting to receive a message from processing node " << sender_id << ".\n"; sout << " Sending processing node address: " << con->con->get_foreign_ip() << ":" << con->con->get_foreign_port() << std::endl; sout << " Receiving processing node address: " << con->con->get_local_ip() << ":" << con->con->get_local_port() << std::endl; sout << " Receiving processing node id: " << node_id << std::endl; msg.sender_id = sender_id; msg.msg_type = READ_ERROR; msg_buffer.push_and_consume(msg); } } // ------------------------------------------------------------------------------------ } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // IMPLEMENTATION OF bsp_context OBJECT MEMBERS // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- void bsp_context:: close_all_connections_gracefully( ) { if (node_id() != 0) { _cons.reset(); while (_cons.move_next()) { // tell the other end that we are intentionally dropping the connection serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); _cons.element().value()->stream.flush(); } } impl1::msg_data msg; // now wait for all the other nodes to terminate while (num_terminated_nodes < _cons.size() ) { if (node_id() == 0 && num_waiting_nodes + num_terminated_nodes == _cons.size() && outstanding_messages == 0) { num_waiting_nodes = 0; broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); ++current_epoch; } if (!msg_buffer.pop(msg)) throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); if (msg.msg_type == impl2::NODE_TERMINATE) { ++num_terminated_nodes; _cons[msg.sender_id]->terminated = true; } else if (msg.msg_type == impl2::READ_ERROR) { throw dlib::socket_error(msg.data_to_string()); } else if (msg.msg_type == impl2::MESSAGE_HEADER) { throw dlib::socket_error("A BSP node received a message after it has terminated."); } else if (msg.msg_type == impl2::GOT_MESSAGE) { --num_waiting_nodes; --outstanding_messages; } else if (msg.msg_type == impl2::SENT_MESSAGE) { ++outstanding_messages; } else if (msg.msg_type == impl2::IN_WAITING_STATE) { ++num_waiting_nodes; } } if (node_id() == 0) { _cons.reset(); while (_cons.move_next()) { // tell the other end that we are intentionally dropping the connection serialize(impl2::NODE_TERMINATE,_cons.element().value()->stream); _cons.element().value()->stream.flush(); } if (outstanding_messages != 0) { std::ostringstream sout; sout << "A BSP job was allowed to terminate before all sent messages have been received.\n"; sout << "There are at least " << outstanding_messages << " messages still in flight. Make sure all sent messages\n"; sout << "have a corresponding call to receive()."; throw dlib::socket_error(sout.str()); } } } // ---------------------------------------------------------------------------------------- bsp_context:: ~bsp_context() { _cons.reset(); while (_cons.move_next()) { _cons.element().value()->con->shutdown(); } msg_buffer.disable(); // this will wait for all the threads to terminate threads.clear(); } // ---------------------------------------------------------------------------------------- bsp_context:: bsp_context( unsigned long node_id_, impl1::map_id_to_con& cons_ ) : outstanding_messages(0), num_waiting_nodes(0), num_terminated_nodes(0), current_epoch(1), _cons(cons_), _node_id(node_id_) { // spawn a bunch of read threads, one for each connection _cons.reset(); while (_cons.move_next()) { std::unique_ptr<thread_function> ptr(new thread_function(&impl2::read_thread, _cons.element().value().get(), _node_id, _cons.element().key(), ref(msg_buffer))); threads.push_back(ptr); } } // ---------------------------------------------------------------------------------------- bool bsp_context:: receive_data ( std::shared_ptr<std::vector<char> >& item, unsigned long& sending_node_id ) { notify_control_node(impl2::IN_WAITING_STATE); while (true) { // If there aren't any nodes left to give us messages then return right now. // We need to check the msg_buffer size to make sure there aren't any // unprocessed message there. Recall that this can happen because status // messages always jump to the front of the message buffer. So we might have // learned about the node terminations before processing their messages for us. if (num_terminated_nodes == _cons.size() && msg_buffer.size() == 0) { return false; } // if all running nodes are currently blocking forever on receive_data() if (node_id() == 0 && outstanding_messages == 0 && num_terminated_nodes + num_waiting_nodes == _cons.size()) { num_waiting_nodes = 0; broadcast_byte(impl2::SEE_ALL_IN_WAITING_STATE); // Note that the reason we have this epoch counter is so we can tell if a // sent message is from before or after one of these "all nodes waiting" // synchronization events. If we didn't have the epoch count we would have // a race condition where one node gets the SEE_ALL_IN_WAITING_STATE // message before others and then sends out a message to another node // before that node got the SEE_ALL_IN_WAITING_STATE message. Then that // node would think the normal message came before SEE_ALL_IN_WAITING_STATE // which would be bad. ++current_epoch; return false; } impl1::msg_data data; if (!msg_buffer.pop(data, current_epoch)) throw dlib::socket_error("Error reading from msg_buffer in dlib::bsp_context."); switch(data.msg_type) { case impl2::MESSAGE_HEADER: { item = data.data; sending_node_id = data.sender_id; notify_control_node(impl2::GOT_MESSAGE); return true; } break; case impl2::IN_WAITING_STATE: { ++num_waiting_nodes; } break; case impl2::GOT_MESSAGE: { --outstanding_messages; --num_waiting_nodes; } break; case impl2::SENT_MESSAGE: { ++outstanding_messages; } break; case impl2::NODE_TERMINATE: { ++num_terminated_nodes; _cons[data.sender_id]->terminated = true; } break; case impl2::SEE_ALL_IN_WAITING_STATE: { ++current_epoch; return false; } break; case impl2::READ_ERROR: { throw dlib::socket_error(data.data_to_string()); } break; default: { throw dlib::socket_error("Unknown message received by dlib::bsp_context"); } break; } // end switch() } // end while (true) } // ---------------------------------------------------------------------------------------- void bsp_context:: notify_control_node ( char val ) { if (node_id() == 0) { using namespace impl2; switch(val) { case SENT_MESSAGE: { ++outstanding_messages; } break; case GOT_MESSAGE: { --outstanding_messages; } break; case IN_WAITING_STATE: { // nothing to do in this case } break; default: DLIB_CASSERT(false,"This should never happen"); } } else { serialize(val, _cons[0]->stream); _cons[0]->stream.flush(); } } // ---------------------------------------------------------------------------------------- void bsp_context:: broadcast_byte ( char val ) { for (unsigned long i = 0; i < number_of_nodes(); ++i) { // don't send to yourself or to terminated nodes if (i == node_id() || _cons[i]->terminated) continue; serialize(val, _cons[i]->stream); _cons[i]->stream.flush(); } } // ---------------------------------------------------------------------------------------- void bsp_context:: send_data( const std::vector<char>& item, unsigned long target_node_id ) { using namespace impl2; if (_cons[target_node_id]->terminated) throw socket_error("Attempt to send a message to a node that has terminated."); serialize(MESSAGE_HEADER, _cons[target_node_id]->stream); serialize(current_epoch, _cons[target_node_id]->stream); serialize(item, _cons[target_node_id]->stream); _cons[target_node_id]->stream.flush(); notify_control_node(SENT_MESSAGE); } // ---------------------------------------------------------------------------------------- } #endif // DLIB_BSP_CPph_