File size: 10,835 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
"""Base classes to manage a Client's interaction with a running kernel"""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
import asyncio
import atexit
import time
import typing as t
from queue import Empty
from threading import Event, Thread
import zmq.asyncio
from jupyter_core.utils import ensure_async
from ._version import protocol_version_info
from .channelsabc import HBChannelABC
from .session import Session
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
# during garbage collection of threads at exit
# -----------------------------------------------------------------------------
# Constants and exceptions
# -----------------------------------------------------------------------------
major_protocol_version = protocol_version_info[0]
class InvalidPortNumber(Exception): # noqa
"""An exception raised for an invalid port number."""
pass
class HBChannel(Thread):
"""The heartbeat channel which monitors the kernel heartbeat.
Note that the heartbeat channel is paused by default. As long as you start
this channel, the kernel manager will ensure that it is paused and un-paused
as appropriate.
"""
session = None
socket = None
address = None
_exiting = False
time_to_dead: float = 1.0
_running = None
_pause = None
_beating = None
def __init__(
self,
context: t.Optional[zmq.Context] = None,
session: t.Optional[Session] = None,
address: t.Union[t.Tuple[str, int], str] = "",
) -> None:
"""Create the heartbeat monitor thread.
Parameters
----------
context : :class:`zmq.Context`
The ZMQ context to use.
session : :class:`session.Session`
The session to use.
address : zmq url
Standard (ip, port) tuple that the kernel is listening on.
"""
super().__init__()
self.daemon = True
self.context = context
self.session = session
if isinstance(address, tuple):
if address[1] == 0:
message = "The port number for a channel cannot be 0."
raise InvalidPortNumber(message)
address_str = "tcp://%s:%i" % address
else:
address_str = address
self.address = address_str
# running is False until `.start()` is called
self._running = False
self._exit = Event()
# don't start paused
self._pause = False
self.poller = zmq.Poller()
@staticmethod
@atexit.register
def _notice_exit() -> None:
# Class definitions can be torn down during interpreter shutdown.
# We only need to set _exiting flag if this hasn't happened.
if HBChannel is not None:
HBChannel._exiting = True
def _create_socket(self) -> None:
if self.socket is not None:
# close previous socket, before opening a new one
self.poller.unregister(self.socket) # type:ignore[unreachable]
self.socket.close()
assert self.context is not None
self.socket = self.context.socket(zmq.REQ)
self.socket.linger = 1000
assert self.address is not None
self.socket.connect(self.address)
self.poller.register(self.socket, zmq.POLLIN)
async def _async_run(self) -> None:
"""The thread's main activity. Call start() instead."""
self._create_socket()
self._running = True
self._beating = True
assert self.socket is not None
while self._running:
if self._pause:
# just sleep, and skip the rest of the loop
self._exit.wait(self.time_to_dead)
continue
since_last_heartbeat = 0.0
# no need to catch EFSM here, because the previous event was
# either a recv or connect, which cannot be followed by EFSM)
await ensure_async(self.socket.send(b"ping"))
request_time = time.time()
# Wait until timeout
self._exit.wait(self.time_to_dead)
# poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
self._beating = bool(self.poller.poll(0))
if self._beating:
# the poll above guarantees we have something to recv
await ensure_async(self.socket.recv())
continue
elif self._running:
# nothing was received within the time limit, signal heart failure
since_last_heartbeat = time.time() - request_time
self.call_handlers(since_last_heartbeat)
# and close/reopen the socket, because the REQ/REP cycle has been broken
self._create_socket()
continue
def run(self) -> None:
"""Run the heartbeat thread."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._async_run())
loop.close()
def pause(self) -> None:
"""Pause the heartbeat."""
self._pause = True
def unpause(self) -> None:
"""Unpause the heartbeat."""
self._pause = False
def is_beating(self) -> bool:
"""Is the heartbeat running and responsive (and not paused)."""
if self.is_alive() and not self._pause and self._beating: # noqa
return True
else:
return False
def stop(self) -> None:
"""Stop the channel's event loop and join its thread."""
self._running = False
self._exit.set()
self.join()
self.close()
def close(self) -> None:
"""Close the heartbeat thread."""
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
def call_handlers(self, since_last_heartbeat: float) -> None:
"""This method is called in the ioloop thread when a message arrives.
Subclasses should override this method to handle incoming messages.
It is important to remember that this method is called in the thread
so that some logic must be done to ensure that the application level
handlers are called in the application thread.
"""
pass
HBChannelABC.register(HBChannel)
class ZMQSocketChannel:
"""A ZMQ socket wrapper"""
def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None:
"""Create a channel.
Parameters
----------
socket : :class:`zmq.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
Unused here, for other implementations
"""
super().__init__()
self.socket: t.Optional[zmq.Socket] = socket
self.session = session
def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
assert self.socket is not None
msg = self.socket.recv_multipart(**kwargs)
ident, smsg = self.session.feed_identities(msg)
return self.session.deserialize(smsg)
def get_msg(self, timeout: t.Optional[float] = None) -> t.Dict[str, t.Any]:
"""Gets a message if there is one that is ready."""
assert self.socket is not None
timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
ready = self.socket.poll(timeout_ms)
if ready:
res = self._recv()
return res
else:
raise Empty
def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
"""Get all messages that are currently ready."""
msgs = []
while True:
try:
msgs.append(self.get_msg())
except Empty:
break
return msgs
def msg_ready(self) -> bool:
"""Is there a message that has been received?"""
assert self.socket is not None
return bool(self.socket.poll(timeout=0))
def close(self) -> None:
"""Close the socket channel."""
if self.socket is not None:
try:
self.socket.close(linger=0)
except Exception:
pass
self.socket = None
stop = close
def is_alive(self) -> bool:
"""Test whether the channel is alive."""
return self.socket is not None
def send(self, msg: t.Dict[str, t.Any]) -> None:
"""Pass a message to the ZMQ socket to send"""
assert self.socket is not None
self.session.send(self.socket, msg)
def start(self) -> None:
"""Start the socket channel."""
pass
class AsyncZMQSocketChannel(ZMQSocketChannel):
"""A ZMQ socket in an async API"""
socket: zmq.asyncio.Socket
def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None:
"""Create a channel.
Parameters
----------
socket : :class:`zmq.asyncio.Socket`
The ZMQ socket to use.
session : :class:`session.Session`
The session to use.
loop
Unused here, for other implementations
"""
if not isinstance(socket, zmq.asyncio.Socket):
msg = "Socket must be asyncio" # type:ignore[unreachable]
raise ValueError(msg)
super().__init__(socket, session)
async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override]
assert self.socket is not None
msg = await self.socket.recv_multipart(**kwargs)
_, smsg = self.session.feed_identities(msg)
return self.session.deserialize(smsg)
async def get_msg( # type:ignore[override]
self, timeout: t.Optional[float] = None
) -> t.Dict[str, t.Any]:
"""Gets a message if there is one that is ready."""
assert self.socket is not None
timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
ready = await self.socket.poll(timeout_ms)
if ready:
res = await self._recv()
return res
else:
raise Empty
async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override]
"""Get all messages that are currently ready."""
msgs = []
while True:
try:
msgs.append(await self.get_msg())
except Empty:
break
return msgs
async def msg_ready(self) -> bool: # type:ignore[override]
"""Is there a message that has been received?"""
assert self.socket is not None
return bool(await self.socket.poll(timeout=0))
|