Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/jupyter_server
/services
/kernels
/websocket.py
"""Tornado handlers for WebSocket <-> ZMQ sockets.""" | |
# Copyright (c) Jupyter Development Team. | |
# Distributed under the terms of the Modified BSD License. | |
from jupyter_core.utils import ensure_async | |
from tornado import web | |
from tornado.websocket import WebSocketHandler | |
from jupyter_server.auth.decorator import ws_authenticated | |
from jupyter_server.base.handlers import JupyterHandler | |
from jupyter_server.base.websocket import WebSocketMixin | |
AUTH_RESOURCE = "kernels" | |
class KernelWebsocketHandler(WebSocketMixin, WebSocketHandler, JupyterHandler): # type:ignore[misc] | |
"""The kernels websocket should connect""" | |
auth_resource = AUTH_RESOURCE | |
def kernel_websocket_connection_class(self): | |
"""The kernel websocket connection class.""" | |
return self.settings.get("kernel_websocket_connection_class") | |
def set_default_headers(self): | |
"""Undo the set_default_headers in JupyterHandler | |
which doesn't make sense for websockets | |
""" | |
def get_compression_options(self): | |
"""Get the socket connection options.""" | |
return self.settings.get("websocket_compression_options", None) | |
async def pre_get(self): | |
"""Handle a pre_get.""" | |
user = self.current_user | |
# authorize the user. | |
authorized = await ensure_async( | |
self.authorizer.is_authorized(self, user, "execute", "kernels") | |
) | |
if not authorized: | |
raise web.HTTPError(403) | |
kernel = self.kernel_manager.get_kernel(self.kernel_id) | |
self.connection = self.kernel_websocket_connection_class( | |
parent=kernel, websocket_handler=self, config=self.config | |
) | |
if self.get_argument("session_id", None): | |
self.connection.session.session = self.get_argument("session_id") | |
else: | |
self.log.warning("No session ID specified") | |
# For backwards compatibility with older versions | |
# of the websocket connection, call a prepare method if found. | |
if hasattr(self.connection, "prepare"): | |
await self.connection.prepare() | |
async def get(self, kernel_id): | |
"""Handle a get request for a kernel.""" | |
self.kernel_id = kernel_id | |
await self.pre_get() | |
await super().get(kernel_id=kernel_id) | |
async def open(self, kernel_id): | |
"""Open a kernel websocket.""" | |
# Need to call super here to make sure we | |
# begin a ping-pong loop with the client. | |
super().open() | |
# Wait for the kernel to emit an idle status. | |
self.log.info(f"Connecting to kernel {self.kernel_id}.") | |
await self.connection.connect() | |
def on_message(self, ws_message): | |
"""Get a kernel message from the websocket and turn it into a ZMQ message.""" | |
self.connection.handle_incoming_message(ws_message) | |
def on_close(self): | |
"""Handle a socket closure.""" | |
self.connection.disconnect() | |
self.connection = None | |
def select_subprotocol(self, subprotocols): | |
"""Select the sub protocol for the socket.""" | |
preferred_protocol = self.connection.kernel_ws_protocol | |
if preferred_protocol is None: | |
preferred_protocol = "v1.kernel.websocket.jupyter.org" | |
elif preferred_protocol == "": | |
preferred_protocol = None | |
selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None | |
# None is the default, "legacy" protocol | |
return selected_subprotocol | |