Spaces:
Running
Running
import logging | |
import multiprocessing | |
import multiprocessing.connection | |
import os | |
import pickle | |
import signal | |
import sys | |
import tempfile | |
import time | |
import warnings | |
from typing import Optional | |
from . import _prctl_pr_set_pdeathsig # type: ignore[attr-defined] | |
log = logging.getLogger(__name__) | |
class ProcessException(Exception): | |
__slots__ = ["error_index", "error_pid"] | |
def __init__(self, msg: str, error_index: int, pid: int): | |
super().__init__(msg) | |
self.msg = msg | |
self.error_index = error_index | |
self.pid = pid | |
def __reduce__(self): | |
return type(self), (self.msg, self.error_index, self.pid) | |
class ProcessRaisedException(ProcessException): | |
"""Exception raised when a process failed due to an exception raised by the code.""" | |
def __init__( | |
self, | |
msg: str, | |
error_index: int, | |
error_pid: int, | |
): | |
super().__init__(msg, error_index, error_pid) | |
class ProcessExitedException(ProcessException): | |
"""Exception raised when a process failed due to signal or exited with a specific code.""" | |
__slots__ = ["exit_code"] | |
def __init__( | |
self, | |
msg: str, | |
error_index: int, | |
error_pid: int, | |
exit_code: int, | |
signal_name: Optional[str] = None, | |
): | |
super().__init__(msg, error_index, error_pid) | |
self.exit_code = exit_code | |
self.signal_name = signal_name | |
def __reduce__(self): | |
return ( | |
type(self), | |
(self.msg, self.error_index, self.pid, self.exit_code, self.signal_name), | |
) | |
def _wrap(fn, i, args, error_file): | |
# prctl(2) is a Linux specific system call. | |
# On other systems the following function call has no effect. | |
# This is set to ensure that non-daemonic child processes can | |
# terminate if their parent terminates before they do. | |
_prctl_pr_set_pdeathsig(signal.SIGINT) | |
try: | |
fn(i, *args) | |
except KeyboardInterrupt: | |
pass # SIGINT; Killed by parent, do nothing | |
except Exception: | |
# Propagate exception to parent process, keeping original traceback | |
import traceback | |
with open(error_file, "wb") as fh: | |
pickle.dump(traceback.format_exc(), fh) | |
sys.exit(1) | |
class ProcessContext: | |
def __init__(self, processes, error_files): | |
self.error_files = error_files | |
self.processes = processes | |
self.sentinels = { | |
process.sentinel: index for index, process in enumerate(processes) | |
} | |
def pids(self): | |
return [int(process.pid) for process in self.processes] | |
def join(self, timeout=None): | |
r"""Join one or more processes within spawn context. | |
Attempt to join one or more processes in this spawn context. | |
If one of them exited with a non-zero exit status, this function | |
kills the remaining processes and raises an exception with the cause | |
of the first process exiting. | |
Returns ``True`` if all processes have been joined successfully, | |
``False`` if there are more processes that need to be joined. | |
Args: | |
timeout (float): Wait this long before giving up on waiting. | |
""" | |
# Ensure this function can be called even when we're done. | |
if len(self.sentinels) == 0: | |
return True | |
# Wait for any process to fail or all of them to succeed. | |
ready = multiprocessing.connection.wait( | |
self.sentinels.keys(), | |
timeout=timeout, | |
) | |
error_index = None | |
for sentinel in ready: | |
index = self.sentinels.pop(sentinel) | |
process = self.processes[index] | |
process.join() | |
if process.exitcode != 0: | |
error_index = index | |
break | |
# Return if there was no error. | |
if error_index is None: | |
# Return whether or not all processes have been joined. | |
return len(self.sentinels) == 0 | |
# Assume failure. Terminate processes that are still alive. | |
# Try SIGTERM then SIGKILL if the process isn't going down. | |
# The reason is related to python signal handling is limited | |
# to main thread and if that is in c/c++ land and stuck it won't | |
# to handle it. We have seen processes getting stuck not handling | |
# SIGTERM for the above reason. | |
timeout: int = 30 | |
for process in self.processes: | |
if process.is_alive(): | |
log.warning("Terminating process %s via signal SIGTERM", process.pid) | |
process.terminate() | |
end = time.monotonic() + timeout | |
for process in self.processes: | |
time_to_wait = max(0, end - time.monotonic()) | |
process.join(time_to_wait) | |
for process in self.processes: | |
if process.is_alive(): | |
log.warning( | |
"Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL", | |
process.pid, | |
) | |
process.kill() | |
process.join() | |
# The file will only be created if the process crashed. | |
failed_process = self.processes[error_index] | |
if not os.access(self.error_files[error_index], os.R_OK): | |
exitcode = self.processes[error_index].exitcode | |
if exitcode < 0: | |
try: | |
name = signal.Signals(-exitcode).name | |
except ValueError: | |
name = f"<Unknown signal {-exitcode}>" | |
raise ProcessExitedException( | |
"process %d terminated with signal %s" % (error_index, name), | |
error_index=error_index, | |
error_pid=failed_process.pid, | |
exit_code=exitcode, | |
signal_name=name, | |
) | |
else: | |
raise ProcessExitedException( | |
"process %d terminated with exit code %d" % (error_index, exitcode), | |
error_index=error_index, | |
error_pid=failed_process.pid, | |
exit_code=exitcode, | |
) | |
with open(self.error_files[error_index], "rb") as fh: | |
original_trace = pickle.load(fh) | |
msg = "\n\n-- Process %d terminated with the following error:\n" % error_index | |
msg += original_trace | |
raise ProcessRaisedException(msg, error_index, failed_process.pid) | |
class SpawnContext(ProcessContext): | |
def __init__(self, processes, error_files): | |
warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.") | |
super().__init__(processes, error_files) | |
# Note: [start_processes] | |
# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a | |
# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the | |
# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork' | |
# works better than 'spawn'. Every helper function we created for mp.spawn is indeed | |
# general enough, and backends like XLA can reuse them in Colab notebooks as well. | |
# Currently we only add this API first, we can consider adding it to documentation as | |
# needed in the future. | |
def start_processes( | |
fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn" | |
): | |
mp = multiprocessing.get_context(start_method) | |
error_files = [] | |
processes = [] | |
for i in range(nprocs): | |
# Each process is assigned a file to write tracebacks to. We | |
# use the file being non-empty to indicate an exception | |
# occurred (vs an expected shutdown). Note: this previously | |
# used a multiprocessing.Queue but that can be prone to | |
# deadlocks, so we went with a simpler solution for a one-shot | |
# message between processes. | |
tf = tempfile.NamedTemporaryFile( | |
prefix="pytorch-errorfile-", suffix=".pickle", delete=False | |
) | |
tf.close() | |
os.unlink(tf.name) | |
process = mp.Process( | |
target=_wrap, | |
args=(fn, i, args, tf.name), | |
daemon=daemon, | |
) | |
process.start() | |
error_files.append(tf.name) | |
processes.append(process) | |
context = ProcessContext(processes, error_files) | |
if not join: | |
return context | |
# Loop on join until it returns True or raises an exception. | |
while not context.join(): | |
pass | |
def spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"): | |
r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``. | |
If one of the processes exits with a non-zero exit status, the | |
remaining processes are killed and an exception is raised with the | |
cause of termination. In the case an exception was caught in the | |
child process, it is forwarded and its traceback is included in | |
the exception raised in the parent process. | |
Args: | |
fn (function): Function is called as the entrypoint of the | |
spawned process. This function must be defined at the top | |
level of a module so it can be pickled and spawned. This | |
is a requirement imposed by multiprocessing. | |
The function is called as ``fn(i, *args)``, where ``i`` is | |
the process index and ``args`` is the passed through tuple | |
of arguments. | |
args (tuple): Arguments passed to ``fn``. | |
nprocs (int): Number of processes to spawn. | |
join (bool): Perform a blocking join on all processes. | |
daemon (bool): The spawned processes' daemon flag. If set to True, | |
daemonic processes will be created. | |
start_method (str): (deprecated) this method will always use ``spawn`` | |
as the start method. To use a different start method | |
use ``start_processes()``. | |
Returns: | |
None if ``join`` is ``True``, | |
:class:`~ProcessContext` if ``join`` is ``False`` | |
""" | |
if start_method != "spawn": | |
msg = ( | |
"This method only supports start_method=spawn (got: %s).\n" | |
"To use a different start_method use:\n\t\t" | |
" torch.multiprocessing.start_processes(...)" % start_method | |
) | |
warnings.warn(msg) | |
return start_processes(fn, args, nprocs, join, daemon, start_method="spawn") | |