Spaces:
Sleeping
Sleeping
from functools import lru_cache | |
from typing import Callable, Tuple, List, Any | |
import numpy as np | |
import torch | |
from .default_helper import error_wrapper | |
from .fake_linklink import FakeLink | |
from .import_helper import try_import_link | |
def get_link(): | |
return try_import_link() | |
def is_fake_link(): | |
return isinstance(get_link(), FakeLink) | |
def get_rank() -> int: | |
""" | |
Overview: | |
Get the rank of ``linklink`` model, return 0 if use ``FakeLink``. | |
.. note:: | |
Reference ``import_helper.try_import_link`` and ``linklink.get_rank``. | |
""" | |
if is_fake_link(): | |
return 0 | |
return error_wrapper(get_link().get_rank, 0, "[WARNING]: call linklink error, return default_ret.")() | |
def get_world_size() -> int: | |
""" | |
Overview: | |
Get the ``world_size`` of ``linklink model``, return 0 if use ``FakeLink``. | |
.. note:: | |
Reference ``import_helper.try_import_link`` and ``linklink.get_world_size``. | |
""" | |
if is_fake_link(): | |
return 1 | |
return error_wrapper(get_link().get_world_size, 1, "[WARNING]: call linklink error, return default_ret.")() | |
def broadcast(value: torch.Tensor, rank: int) -> None: | |
""" | |
Overview: | |
Use ``linklink.broadcast`` and raise error when using ``FakeLink`` | |
Arguments: | |
- value (:obj:`obj`): the value to board cast | |
- rank (:obj:`int`): the rank to broadcast on | |
""" | |
if is_fake_link(): | |
raise NotImplementedError | |
get_link().broadcast(value, rank) | |
def allreduce(data: torch.Tensor, op: str = 'sum') -> None: | |
""" | |
Overview: | |
Call ``linklink.allreduce`` on the data | |
Arguments: | |
- data (:obj:`obj`): the data to reduce | |
- op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']`` | |
""" | |
link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max} | |
if op not in link_op_map.keys(): | |
raise KeyError("not support allreduce op type: {}".format(op)) | |
else: | |
link_op = link_op_map[op] | |
if is_fake_link(): | |
return data | |
get_link().allreduce(data, reduce_op=link_op) | |
if op == 'sum': | |
data.div_(get_world_size()) | |
def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None: | |
""" | |
Overview: | |
Call ``linklink.allreduce_async`` on the data | |
Arguments: | |
- data (:obj:`obj`): the data to reduce | |
- op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']`` | |
""" | |
link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max} | |
if op not in link_op_map.keys(): | |
raise KeyError("not support allreduce op type: {}".format(op)) | |
else: | |
link_op = link_op_map[op] | |
if is_fake_link(): | |
return data | |
if op == 'sum': | |
data.div_(get_world_size()) | |
get_link().allreduce_async(data, reduce_op=link_op) | |
def get_group(group_size: int) -> List: | |
""" | |
Overview: | |
Get the group segmentation of ``group_size`` each group | |
Arguments: | |
- group_size (:obj:`int`) the ``group_size`` | |
""" | |
rank = get_rank() | |
world_size = get_world_size() | |
if group_size is None: | |
group_size = world_size | |
assert (world_size % group_size == 0) | |
return simple_group_split(world_size, rank, world_size // group_size) | |
def dist_mode(func: Callable) -> Callable: | |
""" | |
Overview: | |
Wrap the function so that in can init and finalize automatically before each call | |
Arguments: | |
- func (:obj:`Callable`): the function to wrap | |
""" | |
def wrapper(*args, **kwargs): | |
dist_init() | |
func(*args, **kwargs) | |
dist_finalize() | |
return wrapper | |
def dist_init(method: str = 'slurm', device_id: int = 0) -> Tuple[int, int]: | |
""" | |
Overview: | |
Init the distribution | |
Arguments: | |
- method (:obj:`str`): Support ``['slurm', 'single_node`]`` | |
- device_id (:obj:`int`): Default device when using ``single_node`` method | |
""" | |
get_link().initialize() | |
world_size = get_link().get_world_size() | |
rank = get_link().get_rank() | |
if method == 'slurm': | |
# proc_id = int(os.environ['SLURM_PROCID']) | |
# ntasks = int(os.environ['SLURM_NTASKS']) | |
# node_list = os.environ['SLURM_NODELIST'] | |
num_gpus = torch.cuda.device_count() | |
torch.cuda.set_device(rank % num_gpus) | |
elif method == 'single_node': | |
torch.cuda.set_device(device_id) | |
return rank, world_size | |
def dist_finalize() -> None: | |
""" | |
Overview: | |
Finalize ``linklink``, see ``linklink.finalize()`` | |
""" | |
get_link().finalize() | |
class DistContext: | |
""" | |
Overview: | |
A context manager for ``linklink`` distribution | |
Interfaces: | |
``__init__``, ``__enter__``, ``__exit__`` | |
""" | |
def __init__(self) -> None: | |
""" | |
Overview: | |
Initialize the ``DistContext`` | |
""" | |
pass | |
def __enter__(self) -> None: | |
""" | |
Overview: | |
Initialize ``linklink`` distribution | |
""" | |
dist_init() | |
def __exit__(self, *args, **kwargs) -> Any: | |
""" | |
Overview: | |
Finalize ``linklink`` distribution | |
Arugments: | |
- args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function. | |
- kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function. | |
""" | |
dist_finalize() | |
def simple_group_split(world_size: int, rank: int, num_groups: int) -> List: | |
""" | |
Overview: | |
Split the group according to ``worldsize``, ``rank`` and ``num_groups`` | |
Arguments: | |
- world_size (:obj:`int`): The world size | |
- rank (:obj:`int`): The rank | |
- num_groups (:obj:`int`): The number of groups | |
.. note:: | |
With faulty input, raise ``array split does not result in an equal division`` | |
""" | |
groups = [] | |
rank_list = np.split(np.arange(world_size), num_groups) | |
rank_list = [list(map(int, x)) for x in rank_list] | |
for i in range(num_groups): | |
groups.append(get_link().new_group(rank_list[i])) | |
group_size = world_size // num_groups | |
return groups[rank // group_size] | |
def synchronize(): | |
""" | |
Overview: | |
Synchronize the process | |
""" | |
get_link().synchronize() | |