|
import os |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import pdb |
|
|
|
|
|
def dist_pdb(rank, in_rank=0): |
|
if rank != in_rank: |
|
dist.barrier() |
|
else: |
|
pdb.set_trace() |
|
dist.barrier() |
|
|
|
|
|
def init_distributed_mode(args): |
|
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: |
|
args.rank = int(os.environ["RANK"]) |
|
args.world_size = int(os.environ['WORLD_SIZE']) |
|
args.gpu = int(os.environ['LOCAL_RANK']) |
|
elif 'SLURM_PROCID' in os.environ: |
|
args.rank = int(os.environ['SLURM_PROCID']) |
|
args.gpu = args.rank % torch.cuda.device_count() |
|
else: |
|
print('Not using distributed mode') |
|
args.distributed = False |
|
return |
|
|
|
args.distributed = True |
|
|
|
torch.cuda.set_device(args.gpu) |
|
args.dist_backend = 'nccl' |
|
print('| distributed init (rank {}): {}'.format( |
|
args.rank, args.dist_url), flush=True) |
|
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, |
|
world_size=args.world_size, rank=args.rank) |
|
dist.barrier() |
|
|
|
|
|
def cleanup(): |
|
dist.destroy_process_group() |
|
|
|
|
|
def is_dist_avail_and_initialized(): |
|
"""检查是否支持分布式环境""" |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
|
|
|
|
def get_world_size(): |
|
if not is_dist_avail_and_initialized(): |
|
return 1 |
|
return dist.get_world_size() |
|
|
|
|
|
def get_rank(): |
|
if not is_dist_avail_and_initialized(): |
|
return 0 |
|
return dist.get_rank() |
|
|
|
|
|
def is_main_process(): |
|
return get_rank() == 0 |
|
|
|
|
|
def reduce_value(value, average=True): |
|
world_size = get_world_size() |
|
if world_size < 2: |
|
return value |
|
|
|
with torch.no_grad(): |
|
dist.all_reduce(value) |
|
if average: |
|
value /= world_size |
|
|
|
return value |
|
|