|
import signal |
|
|
|
import torch |
|
|
|
|
|
def get_world_size(): |
|
if torch.distributed.is_available() and torch.distributed.is_initialized(): |
|
world_size = torch.distributed.get_world_size() |
|
else: |
|
world_size = 1 |
|
return world_size |
|
|
|
|
|
def get_device(local_rank=None): |
|
backend = torch.distributed.get_backend() |
|
if backend == 'nccl': |
|
if local_rank is None: |
|
device = torch.device('cuda') |
|
else: |
|
device = torch.device(f'cuda:{local_rank}') |
|
elif backend == 'gloo': |
|
device = torch.device('cpu') |
|
else: |
|
raise RuntimeError |
|
return device |
|
|
|
|
|
def all_gather_item(item, dtype, group=None, async_op=False, local_rank=None): |
|
if not torch.distributed.is_available() or \ |
|
not torch.distributed.is_initialized(): |
|
return [item] |
|
|
|
device = get_device(local_rank) |
|
|
|
if group is not None: |
|
group_size = group.size() |
|
else: |
|
group_size = get_world_size() |
|
|
|
tensor = torch.tensor([item], device=device, dtype=dtype) |
|
output_tensors = [ |
|
torch.zeros(1, dtype=tensor.dtype, device=tensor.device) |
|
for _ in range(group_size) |
|
] |
|
torch.distributed.all_gather(output_tensors, tensor, group, async_op) |
|
output = [elem.item() for elem in output_tensors] |
|
return output |
|
|
|
|
|
class DistributedSignalHandler: |
|
def __init__(self, sig=signal.SIGTERM): |
|
self.sig = sig |
|
|
|
def signals_received(self): |
|
all_received = all_gather_item( |
|
self._signal_received, dtype=torch.int32 |
|
) |
|
return all_received |
|
|
|
def __enter__(self): |
|
self._signal_received = False |
|
self.released = False |
|
self.original_handler = signal.getsignal(self.sig) |
|
|
|
def handler(signum, frame): |
|
self._signal_received = True |
|
|
|
signal.signal(self.sig, handler) |
|
|
|
return self |
|
|
|
def __exit__(self, type, value, tb): |
|
self.release() |
|
|
|
def release(self): |
|
if self.released: |
|
return False |
|
|
|
signal.signal(self.sig, self.original_handler) |
|
self.released = True |
|
return True |
|
|