import os | |
import warnings | |
from typing import Any, List, Optional | |
from torch import distributed as dist | |
__all__ = [ | |
"init", | |
"is_initialized", | |
"size", | |
"rank", | |
"local_size", | |
"local_rank", | |
"is_main", | |
"barrier", | |
"gather", | |
"all_gather", | |
] | |
def init() -> None: | |
if "RANK" not in os.environ: | |
warnings.warn("Environment variable `RANK` is not set. Skipping distributed initialization.") | |
return | |
dist.init_process_group(backend="nccl", init_method="env://") | |
def is_initialized() -> bool: | |
return dist.is_initialized() | |
def size() -> int: | |
return int(os.environ.get("WORLD_SIZE", 1)) | |
def rank() -> int: | |
return int(os.environ.get("RANK", 0)) | |
def local_size() -> int: | |
return int(os.environ.get("LOCAL_WORLD_SIZE", 1)) | |
def local_rank() -> int: | |
return int(os.environ.get("LOCAL_RANK", 0)) | |
def is_main() -> bool: | |
return rank() == 0 | |
def barrier() -> None: | |
dist.barrier() | |
def gather(obj: Any, dst: int = 0) -> Optional[List[Any]]: | |
if not is_initialized(): | |
return [obj] | |
if is_main(): | |
objs = [None for _ in range(size())] | |
dist.gather_object(obj, objs, dst=dst) | |
return objs | |
else: | |
dist.gather_object(obj, dst=dst) | |
return None | |
def all_gather(obj: Any) -> List[Any]: | |
if not is_initialized(): | |
return [obj] | |
objs = [None for _ in range(size())] | |
dist.all_gather_object(objs, obj) | |
return objs | |