|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import sys |
|
import uuid |
|
from dataclasses import dataclass, field |
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
|
|
|
import torch.distributed.elastic.rendezvous.registry as rdzv_registry |
|
from torch.distributed.elastic import events, metrics |
|
from torch.distributed.elastic.agent.server.api import WorkerSpec |
|
from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent |
|
from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs, LogsSpecs, SignalException |
|
from torch.distributed.elastic.multiprocessing.errors import ChildFailedError |
|
from torch.distributed.elastic.rendezvous import RendezvousParameters |
|
from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint |
|
from torch.distributed.elastic.utils.logging import get_logger |
|
|
|
__all__ = ['LaunchConfig', 'elastic_launch', 'launch_agent'] |
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
@dataclass |
|
class LaunchConfig: |
|
""" |
|
Creates a rendezvous config. |
|
|
|
Args: |
|
min_nodes: Minimum amount of nodes that the user function will |
|
be launched on. Elastic agent ensures that the user |
|
function start only when the min_nodes amount enters |
|
the rendezvous. |
|
max_nodes: Maximum amount of nodes that the user function |
|
will be launched on. |
|
nproc_per_node: On each node the elastic agent will launch |
|
this amount of workers that will execute user |
|
defined function. |
|
rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). |
|
rdzv_endpoint: The endpoint of the rdzv sync. storage. |
|
rdzv_configs: Key, value pair that specifies rendezvous specific configuration. |
|
rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going |
|
to be removed in future versions, see the note below. The default timeout is 900 seconds. |
|
run_id: The unique run id of the job (if not passed a unique one will be |
|
deduced from run environment - flow workflow id in flow - or auto generated). |
|
role: User defined role of the worker (defaults to "trainer"). |
|
max_restarts: The maximum amount of restarts that elastic agent will conduct |
|
on workers before failure. |
|
monitor_interval: The interval in seconds that is used by the elastic_agent |
|
as a period of monitoring workers. |
|
start_method: The method is used by the elastic agent to start the |
|
workers (spawn, fork, forkserver). |
|
metrics_cfg: configuration to initialize metrics. |
|
local_addr: address of the local node if any. If not set, a lookup on the local |
|
machine's FQDN will be performed. |
|
local_ranks_filter: ranks for which to show logs in console. If not set, show from all. |
|
..note: |
|
`rdzv_timeout` is a legacy argument that will be removed in future. |
|
Set the timeout via `rdzv_configs['timeout']` |
|
|
|
""" |
|
|
|
min_nodes: int |
|
max_nodes: int |
|
nproc_per_node: int |
|
logs_specs: Optional[LogsSpecs] = None |
|
run_id: str = "" |
|
role: str = "default_role" |
|
rdzv_endpoint: str = "" |
|
rdzv_backend: str = "etcd" |
|
rdzv_configs: Dict[str, Any] = field(default_factory=dict) |
|
rdzv_timeout: int = -1 |
|
max_restarts: int = 3 |
|
monitor_interval: float = 0.1 |
|
start_method: str = "spawn" |
|
log_line_prefix_template: Optional[str] = None |
|
metrics_cfg: Dict[str, str] = field(default_factory=dict) |
|
local_addr: Optional[str] = None |
|
|
|
def __post_init__(self): |
|
default_timeout = 900 |
|
if self.rdzv_timeout != -1: |
|
self.rdzv_configs["timeout"] = self.rdzv_timeout |
|
elif "timeout" not in self.rdzv_configs: |
|
self.rdzv_configs["timeout"] = default_timeout |
|
|
|
|
|
if self.logs_specs is None: |
|
self.logs_specs = DefaultLogsSpecs() |
|
|
|
|
|
class elastic_launch: |
|
""" |
|
Launches an torchelastic agent on the container that invoked the entrypoint. |
|
|
|
1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ |
|
``entrypoint`` can be a function or a command. |
|
2. The return value is a map of each worker's output mapped |
|
by their respective global rank. |
|
|
|
Usage |
|
|
|
:: |
|
|
|
def worker_fn(foo): |
|
# ... |
|
|
|
def main(): |
|
# entrypoint is a function. |
|
outputs = elastic_launch(LaunchConfig, worker_fn)(foo) |
|
# return rank 0's output |
|
return outputs[0] |
|
|
|
# entrypoint is a command and ``script.py`` is the python module. |
|
outputs = elastic_launch(LaunchConfig, "script.py")(args) |
|
outputs = elastic_launch(LaunchConfig, "python")("script.py") |
|
""" |
|
|
|
def __init__( |
|
self, |
|
config: LaunchConfig, |
|
entrypoint: Union[Callable, str, None], |
|
): |
|
self._config = config |
|
self._entrypoint = entrypoint |
|
|
|
def __call__(self, *args): |
|
return launch_agent(self._config, self._entrypoint, list(args)) |
|
|
|
|
|
def _get_entrypoint_name( |
|
entrypoint: Union[Callable, str, None], args: List[Any] |
|
) -> str: |
|
"""Retrieve entrypoint name with the rule: |
|
1. If entrypoint is a function, use ``entrypoint.__qualname__``. |
|
2. If entrypoint is a string, check its value: |
|
2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` |
|
which does not start with hifen letter (for example, "-u" will be skipped). |
|
2.2 otherwise, use ``entrypoint`` value. |
|
3. Otherwise, return empty string. |
|
""" |
|
if isinstance(entrypoint, Callable): |
|
return entrypoint.__name__ |
|
elif isinstance(entrypoint, str): |
|
if entrypoint == sys.executable: |
|
return next((arg for arg in args if arg[0] != "-"), "") |
|
else: |
|
return entrypoint |
|
else: |
|
return "" |
|
|
|
|
|
def _get_addr_and_port( |
|
rdzv_parameters: RendezvousParameters, |
|
) -> Tuple[Optional[str], Optional[int]]: |
|
if rdzv_parameters.backend != "static": |
|
return (None, None) |
|
endpoint = rdzv_parameters.endpoint |
|
endpoint = endpoint.strip() |
|
if not endpoint: |
|
raise ValueError( |
|
"Endpoint is missing in endpoint. Try to add --master-addr and --master-port" |
|
) |
|
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) |
|
if master_port == -1: |
|
raise ValueError( |
|
f"port is missing in endpoint: {endpoint}. Try to specify --master-port" |
|
) |
|
return (master_addr, master_port) |
|
|
|
|
|
def launch_agent( |
|
config: LaunchConfig, |
|
entrypoint: Union[Callable, str, None], |
|
args: List[Any], |
|
) -> Dict[int, Any]: |
|
if not config.run_id: |
|
run_id = str(uuid.uuid4().int) |
|
logger.warning("config has no run_id, generated a random run_id: %s", run_id) |
|
config.run_id = run_id |
|
|
|
entrypoint_name = _get_entrypoint_name(entrypoint, args) |
|
|
|
logger.info( |
|
"Starting elastic_operator with launch configs:\n" |
|
" entrypoint : %(entrypoint)s\n" |
|
" min_nodes : %(min_nodes)s\n" |
|
" max_nodes : %(max_nodes)s\n" |
|
" nproc_per_node : %(nproc_per_node)s\n" |
|
" run_id : %(run_id)s\n" |
|
" rdzv_backend : %(rdzv_backend)s\n" |
|
" rdzv_endpoint : %(rdzv_endpoint)s\n" |
|
" rdzv_configs : %(rdzv_configs)s\n" |
|
" max_restarts : %(max_restarts)s\n" |
|
" monitor_interval : %(monitor_interval)s\n" |
|
" log_dir : %(log_dir)s\n" |
|
" metrics_cfg : %(metrics_cfg)s\n", |
|
{ |
|
"entrypoint": entrypoint_name, |
|
"min_nodes": config.min_nodes, |
|
"max_nodes": config.max_nodes, |
|
"nproc_per_node": config.nproc_per_node, |
|
"run_id": config.run_id, |
|
"rdzv_backend": config.rdzv_backend, |
|
"rdzv_endpoint": config.rdzv_endpoint, |
|
"rdzv_configs": config.rdzv_configs, |
|
"max_restarts": config.max_restarts, |
|
"monitor_interval": config.monitor_interval, |
|
"log_dir": config.logs_specs.root_log_dir, |
|
"metrics_cfg": config.metrics_cfg |
|
} |
|
) |
|
|
|
rdzv_parameters = RendezvousParameters( |
|
backend=config.rdzv_backend, |
|
endpoint=config.rdzv_endpoint, |
|
run_id=config.run_id, |
|
min_nodes=config.min_nodes, |
|
max_nodes=config.max_nodes, |
|
local_addr=config.local_addr, |
|
**config.rdzv_configs, |
|
) |
|
|
|
master_addr, master_port = _get_addr_and_port(rdzv_parameters) |
|
|
|
spec = WorkerSpec( |
|
role=config.role, |
|
local_world_size=config.nproc_per_node, |
|
entrypoint=entrypoint, |
|
args=tuple(args), |
|
rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), |
|
max_restarts=config.max_restarts, |
|
monitor_interval=config.monitor_interval, |
|
master_addr=master_addr, |
|
master_port=master_port, |
|
local_addr=config.local_addr, |
|
) |
|
|
|
agent = LocalElasticAgent( |
|
spec=spec, |
|
logs_specs=config.logs_specs, |
|
start_method=config.start_method, |
|
log_line_prefix_template=config.log_line_prefix_template, |
|
) |
|
|
|
shutdown_rdzv = True |
|
try: |
|
metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) |
|
|
|
result = agent.run() |
|
|
|
events.record(agent.get_event_succeeded()) |
|
|
|
if result.is_failed(): |
|
|
|
|
|
|
|
|
|
raise ChildFailedError( |
|
name=entrypoint_name, |
|
failures=result.failures, |
|
) |
|
|
|
return result.return_values |
|
except ChildFailedError: |
|
raise |
|
except SignalException: |
|
|
|
|
|
|
|
shutdown_rdzv = False |
|
events.record(agent.get_event_failed()) |
|
raise |
|
except Exception: |
|
events.record(agent.get_event_failed()) |
|
raise |
|
finally: |
|
if shutdown_rdzv: |
|
spec.rdzv_handler.shutdown() |
|
|