Spaces:
Running
Running
#!/usr/bin/env python3 | |
"""Distributed helpers.""" | |
import torch | |
import torch.distributed as dist | |
_LOCAL_PROCESS_GROUP = None | |
def get_world_size() -> int: | |
if not dist.is_available(): | |
return 1 | |
if not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank() -> int: | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_master_process(num_gpus=8): | |
""" | |
Determines if the current process is the master process. | |
""" | |
if torch.distributed.is_initialized(): | |
return dist.get_rank() % num_gpus == 0 | |
else: | |
return True | |
def run( | |
local_rank, | |
num_proc, | |
func, | |
init_method, | |
shard_id, | |
num_shards, | |
backend, | |
cfg, | |
args, | |
): | |
""" | |
Runs a function from a child process. | |
Args: | |
local_rank (int): rank of the current process on the current machine. | |
num_proc (int): number of processes per machine. | |
func (function): function to execute on each of the process. | |
init_method (string): method to initialize the distributed training. | |
TCP initialization: equiring a network address reachable from all | |
processes followed by the port. | |
Shared file-system initialization: makes use of a file system that | |
is shared and visible from all machines. The URL should start with | |
file:// and contain a path to a non-existent file on a shared file | |
system. | |
shard_id (int): the rank of the current machine. | |
num_shards (int): number of overall machines for the distributed | |
training job. | |
backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are | |
supports, each with different capabilities. Details can be found | |
here: | |
https://pytorch.org/docs/stable/distributed.html | |
cfg (CfgNode): configs. Details can be found in | |
loco/config/defaults.py | |
""" | |
# Initialize the process group. | |
# shard_id = get_rank() | |
world_size = num_proc * num_shards | |
rank = shard_id * num_proc + local_rank | |
try: | |
torch.distributed.init_process_group( | |
backend=backend, | |
init_method=init_method, | |
world_size=world_size, | |
rank=rank, | |
) | |
except Exception as e: | |
raise e | |
torch.cuda.set_device(local_rank) | |
func(cfg, args) | |
def destroy_process_group(): | |
"""Destroys the default process group.""" | |
torch.distributed.destroy_process_group() | |
def scaled_all_reduce(cfg, tensors): | |
"""Performs the scaled all_reduce operation on the provided tensors. | |
The input tensors are modified in-place. Currently supports only the sum | |
reduction operator. The reduced values are scaled by the inverse size of | |
the process group (equivalent to cfg.NUM_GPUS). | |
""" | |
# Queue the reductions | |
reductions = [] | |
for tensor in tensors: | |
reduction = torch.distributed.all_reduce(tensor, async_op=True) | |
reductions.append(reduction) | |
# Wait for reductions to finish | |
for reduction in reductions: | |
reduction.wait() | |
# Scale the results | |
for tensor in tensors: | |
tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS) | |
return tensors | |
def cat_all_gather(tensors): | |
"""Performs the concatenated all_gather operation on the provided tensors. | |
""" | |
tensors_gather = [ | |
torch.ones_like(tensors) | |
for _ in range(torch.distributed.get_world_size()) | |
] | |
torch.distributed.all_gather(tensors_gather, tensors, async_op=False) | |
output = torch.cat(tensors_gather, dim=0) | |
return output | |
def local_cat_all_gather(tensors): | |
"""Performs the concatenated all_gather operation on the provided tensors. | |
""" | |
tensors_gather = [ | |
torch.ones_like(tensors) | |
for _ in range(get_local_size()) | |
] | |
torch.distributed.all_gather( | |
tensors_gather, | |
tensors, | |
async_op=False, | |
group=_LOCAL_PROCESS_GROUP, | |
) | |
output = torch.cat(tensors_gather, dim=0) | |
return output | |
def get_local_size(): | |
""" | |
Returns: | |
The size of the per-machine process group, | |
i.e. the number of processes per machine. | |
""" | |
if not dist.is_available(): | |
return 1 | |
if not dist.is_initialized(): | |
return 1 | |
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) | |
def get_local_rank(): | |
""" | |
Returns: | |
The rank of the current process within the local (per-machine) process group. | |
""" | |
if not dist.is_available(): | |
return 0 | |
if not dist.is_initialized(): | |
return 0 | |
assert _LOCAL_PROCESS_GROUP is not None | |
return dist.get_rank(group=_LOCAL_PROCESS_GROUP) | |