|
"""Default classes for Comm and CommManager, for usage in IPython. |
|
""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import contextlib |
|
import logging |
|
import typing as t |
|
import uuid |
|
|
|
from traitlets.utils.importstring import import_item |
|
|
|
import comm |
|
|
|
if t.TYPE_CHECKING: |
|
from zmq.eventloop.zmqstream import ZMQStream |
|
|
|
logger = logging.getLogger("Comm") |
|
|
|
MessageType = t.Dict[str, t.Any] |
|
MaybeDict = t.Optional[t.Dict[str, t.Any]] |
|
BuffersType = t.Optional[t.List[bytes]] |
|
CommCallback = t.Callable[[MessageType], None] |
|
CommTargetCallback = t.Callable[["BaseComm", MessageType], None] |
|
|
|
|
|
class BaseComm: |
|
"""Class for communicating between a Frontend and a Kernel |
|
|
|
Must be subclassed with a publish_msg method implementation which |
|
sends comm messages through the iopub channel. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
target_name: str = "comm", |
|
data: MaybeDict = None, |
|
metadata: MaybeDict = None, |
|
buffers: BuffersType = None, |
|
comm_id: str | None = None, |
|
primary: bool = True, |
|
target_module: str | None = None, |
|
topic: bytes | None = None, |
|
_open_data: MaybeDict = None, |
|
_close_data: MaybeDict = None, |
|
**kwargs: t.Any, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
|
|
self.comm_id = comm_id if comm_id else uuid.uuid4().hex |
|
self.primary = primary |
|
self.target_name = target_name |
|
self.target_module = target_module |
|
self.topic = topic if topic else ("comm-%s" % self.comm_id).encode("ascii") |
|
|
|
self._open_data = _open_data if _open_data else {} |
|
self._close_data = _close_data if _close_data else {} |
|
|
|
self._msg_callback: CommCallback | None = None |
|
self._close_callback: CommCallback | None = None |
|
|
|
self._closed = True |
|
|
|
if self.primary: |
|
|
|
self.open(data=data, metadata=metadata, buffers=buffers) |
|
else: |
|
self._closed = False |
|
|
|
def publish_msg( |
|
self, |
|
msg_type: str, |
|
data: MaybeDict = None, |
|
metadata: MaybeDict = None, |
|
buffers: BuffersType = None, |
|
**keys: t.Any, |
|
) -> None: |
|
msg = "publish_msg Comm method is not implemented" |
|
raise NotImplementedError(msg) |
|
|
|
def __del__(self) -> None: |
|
"""trigger close on gc""" |
|
with contextlib.suppress(Exception): |
|
|
|
|
|
self.close(deleting=True) |
|
|
|
|
|
|
|
def open( |
|
self, data: MaybeDict = None, metadata: MaybeDict = None, buffers: BuffersType = None |
|
) -> None: |
|
"""Open the frontend-side version of this comm""" |
|
|
|
if data is None: |
|
data = self._open_data |
|
comm_manager = comm.get_comm_manager() |
|
if comm_manager is None: |
|
msg = "Comms cannot be opened without a comm_manager." |
|
raise RuntimeError(msg) |
|
|
|
comm_manager.register_comm(self) |
|
try: |
|
self.publish_msg( |
|
"comm_open", |
|
data=data, |
|
metadata=metadata, |
|
buffers=buffers, |
|
target_name=self.target_name, |
|
target_module=self.target_module, |
|
) |
|
self._closed = False |
|
except Exception: |
|
comm_manager.unregister_comm(self) |
|
raise |
|
|
|
def close( |
|
self, |
|
data: MaybeDict = None, |
|
metadata: MaybeDict = None, |
|
buffers: BuffersType = None, |
|
deleting: bool = False, |
|
) -> None: |
|
"""Close the frontend-side version of this comm""" |
|
if self._closed: |
|
|
|
return |
|
self._closed = True |
|
if data is None: |
|
data = self._close_data |
|
self.publish_msg( |
|
"comm_close", |
|
data=data, |
|
metadata=metadata, |
|
buffers=buffers, |
|
) |
|
if not deleting: |
|
|
|
comm.get_comm_manager().unregister_comm(self) |
|
|
|
def send( |
|
self, data: MaybeDict = None, metadata: MaybeDict = None, buffers: BuffersType = None |
|
) -> None: |
|
"""Send a message to the frontend-side version of this comm""" |
|
self.publish_msg( |
|
"comm_msg", |
|
data=data, |
|
metadata=metadata, |
|
buffers=buffers, |
|
) |
|
|
|
|
|
|
|
def on_close(self, callback: CommCallback | None) -> None: |
|
"""Register a callback for comm_close |
|
|
|
Will be called with the `data` of the close message. |
|
|
|
Call `on_close(None)` to disable an existing callback. |
|
""" |
|
self._close_callback = callback |
|
|
|
def on_msg(self, callback: CommCallback | None) -> None: |
|
"""Register a callback for comm_msg |
|
|
|
Will be called with the `data` of any comm_msg messages. |
|
|
|
Call `on_msg(None)` to disable an existing callback. |
|
""" |
|
self._msg_callback = callback |
|
|
|
|
|
|
|
def handle_close(self, msg: MessageType) -> None: |
|
"""Handle a comm_close message""" |
|
logger.debug("handle_close[%s](%s)", self.comm_id, msg) |
|
if self._close_callback: |
|
self._close_callback(msg) |
|
|
|
def handle_msg(self, msg: MessageType) -> None: |
|
"""Handle a comm_msg message""" |
|
logger.debug("handle_msg[%s](%s)", self.comm_id, msg) |
|
if self._msg_callback: |
|
from IPython import get_ipython |
|
|
|
shell = get_ipython() |
|
if shell: |
|
shell.events.trigger("pre_execute") |
|
self._msg_callback(msg) |
|
if shell: |
|
shell.events.trigger("post_execute") |
|
|
|
|
|
class CommManager: |
|
"""Default CommManager singleton implementation for Comms in the Kernel""" |
|
|
|
|
|
|
|
def __init__(self) -> None: |
|
self.comms: dict[str, BaseComm] = {} |
|
self.targets: dict[str, CommTargetCallback] = {} |
|
|
|
def register_target(self, target_name: str, f: CommTargetCallback | str) -> None: |
|
"""Register a callable f for a given target name |
|
|
|
f will be called with two arguments when a comm_open message is received with `target`: |
|
|
|
- the Comm instance |
|
- the `comm_open` message itself. |
|
|
|
f can be a Python callable or an import string for one. |
|
""" |
|
if isinstance(f, str): |
|
f = import_item(f) |
|
|
|
self.targets[target_name] = t.cast(CommTargetCallback, f) |
|
|
|
def unregister_target(self, target_name: str, f: CommTargetCallback) -> CommTargetCallback: |
|
"""Unregister a callable registered with register_target""" |
|
return self.targets.pop(target_name) |
|
|
|
def register_comm(self, comm: BaseComm) -> str: |
|
"""Register a new comm""" |
|
comm_id = comm.comm_id |
|
self.comms[comm_id] = comm |
|
return comm_id |
|
|
|
def unregister_comm(self, comm: BaseComm) -> None: |
|
"""Unregister a comm, and close its counterpart""" |
|
|
|
comm = self.comms.pop(comm.comm_id) |
|
|
|
def get_comm(self, comm_id: str) -> BaseComm | None: |
|
"""Get a comm with a particular id |
|
|
|
Returns the comm if found, otherwise None. |
|
|
|
This will not raise an error, |
|
it will log messages if the comm cannot be found. |
|
""" |
|
try: |
|
return self.comms[comm_id] |
|
except KeyError: |
|
logger.warning("No such comm: %s", comm_id) |
|
if logger.isEnabledFor(logging.DEBUG): |
|
|
|
logger.debug("Current comms: %s", list(self.comms.keys())) |
|
return None |
|
|
|
|
|
|
|
def comm_open(self, stream: ZMQStream, ident: str, msg: MessageType) -> None: |
|
"""Handler for comm_open messages""" |
|
from comm import create_comm |
|
|
|
content = msg["content"] |
|
comm_id = content["comm_id"] |
|
target_name = content["target_name"] |
|
f = self.targets.get(target_name, None) |
|
comm = create_comm( |
|
comm_id=comm_id, |
|
primary=False, |
|
target_name=target_name, |
|
) |
|
self.register_comm(comm) |
|
if f is None: |
|
logger.error("No such comm target registered: %s", target_name) |
|
else: |
|
try: |
|
f(comm, msg) |
|
return |
|
except Exception: |
|
logger.error("Exception opening comm with target: %s", target_name, exc_info=True) |
|
|
|
|
|
try: |
|
comm.close() |
|
except Exception: |
|
logger.error( |
|
"""Could not close comm during `comm_open` failure |
|
clean-up. The comm may not have been opened yet.""", |
|
exc_info=True, |
|
) |
|
|
|
def comm_msg(self, stream: ZMQStream, ident: str, msg: MessageType) -> None: |
|
"""Handler for comm_msg messages""" |
|
content = msg["content"] |
|
comm_id = content["comm_id"] |
|
comm = self.get_comm(comm_id) |
|
if comm is None: |
|
return |
|
|
|
try: |
|
comm.handle_msg(msg) |
|
except Exception: |
|
logger.error("Exception in comm_msg for %s", comm_id, exc_info=True) |
|
|
|
def comm_close(self, stream: ZMQStream, ident: str, msg: MessageType) -> None: |
|
"""Handler for comm_close messages""" |
|
content = msg["content"] |
|
comm_id = content["comm_id"] |
|
comm = self.get_comm(comm_id) |
|
if comm is None: |
|
return |
|
|
|
self.comms[comm_id]._closed = True |
|
del self.comms[comm_id] |
|
|
|
try: |
|
comm.handle_close(msg) |
|
except Exception: |
|
logger.error("Exception in comm_close for %s", comm_id, exc_info=True) |
|
|
|
|
|
__all__ = ["CommManager", "BaseComm"] |
|
|