Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/distributed
/elastic
/control_plane.py
import os | |
from contextlib import contextmanager, ExitStack | |
from typing import Generator | |
from torch.distributed.elastic.multiprocessing.errors import record | |
__all__ = [ | |
"worker_main", | |
] | |
TORCH_WORKER_SERVER_SOCKET = "TORCH_WORKER_SERVER_SOCKET" | |
def _worker_server(socket_path: str) -> Generator[None, None, None]: | |
from torch._C._distributed_c10d import _WorkerServer | |
server = _WorkerServer(socket_path) | |
try: | |
yield | |
finally: | |
server.shutdown() | |
def worker_main() -> Generator[None, None, None]: | |
""" | |
This is a context manager that wraps your main entry function. This combines | |
the existing ``errors.record`` logic as well as a new ``_WorkerServer`` that | |
exposes handlers via a unix socket specified by | |
``Torch_WORKER_SERVER_SOCKET``. | |
Example | |
:: | |
@worker_main() | |
def main(): | |
pass | |
if __name__=="__main__": | |
main() | |
""" | |
with ExitStack() as stack: | |
socket_path = os.environ.get(TORCH_WORKER_SERVER_SOCKET) | |
if socket_path is not None: | |
stack.enter_context(_worker_server(socket_path)) | |
yield | |