File size: 2,463 Bytes
5caedb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
from typing import Any

import numpy as np
import torch
import torch.distributed as dist


def sync_across_processes(
    t: torch.Tensor | np.ndarray, world_size: int, group: Any = None
) -> torch.Tensor | np.ndarray:
    """Concatenates tensors across processes.

    Args:
        t: input tensor or numpy array
        world_size: world size
        group (ProcessGroup, optional): The process group to work on

    Returns:
        Tensor or numpy array concatenated across all processes
    """

    dist.barrier()
    ret: torch.Tensor | np.ndarray

    if isinstance(t, torch.Tensor):
        gather_t_tensor = [torch.ones_like(t) for _ in range(world_size)]

        if t.is_cuda:
            dist.all_gather(gather_t_tensor, t)
        else:
            dist.all_gather_object(gather_t_tensor, t, group=group)

        ret = torch.cat(gather_t_tensor)
    elif isinstance(t, np.ndarray):
        gather_t_array = [np.ones_like(t) for _ in range(world_size)]
        dist.all_gather_object(gather_t_array, t, group=group)
        ret = np.concatenate(gather_t_array)
    else:
        raise ValueError(f"Can't synchronize {type(t)}.")

    return ret


# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
def is_cuda_out_of_memory(exception: BaseException) -> bool:
    return (
        isinstance(exception, RuntimeError)
        and len(exception.args) == 1
        and "CUDA" in exception.args[0]
        and "out of memory" in exception.args[0]
    )


# based on https://github.com/BlackHC/toma/blob/master/toma/cpu_memory.py
def is_out_of_cpu_memory(exception: BaseException) -> bool:
    return (
        isinstance(exception, RuntimeError)
        and len(exception.args) == 1
        and "DefaultCPUAllocator: can't allocate memory" in exception.args[0]
    )


# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
def is_cudnn_snafu(exception: BaseException) -> bool:
    # For/because of https://github.com/pytorch/pytorch/issues/4107
    return (
        isinstance(exception, RuntimeError)
        and len(exception.args) == 1
        and "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED." in exception.args[0]
    )


# based on https://github.com/BlackHC/toma/blob/master/toma/torch_cuda_memory.py
def is_oom_error(exception: BaseException) -> bool:
    return (
        is_cuda_out_of_memory(exception)
        or is_cudnn_snafu(exception)
        or is_out_of_cpu_memory(exception)
    )