|
from typing import Any, Union |
|
|
|
import numpy as np |
|
import torch |
|
|
|
|
|
def sync_across_processes( |
|
t: Union[torch.Tensor, np.ndarray], world_size: int, group: Any = None |
|
) -> Union[torch.Tensor, np.ndarray]: |
|
"""Concatenates tensors across processes. |
|
|
|
Args: |
|
t: input tensor or numpy array |
|
world_size: world size |
|
group: The process group to work on |
|
|
|
Returns: |
|
Tensor or numpy array concatenated across all processes |
|
""" |
|
|
|
torch.distributed.barrier() |
|
|
|
if isinstance(t, torch.Tensor): |
|
gather_t_tensor = [torch.ones_like(t) for _ in range(world_size)] |
|
|
|
if t.is_cuda: |
|
torch.distributed.all_gather(gather_t_tensor, t) |
|
else: |
|
torch.distributed.all_gather_object(gather_t_tensor, t, group=group) |
|
|
|
ret = torch.cat(gather_t_tensor) |
|
elif isinstance(t, np.ndarray): |
|
gather_t_array = [np.ones_like(t) for _ in range(world_size)] |
|
torch.distributed.all_gather_object(gather_t_array, t, group=group) |
|
ret = np.concatenate(gather_t_array) |
|
else: |
|
raise ValueError(f"Can't synchronize {type(t)}.") |
|
|
|
return ret |
|
|
|
|
|
|
|
def is_cuda_out_of_memory(exception: BaseException) -> bool: |
|
return ( |
|
isinstance(exception, RuntimeError) |
|
and len(exception.args) == 1 |
|
and "CUDA" in exception.args[0] |
|
and "out of memory" in exception.args[0] |
|
) |
|
|
|
|
|
|
|
def is_out_of_cpu_memory(exception: BaseException) -> bool: |
|
return ( |
|
isinstance(exception, RuntimeError) |
|
and len(exception.args) == 1 |
|
and "DefaultCPUAllocator: can't allocate memory" in exception.args[0] |
|
) |
|
|
|
|
|
|
|
def is_cudnn_snafu(exception: BaseException) -> bool: |
|
|
|
return ( |
|
isinstance(exception, RuntimeError) |
|
and len(exception.args) == 1 |
|
and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0] |
|
) |
|
|
|
|
|
|
|
def is_oom_error(exception: BaseException) -> bool: |
|
return ( |
|
is_cuda_out_of_memory(exception) |
|
or is_cudnn_snafu(exception) |
|
or is_out_of_cpu_memory(exception) |
|
) |
|
|