scene-sketch-seg / vpt /src /utils /distributed.py
ahmedbrs's picture
first
254fdf2
#!/usr/bin/env python3
"""Distributed helpers."""
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def is_master_process(num_gpus=8):
"""
Determines if the current process is the master process.
"""
if torch.distributed.is_initialized():
return dist.get_rank() % num_gpus == 0
else:
return True
def run(
local_rank,
num_proc,
func,
init_method,
shard_id,
num_shards,
backend,
cfg,
args,
):
"""
Runs a function from a child process.
Args:
local_rank (int): rank of the current process on the current machine.
num_proc (int): number of processes per machine.
func (function): function to execute on each of the process.
init_method (string): method to initialize the distributed training.
TCP initialization: equiring a network address reachable from all
processes followed by the port.
Shared file-system initialization: makes use of a file system that
is shared and visible from all machines. The URL should start with
file:// and contain a path to a non-existent file on a shared file
system.
shard_id (int): the rank of the current machine.
num_shards (int): number of overall machines for the distributed
training job.
backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
supports, each with different capabilities. Details can be found
here:
https://pytorch.org/docs/stable/distributed.html
cfg (CfgNode): configs. Details can be found in
loco/config/defaults.py
"""
# Initialize the process group.
# shard_id = get_rank()
world_size = num_proc * num_shards
rank = shard_id * num_proc + local_rank
try:
torch.distributed.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank,
)
except Exception as e:
raise e
torch.cuda.set_device(local_rank)
func(cfg, args)
def destroy_process_group():
"""Destroys the default process group."""
torch.distributed.destroy_process_group()
def scaled_all_reduce(cfg, tensors):
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of
the process group (equivalent to cfg.NUM_GPUS).
"""
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS)
return tensors
def cat_all_gather(tensors):
"""Performs the concatenated all_gather operation on the provided tensors.
"""
tensors_gather = [
torch.ones_like(tensors)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensors, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def local_cat_all_gather(tensors):
"""Performs the concatenated all_gather operation on the provided tensors.
"""
tensors_gather = [
torch.ones_like(tensors)
for _ in range(get_local_size())
]
torch.distributed.all_gather(
tensors_gather,
tensors,
async_op=False,
group=_LOCAL_PROCESS_GROUP,
)
output = torch.cat(tensors_gather, dim=0)
return output
def get_local_size():
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def get_local_rank():
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)