Spaces:
Sleeping
Sleeping
File size: 4,774 Bytes
254fdf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
#!/usr/bin/env python3
"""Distributed helpers."""
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def is_master_process(num_gpus=8):
"""
Determines if the current process is the master process.
"""
if torch.distributed.is_initialized():
return dist.get_rank() % num_gpus == 0
else:
return True
def run(
local_rank,
num_proc,
func,
init_method,
shard_id,
num_shards,
backend,
cfg,
args,
):
"""
Runs a function from a child process.
Args:
local_rank (int): rank of the current process on the current machine.
num_proc (int): number of processes per machine.
func (function): function to execute on each of the process.
init_method (string): method to initialize the distributed training.
TCP initialization: equiring a network address reachable from all
processes followed by the port.
Shared file-system initialization: makes use of a file system that
is shared and visible from all machines. The URL should start with
file:// and contain a path to a non-existent file on a shared file
system.
shard_id (int): the rank of the current machine.
num_shards (int): number of overall machines for the distributed
training job.
backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
supports, each with different capabilities. Details can be found
here:
https://pytorch.org/docs/stable/distributed.html
cfg (CfgNode): configs. Details can be found in
loco/config/defaults.py
"""
# Initialize the process group.
# shard_id = get_rank()
world_size = num_proc * num_shards
rank = shard_id * num_proc + local_rank
try:
torch.distributed.init_process_group(
backend=backend,
init_method=init_method,
world_size=world_size,
rank=rank,
)
except Exception as e:
raise e
torch.cuda.set_device(local_rank)
func(cfg, args)
def destroy_process_group():
"""Destroys the default process group."""
torch.distributed.destroy_process_group()
def scaled_all_reduce(cfg, tensors):
"""Performs the scaled all_reduce operation on the provided tensors.
The input tensors are modified in-place. Currently supports only the sum
reduction operator. The reduced values are scaled by the inverse size of
the process group (equivalent to cfg.NUM_GPUS).
"""
# Queue the reductions
reductions = []
for tensor in tensors:
reduction = torch.distributed.all_reduce(tensor, async_op=True)
reductions.append(reduction)
# Wait for reductions to finish
for reduction in reductions:
reduction.wait()
# Scale the results
for tensor in tensors:
tensor.mul_(1.0 / cfg.NUM_GPUS / cfg.NUM_SHARDS)
return tensors
def cat_all_gather(tensors):
"""Performs the concatenated all_gather operation on the provided tensors.
"""
tensors_gather = [
torch.ones_like(tensors)
for _ in range(torch.distributed.get_world_size())
]
torch.distributed.all_gather(tensors_gather, tensors, async_op=False)
output = torch.cat(tensors_gather, dim=0)
return output
def local_cat_all_gather(tensors):
"""Performs the concatenated all_gather operation on the provided tensors.
"""
tensors_gather = [
torch.ones_like(tensors)
for _ in range(get_local_size())
]
torch.distributed.all_gather(
tensors_gather,
tensors,
async_op=False,
group=_LOCAL_PROCESS_GROUP,
)
output = torch.cat(tensors_gather, dim=0)
return output
def get_local_size():
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def get_local_rank():
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
|