// Copyright (C) 2012 Davis E. King (davis@dlib.net) // License: Boost Software License See LICENSE.txt for the full license. #ifndef DLIB_BsP_Hh_ #define DLIB_BsP_Hh_ #include "bsp_abstract.h" #include <memory> #include <queue> #include <vector> #include "../sockets.h" #include "../array.h" #include "../sockstreambuf.h" #include "../string.h" #include "../serialize.h" #include "../map.h" #include "../ref.h" #include "../vectorstream.h" namespace dlib { // ---------------------------------------------------------------------------------------- namespace impl1 { inline void null_notify( unsigned short ) {} struct bsp_con { bsp_con( const network_address& dest ) : con(connect(dest)), buf(con), stream(&buf), terminated(false) { con->disable_nagle(); } bsp_con( std::unique_ptr<connection>& conptr ) : buf(conptr), stream(&buf), terminated(false) { // make sure we own the connection conptr.swap(con); con->disable_nagle(); } std::unique_ptr<connection> con; sockstreambuf buf; std::iostream stream; bool terminated; }; typedef dlib::map<unsigned long, std::unique_ptr<bsp_con> >::kernel_1a_c map_id_to_con; void connect_all ( map_id_to_con& cons, const std::vector<network_address>& hosts, unsigned long node_id ); /*! ensures - creates connections to all the given hosts and stores them into cons !*/ void send_out_connection_orders ( map_id_to_con& cons, const std::vector<network_address>& hosts ); // ------------------------------------------------------------------------------------ struct hostinfo { hostinfo() {} hostinfo ( const network_address& addr_, unsigned long node_id_ ) : addr(addr_), node_id(node_id_) { } network_address addr; unsigned long node_id; }; inline void serialize ( const hostinfo& item, std::ostream& out ) { dlib::serialize(item.addr, out); dlib::serialize(item.node_id, out); } inline void deserialize ( hostinfo& item, std::istream& in ) { dlib::deserialize(item.addr, in); dlib::deserialize(item.node_id, in); } // ------------------------------------------------------------------------------------ void connect_all_hostinfo ( map_id_to_con& cons, const std::vector<hostinfo>& hosts, unsigned long node_id, std::string& error_string ); // ------------------------------------------------------------------------------------ template < typename port_notify_function_type > void listen_and_connect_all( unsigned long& node_id, map_id_to_con& cons, unsigned short port, port_notify_function_type port_notify_function ) { cons.clear(); std::unique_ptr<listener> list; const int status = create_listener(list, port); if (status == PORTINUSE) { throw socket_error("Unable to create listening port " + cast_to_string(port) + ". The port is already in use"); } else if (status != 0) { throw socket_error("Unable to create listening port " + cast_to_string(port) ); } port_notify_function(list->get_listening_port()); std::unique_ptr<connection> con; if (list->accept(con)) { throw socket_error("Error occurred while accepting new connection"); } std::unique_ptr<bsp_con> temp(new bsp_con(con)); unsigned long remote_node_id; dlib::deserialize(remote_node_id, temp->stream); dlib::deserialize(node_id, temp->stream); std::vector<hostinfo> targets; dlib::deserialize(targets, temp->stream); unsigned long num_incoming_connections; dlib::deserialize(num_incoming_connections, temp->stream); cons.add(remote_node_id,temp); // make a thread that will connect to all the targets map_id_to_con cons2; std::string error_string; thread_function thread(connect_all_hostinfo, dlib::ref(cons2), dlib::ref(targets), node_id, dlib::ref(error_string)); if (error_string.size() != 0) throw socket_error(error_string); // accept any incoming connections for (unsigned long i = 0; i < num_incoming_connections; ++i) { // If it takes more than 10 seconds for the other nodes to connect to us // then something has gone horribly wrong and it almost certainly will // never connect at all. So just give up if that happens. const unsigned long timeout_milliseconds = 10000; if (list->accept(con, timeout_milliseconds)) { throw socket_error("Error occurred while accepting new connection"); } temp.reset(new bsp_con(con)); dlib::deserialize(remote_node_id, temp->stream); cons.add(remote_node_id,temp); } // put all the connections created by the thread into cons thread.wait(); while (cons2.size() > 0) { unsigned long id; std::unique_ptr<bsp_con> temp; cons2.remove_any(id,temp); cons.add(id,temp); } } // ------------------------------------------------------------------------------------ struct msg_data { std::shared_ptr<std::vector<char> > data; unsigned long sender_id; char msg_type; dlib::uint64 epoch; msg_data() : sender_id(0xFFFFFFFF), msg_type(-1), epoch(0) {} std::string data_to_string() const { if (data && data->size() != 0) return std::string(&(*data)[0], data->size()); else return ""; } }; // ------------------------------------------------------------------------------------ class thread_safe_message_queue : noncopyable { /*! WHAT THIS OBJECT REPRESENTS This is a simple message queue for msg_data objects. Note that it has the special property that, while messages will generally leave the queue in the order they are inserted, any message with a smaller epoch value will always be popped out first. But for all messages with equal epoch values the queue functions as a normal FIFO queue. !*/ private: struct msg_wrap { msg_wrap( const msg_data& data_, const dlib::uint64& sequence_number_ ) : data(data_), sequence_number(sequence_number_) {} msg_wrap() : sequence_number(0){} msg_data data; dlib::uint64 sequence_number; // Make it so that when msg_wrap objects are in a std::priority_queue, // messages with a smaller epoch number always come first. Then, within an // epoch, messages are ordered by their sequence number (so smaller first // there as well). bool operator<(const msg_wrap& item) const { if (data.epoch < item.data.epoch) { return false; } else if (data.epoch > item.data.epoch) { return true; } else { if (sequence_number < item.sequence_number) return false; else return true; } } }; public: thread_safe_message_queue() : sig(class_mutex),disabled(false),next_seq_num(1) {} ~thread_safe_message_queue() { disable(); } void disable() { auto_mutex lock(class_mutex); disabled = true; sig.broadcast(); } unsigned long size() const { auto_mutex lock(class_mutex); return data.size(); } void push_and_consume( msg_data& item) { auto_mutex lock(class_mutex); data.push(msg_wrap(item, next_seq_num++)); // do this here so that we don't have to worry about different threads touching the shared_ptr. item.data.reset(); sig.signal(); } bool pop ( msg_data& item ) /*! ensures - if (this function returns true) then - #item == the next thing from the queue - else - this object is disabled !*/ { auto_mutex lock(class_mutex); while (data.size() == 0 && !disabled) sig.wait(); if (disabled) return false; item = data.top().data; data.pop(); return true; } bool pop ( msg_data& item, const dlib::uint64& max_epoch ) /*! ensures - if (this function returns true) then - #item == the next thing from the queue that has an epoch <= max_epoch - else - this object is disabled !*/ { auto_mutex lock(class_mutex); while ((data.size() == 0 || data.top().data.epoch > max_epoch) && !disabled) sig.wait(); if (disabled) return false; item = data.top().data; data.pop(); return true; } private: std::priority_queue<msg_wrap> data; dlib::mutex class_mutex; dlib::signaler sig; bool disabled; dlib::uint64 next_seq_num; }; } // ---------------------------------------------------------------------------------------- class bsp_context : noncopyable { public: template <typename T> void send( const T& item, unsigned long target_node_id ) { // make sure requires clause is not broken DLIB_CASSERT(target_node_id < number_of_nodes() && target_node_id != node_id(), "\t void bsp_context::send()" << "\n\t Invalid arguments were given to this function." << "\n\t target_node_id: " << target_node_id << "\n\t node_id(): " << node_id() << "\n\t number_of_nodes(): " << number_of_nodes() << "\n\t this: " << this ); std::vector<char> buf; vectorstream sout(buf); serialize(item, sout); send_data(buf, target_node_id); } template <typename T> void broadcast ( const T& item ) { std::vector<char> buf; vectorstream sout(buf); serialize(item, sout); for (unsigned long i = 0; i < number_of_nodes(); ++i) { // Don't send to yourself. if (i == node_id()) continue; send_data(buf, i); } } unsigned long node_id ( ) const { return _node_id; } unsigned long number_of_nodes ( ) const { return _cons.size()+1; } void receive ( ) { unsigned long id; std::shared_ptr<std::vector<char> > temp; if (receive_data(temp,id)) throw dlib::socket_error("Call to bsp_context::receive() got an unexpected message."); } template <typename T> void receive ( T& item ) { if(!try_receive(item)) throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); } template <typename T> bool try_receive ( T& item ) { unsigned long sending_node_id; return try_receive(item, sending_node_id); } template <typename T> void receive ( T& item, unsigned long& sending_node_id ) { if(!try_receive(item, sending_node_id)) throw dlib::socket_error("bsp_context::receive(): no messages to receive, all nodes currently blocked."); } template <typename T> bool try_receive ( T& item, unsigned long& sending_node_id ) { std::shared_ptr<std::vector<char> > temp; if (receive_data(temp, sending_node_id)) { vectorstream sin(*temp); deserialize(item, sin); if (sin.peek() != EOF) throw serialization_error("deserialize() did not consume all bytes produced by serialize(). " "This probably means you are calling a receive method with a different type " "of object than the one which was sent."); return true; } else { return false; } } ~bsp_context(); private: bsp_context(); bsp_context( unsigned long node_id_, impl1::map_id_to_con& cons_ ); void close_all_connections_gracefully(); /*! ensures - closes all the connections to other nodes and lets them know that we are terminating normally rather than as the result of some kind of error. !*/ bool receive_data ( std::shared_ptr<std::vector<char> >& item, unsigned long& sending_node_id ); void notify_control_node ( char val ); void broadcast_byte ( char val ); void send_data( const std::vector<char>& item, unsigned long target_node_id ); /*! requires - target_node_id < number_of_nodes() - target_node_id != node_id() ensures - sends a copy of item to the node with the given id. !*/ unsigned long outstanding_messages; unsigned long num_waiting_nodes; unsigned long num_terminated_nodes; dlib::uint64 current_epoch; impl1::thread_safe_message_queue msg_buffer; impl1::map_id_to_con& _cons; const unsigned long _node_id; array<std::unique_ptr<thread_function> > threads; // ----------------------------------- template < typename funct_type > friend void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct ); template < typename funct_type, typename ARG1 > friend void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct, ARG1 arg1 ); template < typename funct_type, typename ARG1, typename ARG2 > friend void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct, ARG1 arg1, ARG2 arg2 ); template < typename funct_type, typename ARG1, typename ARG2, typename ARG3 > friend void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ); template < typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > friend void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ); // ----------------------------------- template < typename port_notify_function_type, typename funct_type > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct ); template < typename port_notify_function_type, typename funct_type, typename ARG1 > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1 ); template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2 > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2 ); template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3 > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ); template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > friend void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ); // ----------------------------------- }; // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename funct_type > void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1 > void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct, ARG1 arg1 ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj,arg1); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2 > void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct, ARG1 arg1, ARG2 arg2 ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj,arg1,arg2); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj,arg1,arg2,arg3); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_connect ( const std::vector<network_address>& hosts, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ) { impl1::map_id_to_con cons; const unsigned long node_id = 0; connect_all(cons, hosts, node_id); send_out_connection_orders(cons, hosts); bsp_context obj(node_id, cons); funct(obj,arg1,arg2,arg3,arg4); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename funct_type > void bsp_listen ( unsigned short listening_port, funct_type funct ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1 ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2 ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3); } // ---------------------------------------------------------------------------------------- template < typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_listen ( unsigned short listening_port, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ) { // make sure requires clause is not broken DLIB_CASSERT(listening_port != 0, "\t void bsp_listen()" << "\n\t Invalid arguments were given to this function." ); bsp_listen_dynamic_port(listening_port, impl1::null_notify, funct, arg1, arg2, arg3, arg4); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1 ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj,arg1); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2 ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj,arg1,arg2); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3 ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj,arg1,arg2,arg3); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- template < typename port_notify_function_type, typename funct_type, typename ARG1, typename ARG2, typename ARG3, typename ARG4 > void bsp_listen_dynamic_port ( unsigned short listening_port, port_notify_function_type port_notify_function, funct_type funct, ARG1 arg1, ARG2 arg2, ARG3 arg3, ARG4 arg4 ) { impl1::map_id_to_con cons; unsigned long node_id; listen_and_connect_all(node_id, cons, listening_port, port_notify_function); bsp_context obj(node_id, cons); funct(obj,arg1,arg2,arg3,arg4); obj.close_all_connections_gracefully(); } // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------- } #ifdef NO_MAKEFILE #include "bsp.cpp" #endif #endif // DLIB_BsP_Hh_