File size: 1,456 Bytes
ae09efc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import warnings
from typing import Any, List, Optional

from torch import distributed as dist

__all__ = [
    "init",
    "is_initialized",
    "size",
    "rank",
    "local_size",
    "local_rank",
    "is_main",
    "barrier",
    "gather",
    "all_gather",
]


def init() -> None:
    if "RANK" not in os.environ:
        warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.")
        return
    dist.init_process_group(backend="nccl", init_method="env://")


def is_initialized() -> bool:
    return dist.is_initialized()


def size() -> int:
    return int(os.environ.get("WORLD_SIZE", 1))


def rank() -> int:
    return int(os.environ.get("RANK", 0))


def local_size() -> int:
    return int(os.environ.get("LOCAL_WORLD_SIZE", 1))


def local_rank() -> int:
    return int(os.environ.get("LOCAL_RANK", 0))


def is_main() -> bool:
    return rank() == 0


def barrier() -> None:
    dist.barrier()


def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]:
    if not is_initialized():
        return [obj]
    if is_main():
        objs = [None for _ in range(size())]
        dist.gather_object(obj, objs, dst=dst)
        return objs
    else:
        dist.gather_object(obj, dst=dst)
        return None


def all_gather(obj: Any) -> List[Any]:
    if not is_initialized():
        return [obj]
    objs = [None for _ in range(size())]
    dist.all_gather_object(objs, obj)
    return objs