|
|
|
import logging
|
|
from datetime import timedelta
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.multiprocessing as mp
|
|
|
|
from detectron2.utils import comm
|
|
|
|
__all__ = ["DEFAULT_TIMEOUT", "launch"]
|
|
|
|
DEFAULT_TIMEOUT = timedelta(minutes=30)
|
|
|
|
|
|
def _find_free_port():
|
|
import socket
|
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
|
sock.bind(("", 0))
|
|
port = sock.getsockname()[1]
|
|
sock.close()
|
|
|
|
return port
|
|
|
|
|
|
def launch(
|
|
main_func,
|
|
|
|
num_gpus_per_machine,
|
|
num_machines=1,
|
|
machine_rank=0,
|
|
dist_url=None,
|
|
args=(),
|
|
timeout=DEFAULT_TIMEOUT,
|
|
):
|
|
"""
|
|
Launch multi-process or distributed training.
|
|
This function must be called on all machines involved in the training.
|
|
It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine.
|
|
|
|
Args:
|
|
main_func: a function that will be called by `main_func(*args)`
|
|
num_gpus_per_machine (int): number of processes per machine. When
|
|
using GPUs, this should be the number of GPUs.
|
|
num_machines (int): the total number of machines
|
|
machine_rank (int): the rank of this machine
|
|
dist_url (str): url to connect to for distributed jobs, including protocol
|
|
e.g. "tcp://127.0.0.1:8686".
|
|
Can be set to "auto" to automatically select a free port on localhost
|
|
timeout (timedelta): timeout of the distributed workers
|
|
args (tuple): arguments passed to main_func
|
|
"""
|
|
world_size = num_machines * num_gpus_per_machine
|
|
if world_size > 1:
|
|
|
|
|
|
|
|
if dist_url == "auto":
|
|
assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs."
|
|
port = _find_free_port()
|
|
dist_url = f"tcp://127.0.0.1:{port}"
|
|
if num_machines > 1 and dist_url.startswith("file://"):
|
|
logger = logging.getLogger(__name__)
|
|
logger.warning(
|
|
"file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://"
|
|
)
|
|
|
|
mp.start_processes(
|
|
_distributed_worker,
|
|
nprocs=num_gpus_per_machine,
|
|
args=(
|
|
main_func,
|
|
world_size,
|
|
num_gpus_per_machine,
|
|
machine_rank,
|
|
dist_url,
|
|
args,
|
|
timeout,
|
|
),
|
|
daemon=False,
|
|
)
|
|
else:
|
|
main_func(*args)
|
|
|
|
|
|
def _distributed_worker(
|
|
local_rank,
|
|
main_func,
|
|
world_size,
|
|
num_gpus_per_machine,
|
|
machine_rank,
|
|
dist_url,
|
|
args,
|
|
timeout=DEFAULT_TIMEOUT,
|
|
):
|
|
has_gpu = torch.cuda.is_available()
|
|
if has_gpu:
|
|
assert num_gpus_per_machine <= torch.cuda.device_count()
|
|
global_rank = machine_rank * num_gpus_per_machine + local_rank
|
|
try:
|
|
dist.init_process_group(
|
|
backend="NCCL" if has_gpu else "GLOO",
|
|
init_method=dist_url,
|
|
world_size=world_size,
|
|
rank=global_rank,
|
|
timeout=timeout,
|
|
)
|
|
except Exception as e:
|
|
logger = logging.getLogger(__name__)
|
|
logger.error("Process group URL: {}".format(dist_url))
|
|
raise e
|
|
|
|
|
|
comm.create_local_process_group(num_gpus_per_machine)
|
|
if has_gpu:
|
|
torch.cuda.set_device(local_rank)
|
|
|
|
|
|
|
|
comm.synchronize()
|
|
|
|
main_func(*args)
|
|
|