Spaces:
Running
Running
# adapted from vllm | |
# https://github.com/vllm-project/vllm/blob/main/vllm/executor/multiproc_worker_utils.py | |
import asyncio | |
import multiprocessing | |
import os | |
import socket | |
import sys | |
import threading | |
import traceback | |
import uuid | |
from dataclasses import dataclass | |
from multiprocessing import Queue | |
from multiprocessing.connection import wait | |
from typing import Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union | |
from videosys.utils.logging import create_logger | |
T = TypeVar("T") | |
_TERMINATE = "TERMINATE" # sentinel | |
# ANSI color codes | |
CYAN = "\033[1;36m" | |
RESET = "\033[0;0m" | |
JOIN_TIMEOUT_S = 2 | |
mp_method = "spawn" # fork cann't work | |
mp = multiprocessing.get_context(mp_method) | |
logger = create_logger() | |
def get_distributed_init_method(ip: str, port: int) -> str: | |
# Brackets are not permitted in ipv4 addresses, | |
# see https://github.com/python/cpython/issues/103848 | |
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" | |
def get_open_port() -> int: | |
# try ipv4 | |
try: | |
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
s.bind(("", 0)) | |
return s.getsockname()[1] | |
except OSError: | |
# try ipv6 | |
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: | |
s.bind(("", 0)) | |
return s.getsockname()[1] | |
class Result(Generic[T]): | |
"""Result of task dispatched to worker""" | |
task_id: uuid.UUID | |
value: Optional[T] = None | |
exception: Optional[BaseException] = None | |
class ResultFuture(threading.Event, Generic[T]): | |
"""Synchronous future for non-async case""" | |
def __init__(self): | |
super().__init__() | |
self.result: Optional[Result[T]] = None | |
def set_result(self, result: Result[T]): | |
self.result = result | |
self.set() | |
def get(self) -> T: | |
self.wait() | |
assert self.result is not None | |
if self.result.exception is not None: | |
raise self.result.exception | |
return self.result.value # type: ignore[return-value] | |
def _set_future_result(future: Union[ResultFuture, asyncio.Future], result: Result): | |
if isinstance(future, ResultFuture): | |
future.set_result(result) | |
return | |
loop = future.get_loop() | |
if not loop.is_closed(): | |
if result.exception is not None: | |
loop.call_soon_threadsafe(future.set_exception, result.exception) | |
else: | |
loop.call_soon_threadsafe(future.set_result, result.value) | |
class ResultHandler(threading.Thread): | |
"""Handle results from all workers (in background thread)""" | |
def __init__(self) -> None: | |
super().__init__(daemon=True) | |
self.result_queue = mp.Queue() | |
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {} | |
def run(self): | |
for result in iter(self.result_queue.get, _TERMINATE): | |
future = self.tasks.pop(result.task_id) | |
_set_future_result(future, result) | |
# Ensure that all waiters will receive an exception | |
for task_id, future in self.tasks.items(): | |
_set_future_result(future, Result(task_id=task_id, exception=ChildProcessError("worker died"))) | |
def close(self): | |
self.result_queue.put(_TERMINATE) | |
class WorkerMonitor(threading.Thread): | |
"""Monitor worker status (in background thread)""" | |
def __init__(self, workers: List["ProcessWorkerWrapper"], result_handler: ResultHandler): | |
super().__init__(daemon=True) | |
self.workers = workers | |
self.result_handler = result_handler | |
self._close = False | |
def run(self) -> None: | |
# Blocks until any worker exits | |
dead_sentinels = wait([w.process.sentinel for w in self.workers]) | |
if not self._close: | |
self._close = True | |
# Kill / cleanup all workers | |
for worker in self.workers: | |
process = worker.process | |
if process.sentinel in dead_sentinels: | |
process.join(JOIN_TIMEOUT_S) | |
if process.exitcode is not None and process.exitcode != 0: | |
logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode) | |
# Cleanup any remaining workers | |
logger.info("Killing local worker processes") | |
for worker in self.workers: | |
worker.kill_worker() | |
# Must be done after worker task queues are all closed | |
self.result_handler.close() | |
for worker in self.workers: | |
worker.process.join(JOIN_TIMEOUT_S) | |
def close(self): | |
if self._close: | |
return | |
self._close = True | |
logger.info("Terminating local worker processes") | |
for worker in self.workers: | |
worker.terminate_worker() | |
# Must be done after worker task queues are all closed | |
self.result_handler.close() | |
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None: | |
"""Prepend each output line with process-specific prefix""" | |
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} " | |
file_write = file.write | |
def write_with_prefix(s: str): | |
if not s: | |
return | |
if file.start_new_line: # type: ignore[attr-defined] | |
file_write(prefix) | |
idx = 0 | |
while (next_idx := s.find("\n", idx)) != -1: | |
next_idx += 1 | |
file_write(s[idx:next_idx]) | |
if next_idx == len(s): | |
file.start_new_line = True # type: ignore[attr-defined] | |
return | |
file_write(prefix) | |
idx = next_idx | |
file_write(s[idx:]) | |
file.start_new_line = False # type: ignore[attr-defined] | |
file.start_new_line = True # type: ignore[attr-defined] | |
file.write = write_with_prefix # type: ignore[method-assign] | |
def _run_worker_process( | |
worker_factory: Callable[[], Any], | |
task_queue: Queue, | |
result_queue: Queue, | |
) -> None: | |
"""Worker process event loop""" | |
# Add process-specific prefix to stdout and stderr | |
process_name = mp.current_process().name | |
pid = os.getpid() | |
_add_prefix(sys.stdout, process_name, pid) | |
_add_prefix(sys.stderr, process_name, pid) | |
# Initialize worker | |
worker = worker_factory() | |
del worker_factory | |
# Accept tasks from the engine in task_queue | |
# and return task output in result_queue | |
logger.info("Worker ready; awaiting tasks") | |
try: | |
for items in iter(task_queue.get, _TERMINATE): | |
output = None | |
exception = None | |
task_id, method, args, kwargs = items | |
try: | |
executor = getattr(worker, method) | |
output = executor(*args, **kwargs) | |
except BaseException as e: | |
tb = traceback.format_exc() | |
logger.error("Exception in worker %s while processing method %s: %s, %s", process_name, method, e, tb) | |
exception = e | |
result_queue.put(Result(task_id=task_id, value=output, exception=exception)) | |
except KeyboardInterrupt: | |
pass | |
except Exception: | |
logger.exception("Worker failed") | |
logger.info("Worker exiting") | |
class ProcessWorkerWrapper: | |
"""Local process wrapper for handling single-node multi-GPU.""" | |
def __init__(self, result_handler: ResultHandler, worker_factory: Callable[[], Any]) -> None: | |
self._task_queue = mp.Queue() | |
self.result_queue = result_handler.result_queue | |
self.tasks = result_handler.tasks | |
self.process = mp.Process( # type: ignore[attr-defined] | |
target=_run_worker_process, | |
name="VideoSysWorkerProcess", | |
kwargs=dict( | |
worker_factory=worker_factory, | |
task_queue=self._task_queue, | |
result_queue=self.result_queue, | |
), | |
daemon=True, | |
) | |
self.process.start() | |
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], method: str, args, kwargs): | |
task_id = uuid.uuid4() | |
self.tasks[task_id] = future | |
try: | |
self._task_queue.put((task_id, method, args, kwargs)) | |
except BaseException as e: | |
del self.tasks[task_id] | |
raise ChildProcessError("worker died") from e | |
def execute_method(self, method: str, *args, **kwargs): | |
future: ResultFuture = ResultFuture() | |
self._enqueue_task(future, method, args, kwargs) | |
return future | |
async def execute_method_async(self, method: str, *args, **kwargs): | |
future = asyncio.get_running_loop().create_future() | |
self._enqueue_task(future, method, args, kwargs) | |
return await future | |
def terminate_worker(self): | |
try: | |
self._task_queue.put(_TERMINATE) | |
except ValueError: | |
self.process.kill() | |
self._task_queue.close() | |
def kill_worker(self): | |
self._task_queue.close() | |
self.process.kill() | |