Spaces:
Sleeping
Sleeping
from time import sleep, time | |
from ditk import logging | |
from ding.framework import task | |
from ding.utils.lock_helper import LockContext, LockContextType | |
from ding.utils.design_helper import SingletonMetaclass | |
class BarrierRuntime(metaclass=SingletonMetaclass): | |
def __init__(self, node_id: int, max_world_size: int = 100): | |
""" | |
Overview: | |
'BarrierRuntime' is a singleton class. In addition, it must be initialized before the | |
class 'Parallel' starts MQ, otherwise the messages sent by other nodes may be lost after | |
the detection is completed. We don't have a message retransmission mechanism, and losing | |
a message means deadlock. | |
Arguments: | |
- node_id (int): Process ID. | |
- max_world_size (int, optional): The maximum total number of processes that can be | |
synchronized, the defalut value is 100. | |
""" | |
self.node_id = node_id | |
self._has_detected = False | |
self._range_len = len(str(max_world_size)) + 1 | |
self._barrier_epoch = 0 | |
self._barrier_recv_peers_buff = dict() | |
self._barrier_recv_peers = dict() | |
self._barrier_ack_peers = [] | |
self._barrier_lock = LockContext(LockContextType.THREAD_LOCK) | |
self.mq_type = task.router.mq_type | |
self._connected_peers = dict() | |
self._connected_peers_lock = LockContext(LockContextType.THREAD_LOCK) | |
self._keep_alive_daemon = False | |
self._event_name_detect = "b_det" | |
self.event_name_req = "b_req" | |
self.event_name_ack = "b_ack" | |
def _alive_msg_handler(self, peer_id): | |
with self._connected_peers_lock: | |
self._connected_peers[peer_id] = time() | |
def _add_barrier_req(self, msg): | |
peer, epoch = self._unpickle_barrier_tag(msg) | |
logging.debug("Node:[{}] recv barrier request from node:{}, epoch:{}".format(self.node_id, peer, epoch)) | |
with self._barrier_lock: | |
if peer not in self._barrier_recv_peers: | |
self._barrier_recv_peers[peer] = [] | |
self._barrier_recv_peers[peer].append(epoch) | |
def _add_barrier_ack(self, peer): | |
logging.debug("Node:[{}] recv barrier ack from node:{}".format(self.node_id, peer)) | |
with self._barrier_lock: | |
self._barrier_ack_peers.append(peer) | |
def _unpickle_barrier_tag(self, msg): | |
return msg % self._range_len, msg // self._range_len | |
def pickle_barrier_tag(self): | |
return int(self._barrier_epoch * self._range_len + self.node_id) | |
def reset_all_peers(self): | |
with self._barrier_lock: | |
for peer, q in self._barrier_recv_peers.items(): | |
if len(q) != 0: | |
assert q.pop(0) == self._barrier_epoch | |
self._barrier_ack_peers = [] | |
self._barrier_epoch += 1 | |
def get_recv_num(self): | |
count = 0 | |
with self._barrier_lock: | |
if len(self._barrier_recv_peers) > 0: | |
for _, q in self._barrier_recv_peers.items(): | |
if len(q) > 0 and q[0] == self._barrier_epoch: | |
count += 1 | |
return count | |
def get_ack_num(self): | |
with self._barrier_lock: | |
return len(self._barrier_ack_peers) | |
def detect_alive(self, expected, timeout): | |
# The barrier can only block other nodes within the visible range of the current node. | |
# If the 'attch_to' list of a node is empty, it does not know how many nodes will attach to him, | |
# so we cannot specify the effective range of a barrier in advance. | |
assert task._running | |
task.on(self._event_name_detect, self._alive_msg_handler) | |
task.on(self.event_name_req, self._add_barrier_req) | |
task.on(self.event_name_ack, self._add_barrier_ack) | |
start = time() | |
while True: | |
sleep(0.1) | |
task.emit(self._event_name_detect, self.node_id, only_remote=True) | |
# In case the other node has not had time to receive our detect message, | |
# we will send an additional round. | |
if self._has_detected: | |
break | |
with self._connected_peers_lock: | |
if len(self._connected_peers) == expected: | |
self._has_detected = True | |
if time() - start > timeout: | |
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) | |
task.off(self._event_name_detect) | |
logging.info( | |
"Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected) | |
) | |
class BarrierContext: | |
def __init__(self, runtime: BarrierRuntime, detect_timeout, expected_peer_num: int = 0): | |
self._runtime = runtime | |
self._expected_peer_num = expected_peer_num | |
self._timeout = detect_timeout | |
def __enter__(self): | |
if not self._runtime._has_detected: | |
self._runtime.detect_alive(self._expected_peer_num, self._timeout) | |
def __exit__(self, exc_type, exc_value, tb): | |
if exc_type is not None: | |
import traceback | |
traceback.print_exception(exc_type, exc_value, tb) | |
self._runtime.reset_all_peers() | |
class Barrier: | |
def __init__(self, attch_from_nums: int, timeout: int = 60): | |
""" | |
Overview: | |
Barrier() is a middleware for debug or profiling. It can synchronize the task step of each | |
process within the scope of all visible processes. When using Barrier(), you need to pay | |
attention to the following points: | |
1. All processes must call the same number of Barrier(), otherwise a deadlock occurs. | |
2. 'attch_from_nums' is a very important variable, This value indicates the number of times | |
the current process will be attached to by other processes (the number of connections | |
established). | |
For example: | |
Node0: address: 127.0.0.1:12345, attach_to = [] | |
Node1: address: 127.0.0.1:12346, attach_to = ["tcp://127.0.0.1:12345"] | |
For Node0, the 'attch_from_nums' value is 1. (It will be acttched by Node1) | |
For Node1, the 'attch_from_nums' value is 0. (No one will attach to Node1) | |
Please note that this value must be given correctly, otherwise, for a node whose 'attach_to' | |
list is empty, it cannot perceive how many processes will establish connections with it, | |
resulting in any form of synchronization cannot be performed. | |
3. Barrier() is thread-safe, but it is not recommended to use barrier in multithreading. You need | |
to carefully calculate the number of times each thread calls Barrier() to avoid deadlock. | |
4. In normal training tasks, please do not use Barrier(), which will force the step synchronization | |
between each process, so it will greatly damage the training efficiency. In addition, if your | |
training task has dynamic processes, do not use Barrier() to prevent deadlock. | |
Arguments: | |
- attch_from_nums (int): [description] | |
- timeout (int, optional): The timeout for successful detection of 'expected_peer_num' | |
number of nodes, the default value is 60 seconds. | |
""" | |
self.node_id = task.router.node_id | |
self.timeout = timeout | |
self._runtime: BarrierRuntime = task.router.barrier_runtime | |
self._barrier_peers_nums = task.get_attch_to_len() + attch_from_nums | |
logging.info( | |
"Node:[{}], attach to num is:{}, attach from num is:{}".format( | |
self.node_id, task.get_attch_to_len(), attch_from_nums | |
) | |
) | |
def __call__(self, ctx): | |
self._wait_barrier(ctx) | |
yield | |
self._wait_barrier(ctx) | |
def _wait_barrier(self, ctx): | |
self_ready = False | |
with BarrierContext(self._runtime, self.timeout, self._barrier_peers_nums): | |
logging.debug("Node:[{}] enter barrier".format(self.node_id)) | |
# Step1: Notifies all the attached nodes that we have reached the barrier. | |
task.emit(self._runtime.event_name_req, self._runtime.pickle_barrier_tag(), only_remote=True) | |
logging.debug("Node:[{}] sended barrier request".format(self.node_id)) | |
# Step2: We check the number of flags we have received. | |
# In the current CI design of DI-engine, there will always be a node whose 'attach_to' list is empty, | |
# so there will always be a node that will send ACK unconditionally, so deadlock will not occur. | |
if self._runtime.get_recv_num() == self._barrier_peers_nums: | |
self_ready = True | |
# Step3: Waiting for our own to be ready. | |
# Even if the current process has reached the barrier, we will not send an ack immediately, | |
# we need to wait for the slowest directly connected or indirectly connected peer to | |
# reach the barrier. | |
start = time() | |
if not self_ready: | |
while True: | |
if time() - start > self.timeout: | |
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) | |
if self._runtime.get_recv_num() != self._barrier_peers_nums: | |
sleep(0.1) | |
else: | |
break | |
# Step4: Notifies all attached nodes that we are ready. | |
task.emit(self._runtime.event_name_ack, self.node_id, only_remote=True) | |
logging.debug("Node:[{}] sended barrier ack".format(self.node_id)) | |
# Step5: Wait until all directly or indirectly connected nodes are ready. | |
start = time() | |
while True: | |
if time() - start > self.timeout: | |
raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id)) | |
if self._runtime.get_ack_num() != self._barrier_peers_nums: | |
sleep(0.1) | |
else: | |
break | |
logging.info("Node-[{}] env_step:[{}] barrier finish".format(self.node_id, ctx.env_step)) | |