|
""" Distributed training/validation utils |
|
|
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
import logging |
|
import os |
|
from typing import Optional |
|
|
|
import torch |
|
from torch import distributed as dist |
|
|
|
from .model import unwrap_model |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
def reduce_tensor(tensor, n): |
|
rt = tensor.clone() |
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM) |
|
rt /= n |
|
return rt |
|
|
|
|
|
def distribute_bn(model, world_size, reduce=False): |
|
|
|
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True): |
|
if ('running_mean' in bn_name) or ('running_var' in bn_name): |
|
if reduce: |
|
|
|
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM) |
|
bn_buf /= float(world_size) |
|
else: |
|
|
|
torch.distributed.broadcast(bn_buf, 0) |
|
|
|
|
|
def is_global_primary(args): |
|
return args.rank == 0 |
|
|
|
|
|
def is_local_primary(args): |
|
return args.local_rank == 0 |
|
|
|
|
|
def is_primary(args, local=False): |
|
return is_local_primary(args) if local else is_global_primary(args) |
|
|
|
|
|
def is_distributed_env(): |
|
if 'WORLD_SIZE' in os.environ: |
|
return int(os.environ['WORLD_SIZE']) > 1 |
|
if 'SLURM_NTASKS' in os.environ: |
|
return int(os.environ['SLURM_NTASKS']) > 1 |
|
return False |
|
|
|
|
|
def world_info_from_env(): |
|
local_rank = 0 |
|
for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): |
|
if v in os.environ: |
|
local_rank = int(os.environ[v]) |
|
break |
|
|
|
global_rank = 0 |
|
for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): |
|
if v in os.environ: |
|
global_rank = int(os.environ[v]) |
|
break |
|
|
|
world_size = 1 |
|
for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): |
|
if v in os.environ: |
|
world_size = int(os.environ[v]) |
|
break |
|
|
|
return local_rank, global_rank, world_size |
|
|
|
|
|
def init_distributed_device(args): |
|
|
|
|
|
args.distributed = False |
|
args.world_size = 1 |
|
args.rank = 0 |
|
args.local_rank = 0 |
|
result = init_distributed_device_so( |
|
device=getattr(args, 'device', 'cuda'), |
|
dist_backend=getattr(args, 'dist_backend', None), |
|
dist_url=getattr(args, 'dist_url', None), |
|
) |
|
args.device = result['device'] |
|
args.world_size = result['world_size'] |
|
args.rank = result['global_rank'] |
|
args.local_rank = result['local_rank'] |
|
args.distributed = result['distributed'] |
|
device = torch.device(args.device) |
|
return device |
|
|
|
|
|
def init_distributed_device_so( |
|
device: str = 'cuda', |
|
dist_backend: Optional[str] = None, |
|
dist_url: Optional[str] = None, |
|
): |
|
|
|
|
|
distributed = False |
|
world_size = 1 |
|
global_rank = 0 |
|
local_rank = 0 |
|
if dist_backend is None: |
|
|
|
dist_backend = 'nccl' if 'cuda' in device else 'gloo' |
|
dist_url = dist_url or 'env://' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_distributed_env(): |
|
if 'SLURM_PROCID' in os.environ: |
|
|
|
local_rank, global_rank, world_size = world_info_from_env() |
|
|
|
os.environ['LOCAL_RANK'] = str(local_rank) |
|
os.environ['RANK'] = str(global_rank) |
|
os.environ['WORLD_SIZE'] = str(world_size) |
|
torch.distributed.init_process_group( |
|
backend=dist_backend, |
|
init_method=dist_url, |
|
world_size=world_size, |
|
rank=global_rank, |
|
) |
|
else: |
|
|
|
local_rank, _, _ = world_info_from_env() |
|
torch.distributed.init_process_group( |
|
backend=dist_backend, |
|
init_method=dist_url, |
|
) |
|
world_size = torch.distributed.get_world_size() |
|
global_rank = torch.distributed.get_rank() |
|
distributed = True |
|
|
|
if 'cuda' in device: |
|
assert torch.cuda.is_available(), f'CUDA is not available but {device} was specified.' |
|
|
|
if distributed and device != 'cpu': |
|
device, *device_idx = device.split(':', maxsplit=1) |
|
|
|
|
|
|
|
if device_idx: |
|
_logger.warning(f'device index {device_idx[0]} removed from specified ({device}).') |
|
|
|
device = f'{device}:{local_rank}' |
|
|
|
if device.startswith('cuda:'): |
|
torch.cuda.set_device(device) |
|
|
|
return dict( |
|
device=device, |
|
global_rank=global_rank, |
|
local_rank=local_rank, |
|
world_size=world_size, |
|
distributed=distributed, |
|
) |
|
|