File size: 10,055 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
"""Kernel Provisioner Classes"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import os
import signal
import sys
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ..connect import KernelConnectionInfo, LocalPortCache
from ..launcher import launch_kernel
from ..localinterfaces import is_local_ip, local_ips
from .provisioner_base import KernelProvisionerBase
class LocalProvisioner(KernelProvisionerBase): # type:ignore[misc]
"""
:class:`LocalProvisioner` is a concrete class of ABC :py:class:`KernelProvisionerBase`
and is the out-of-box default implementation used when no kernel provisioner is
specified in the kernel specification (``kernel.json``). It provides functional
parity to existing applications by launching the kernel locally and using
:class:`subprocess.Popen` to manage its lifecycle.
This class is intended to be subclassed for customizing local kernel environments
and serve as a reference implementation for other custom provisioners.
"""
process = None
_exit_future = None
pid = None
pgid = None
ip = None
ports_cached = False
@property
def has_process(self) -> bool:
return self.process is not None
async def poll(self) -> Optional[int]:
"""Poll the provisioner."""
ret = 0
if self.process:
ret = self.process.poll() # type:ignore[unreachable]
return ret
async def wait(self) -> Optional[int]:
"""Wait for the provisioner process."""
ret = 0
if self.process:
# Use busy loop at 100ms intervals, polling until the process is
# not alive. If we find the process is no longer alive, complete
# its cleanup via the blocking wait(). Callers are responsible for
# issuing calls to wait() using a timeout (see kill()).
while await self.poll() is None: # type:ignore[unreachable]
await asyncio.sleep(0.1)
# Process is no longer alive, wait and clear
ret = self.process.wait()
# Make sure all the fds get closed.
for attr in ["stdout", "stderr", "stdin"]:
fid = getattr(self.process, attr)
if fid:
fid.close()
self.process = None # allow has_process to now return False
return ret
async def send_signal(self, signum: int) -> None:
"""Sends a signal to the process group of the kernel (this
usually includes the kernel and any subprocesses spawned by
the kernel).
Note that since only SIGTERM is supported on Windows, we will
check if the desired signal is for interrupt and apply the
applicable code on Windows in that case.
"""
if self.process:
if signum == signal.SIGINT and sys.platform == "win32": # type:ignore[unreachable]
from ..win_interrupt import send_interrupt
send_interrupt(self.process.win32_interrupt_event)
return
# Prefer process-group over process
if self.pgid and hasattr(os, "killpg"):
try:
os.killpg(self.pgid, signum)
return
except OSError:
pass # We'll retry sending the signal to only the process below
# If we're here, send the signal to the process and let caller handle exceptions
self.process.send_signal(signum)
return
async def kill(self, restart: bool = False) -> None:
"""Kill the provisioner and optionally restart."""
if self.process:
if hasattr(signal, "SIGKILL"): # type:ignore[unreachable]
# If available, give preference to signalling the process-group over `kill()`.
try:
await self.send_signal(signal.SIGKILL)
return
except OSError:
pass
try:
self.process.kill()
except OSError as e:
LocalProvisioner._tolerate_no_process(e)
async def terminate(self, restart: bool = False) -> None:
"""Terminate the provisioner and optionally restart."""
if self.process:
if hasattr(signal, "SIGTERM"): # type:ignore[unreachable]
# If available, give preference to signalling the process group over `terminate()`.
try:
await self.send_signal(signal.SIGTERM)
return
except OSError:
pass
try:
self.process.terminate()
except OSError as e:
LocalProvisioner._tolerate_no_process(e)
@staticmethod
def _tolerate_no_process(os_error: OSError) -> None:
# In Windows, we will get an Access Denied error if the process
# has already terminated. Ignore it.
if sys.platform == "win32":
if os_error.winerror != 5:
raise
# On Unix, we may get an ESRCH error (or ProcessLookupError instance) if
# the process has already terminated. Ignore it.
else:
from errno import ESRCH
if not isinstance(os_error, ProcessLookupError) or os_error.errno != ESRCH:
raise
async def cleanup(self, restart: bool = False) -> None:
"""Clean up the resources used by the provisioner and optionally restart."""
if self.ports_cached and not restart:
# provisioner is about to be destroyed, return cached ports
lpc = LocalPortCache.instance()
ports = (
self.connection_info["shell_port"],
self.connection_info["iopub_port"],
self.connection_info["stdin_port"],
self.connection_info["hb_port"],
self.connection_info["control_port"],
)
for port in ports:
if TYPE_CHECKING:
assert isinstance(port, int)
lpc.return_port(port)
async def pre_launch(self, **kwargs: Any) -> Dict[str, Any]:
"""Perform any steps in preparation for kernel process launch.
This includes applying additional substitutions to the kernel launch command and env.
It also includes preparation of launch parameters.
Returns the updated kwargs.
"""
# This should be considered temporary until a better division of labor can be defined.
km = self.parent
if km:
if km.transport == "tcp" and not is_local_ip(km.ip):
msg = (
"Can only launch a kernel on a local interface. "
f"This one is not: {km.ip}."
"Make sure that the '*_address' attributes are "
"configured properly. "
f"Currently valid addresses are: {local_ips()}"
)
raise RuntimeError(msg)
# build the Popen cmd
extra_arguments = kwargs.pop("extra_arguments", [])
# write connection file / get default ports
# TODO - change when handshake pattern is adopted
if km.cache_ports and not self.ports_cached:
lpc = LocalPortCache.instance()
km.shell_port = lpc.find_available_port(km.ip)
km.iopub_port = lpc.find_available_port(km.ip)
km.stdin_port = lpc.find_available_port(km.ip)
km.hb_port = lpc.find_available_port(km.ip)
km.control_port = lpc.find_available_port(km.ip)
self.ports_cached = True
if "env" in kwargs:
jupyter_session = kwargs["env"].get("JPY_SESSION_NAME", "")
km.write_connection_file(jupyter_session=jupyter_session)
else:
km.write_connection_file()
self.connection_info = km.get_connection_info()
kernel_cmd = km.format_kernel_cmd(
extra_arguments=extra_arguments
) # This needs to remain here for b/c
else:
extra_arguments = kwargs.pop("extra_arguments", [])
kernel_cmd = self.kernel_spec.argv + extra_arguments
return await super().pre_launch(cmd=kernel_cmd, **kwargs)
async def launch_kernel(self, cmd: List[str], **kwargs: Any) -> KernelConnectionInfo:
"""Launch a kernel with a command."""
scrubbed_kwargs = LocalProvisioner._scrub_kwargs(kwargs)
self.process = launch_kernel(cmd, **scrubbed_kwargs)
pgid = None
if hasattr(os, "getpgid"):
try:
pgid = os.getpgid(self.process.pid)
except OSError:
pass
self.pid = self.process.pid
self.pgid = pgid
return self.connection_info
@staticmethod
def _scrub_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""Remove any keyword arguments that Popen does not tolerate."""
keywords_to_scrub: List[str] = ["extra_arguments", "kernel_id"]
scrubbed_kwargs = kwargs.copy()
for kw in keywords_to_scrub:
scrubbed_kwargs.pop(kw, None)
return scrubbed_kwargs
async def get_provisioner_info(self) -> Dict:
"""Captures the base information necessary for persistence relative to this instance."""
provisioner_info = await super().get_provisioner_info()
provisioner_info.update({"pid": self.pid, "pgid": self.pgid, "ip": self.ip})
return provisioner_info
async def load_provisioner_info(self, provisioner_info: Dict) -> None:
"""Loads the base information necessary for persistence relative to this instance."""
await super().load_provisioner_info(provisioner_info)
self.pid = provisioner_info["pid"]
self.pgid = provisioner_info["pgid"]
self.ip = provisioner_info["ip"]
|