|
"""Gateway API handlers.""" |
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
import asyncio |
|
import logging |
|
import mimetypes |
|
import os |
|
import random |
|
import warnings |
|
from typing import Any, Optional, cast |
|
|
|
from jupyter_client.session import Session |
|
from tornado import web |
|
from tornado.concurrent import Future |
|
from tornado.escape import json_decode, url_escape, utf8 |
|
from tornado.httpclient import HTTPRequest |
|
from tornado.ioloop import IOLoop, PeriodicCallback |
|
from tornado.websocket import WebSocketHandler, websocket_connect |
|
from traitlets.config.configurable import LoggingConfigurable |
|
|
|
from ..base.handlers import APIHandler, JupyterHandler |
|
from ..utils import url_path_join |
|
from .gateway_client import GatewayClient |
|
|
|
warnings.warn( |
|
"The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0", |
|
DeprecationWarning, |
|
stacklevel=2, |
|
) |
|
|
|
|
|
|
|
GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30")) |
|
|
|
|
|
class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler): |
|
"""Gateway web socket channels handler.""" |
|
|
|
session = None |
|
gateway = None |
|
kernel_id = None |
|
ping_callback = None |
|
|
|
def check_origin(self, origin=None): |
|
"""Check origin for the socket.""" |
|
return JupyterHandler.check_origin(self, origin) |
|
|
|
def set_default_headers(self): |
|
"""Undo the set_default_headers in JupyterHandler which doesn't make sense for websockets""" |
|
|
|
def get_compression_options(self): |
|
"""Get the compression options for the socket.""" |
|
|
|
return {} |
|
|
|
def authenticate(self): |
|
"""Run before finishing the GET request |
|
|
|
Extend this method to add logic that should fire before |
|
the websocket finishes completing. |
|
""" |
|
|
|
if self.current_user is None: |
|
self.log.warning("Couldn't authenticate WebSocket connection") |
|
raise web.HTTPError(403) |
|
|
|
if self.get_argument("session_id", None): |
|
assert self.session is not None |
|
self.session.session = self.get_argument("session_id") |
|
else: |
|
self.log.warning("No session ID specified") |
|
|
|
def initialize(self): |
|
"""Initialize the socket.""" |
|
self.log.debug("Initializing websocket connection %s", self.request.path) |
|
self.session = Session(config=self.config) |
|
self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url) |
|
|
|
async def get(self, kernel_id, *args, **kwargs): |
|
"""Get the socket.""" |
|
self.authenticate() |
|
self.kernel_id = kernel_id |
|
kwargs["kernel_id"] = kernel_id |
|
await super().get(*args, **kwargs) |
|
|
|
def send_ping(self): |
|
"""Send a ping to the socket.""" |
|
if self.ws_connection is None and self.ping_callback is not None: |
|
self.ping_callback.stop() |
|
return |
|
|
|
self.ping(b"") |
|
|
|
def open(self, kernel_id, *args, **kwargs): |
|
"""Handle web socket connection open to notebook server and delegate to gateway web socket handler""" |
|
self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000) |
|
self.ping_callback.start() |
|
|
|
assert self.gateway is not None |
|
self.gateway.on_open( |
|
kernel_id=kernel_id, |
|
message_callback=self.write_message, |
|
compression_options=self.get_compression_options(), |
|
) |
|
|
|
def on_message(self, message): |
|
"""Forward message to gateway web socket handler.""" |
|
assert self.gateway is not None |
|
self.gateway.on_message(message) |
|
|
|
def write_message(self, message, binary=False): |
|
"""Send message back to notebook client. This is called via callback from self.gateway._read_messages.""" |
|
if self.ws_connection: |
|
if isinstance(message, bytes): |
|
binary = True |
|
super().write_message(message, binary=binary) |
|
elif self.log.isEnabledFor(logging.DEBUG): |
|
msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message))) |
|
self.log.debug( |
|
f"Notebook client closed websocket connection - message dropped: {msg_summary}" |
|
) |
|
|
|
def on_close(self): |
|
"""Handle a closing socket.""" |
|
self.log.debug("Closing websocket connection %s", self.request.path) |
|
assert self.gateway is not None |
|
self.gateway.on_close() |
|
super().on_close() |
|
|
|
@staticmethod |
|
def _get_message_summary(message): |
|
"""Get a summary of a message.""" |
|
summary = [] |
|
message_type = message["msg_type"] |
|
summary.append(f"type: {message_type}") |
|
|
|
if message_type == "status": |
|
summary.append(", state: {}".format(message["content"]["execution_state"])) |
|
elif message_type == "error": |
|
summary.append( |
|
", {}:{}:{}".format( |
|
message["content"]["ename"], |
|
message["content"]["evalue"], |
|
message["content"]["traceback"], |
|
) |
|
) |
|
else: |
|
summary.append(", ...") |
|
|
|
return "".join(summary) |
|
|
|
|
|
class GatewayWebSocketClient(LoggingConfigurable): |
|
"""Proxy web socket connection to a kernel/enterprise gateway.""" |
|
|
|
def __init__(self, **kwargs): |
|
"""Initialize the gateway web socket client.""" |
|
super().__init__() |
|
self.kernel_id = None |
|
self.ws = None |
|
self.ws_future: Future[Any] = Future() |
|
self.disconnected = False |
|
self.retry = 0 |
|
|
|
async def _connect(self, kernel_id, message_callback): |
|
"""Connect to the socket.""" |
|
|
|
self.ws = None |
|
self.kernel_id = kernel_id |
|
client = GatewayClient.instance() |
|
assert client.ws_url is not None |
|
|
|
ws_url = url_path_join( |
|
client.ws_url, |
|
client.kernels_endpoint, |
|
url_escape(kernel_id), |
|
"channels", |
|
) |
|
self.log.info(f"Connecting to {ws_url}") |
|
kwargs: dict[str, Any] = {} |
|
kwargs = client.load_connection_args(**kwargs) |
|
|
|
request = HTTPRequest(ws_url, **kwargs) |
|
self.ws_future = cast("Future[Any]", websocket_connect(request)) |
|
self.ws_future.add_done_callback(self._connection_done) |
|
|
|
loop = IOLoop.current() |
|
loop.add_future(self.ws_future, lambda future: self._read_messages(message_callback)) |
|
|
|
def _connection_done(self, fut): |
|
"""Handle a finished connection.""" |
|
if ( |
|
not self.disconnected and fut.exception() is None |
|
): |
|
self.ws = fut.result() |
|
self.retry = 0 |
|
self.log.debug(f"Connection is ready: ws: {self.ws}") |
|
else: |
|
self.log.warning( |
|
"Websocket connection has been closed via client disconnect or due to error. " |
|
f"Kernel with ID '{self.kernel_id}' may not be terminated on GatewayClient: {GatewayClient.instance().url}" |
|
) |
|
|
|
def _disconnect(self): |
|
"""Handle a disconnect.""" |
|
self.disconnected = True |
|
if self.ws is not None: |
|
|
|
self.ws.close() |
|
elif not self.ws_future.done(): |
|
|
|
self.ws_future.cancel() |
|
self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}") |
|
|
|
async def _read_messages(self, callback): |
|
"""Read messages from gateway server.""" |
|
while self.ws is not None: |
|
message = None |
|
if not self.disconnected: |
|
try: |
|
message = await self.ws.read_message() |
|
except Exception as e: |
|
self.log.error( |
|
f"Exception reading message from websocket: {e}" |
|
) |
|
if message is None: |
|
if not self.disconnected: |
|
self.log.warning(f"Lost connection to Gateway: {self.kernel_id}") |
|
break |
|
callback( |
|
message |
|
) |
|
else: |
|
break |
|
|
|
|
|
if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max: |
|
jitter = random.randint(10, 100) * 0.01 |
|
retry_interval = ( |
|
min( |
|
GatewayClient.instance().gateway_retry_interval * (2**self.retry), |
|
GatewayClient.instance().gateway_retry_interval_max, |
|
) |
|
+ jitter |
|
) |
|
self.retry += 1 |
|
self.log.info( |
|
"Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s", |
|
retry_interval, |
|
self.retry, |
|
GatewayClient.instance().gateway_retry_max, |
|
self.kernel_id, |
|
) |
|
await asyncio.sleep(retry_interval) |
|
loop = IOLoop.current() |
|
loop.spawn_callback(self._connect, self.kernel_id, callback) |
|
|
|
def on_open(self, kernel_id, message_callback, **kwargs): |
|
"""Web socket connection open against gateway server.""" |
|
loop = IOLoop.current() |
|
loop.spawn_callback(self._connect, kernel_id, message_callback) |
|
|
|
def on_message(self, message): |
|
"""Send message to gateway server.""" |
|
if self.ws is None: |
|
loop = IOLoop.current() |
|
loop.add_future(self.ws_future, lambda future: self._write_message(message)) |
|
else: |
|
self._write_message(message) |
|
|
|
def _write_message(self, message): |
|
"""Send message to gateway server.""" |
|
try: |
|
if not self.disconnected and self.ws is not None: |
|
self.ws.write_message(message) |
|
except Exception as e: |
|
self.log.error(f"Exception writing message to websocket: {e}") |
|
|
|
def on_close(self): |
|
"""Web socket closed event.""" |
|
self._disconnect() |
|
|
|
|
|
class GatewayResourceHandler(APIHandler): |
|
"""Retrieves resources for specific kernelspec definitions from kernel/enterprise gateway.""" |
|
|
|
@web.authenticated |
|
async def get(self, kernel_name, path, include_body=True): |
|
"""Get a gateway resource by name and path.""" |
|
mimetype: Optional[str] = None |
|
ksm = self.kernel_spec_manager |
|
kernel_spec_res = await ksm.get_kernel_spec_resource( |
|
kernel_name, path |
|
) |
|
if kernel_spec_res is None: |
|
self.log.warning( |
|
f"Kernelspec resource '{path}' for '{kernel_name}' not found. Gateway may not support" |
|
" resource serving." |
|
) |
|
else: |
|
mimetype = mimetypes.guess_type(path)[0] or "text/plain" |
|
self.finish(kernel_spec_res, set_content_type=mimetype) |
|
|
|
|
|
from ..services.kernels.handlers import _kernel_id_regex |
|
from ..services.kernelspecs.handlers import kernel_name_regex |
|
|
|
default_handlers = [ |
|
(r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler), |
|
(r"/kernelspecs/%s/(?P<path>.*)" % kernel_name_regex, GatewayResourceHandler), |
|
] |
|
|