mbuali's picture
Upload folder using huggingface_hub
d1ceb73 verified
"""Gateway API handlers."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import asyncio
import logging
import mimetypes
import os
import random
import warnings
from typing import Any, Optional, cast
from jupyter_client.session import Session
from tornado import web
from tornado.concurrent import Future
from tornado.escape import json_decode, url_escape, utf8
from tornado.httpclient import HTTPRequest
from tornado.ioloop import IOLoop, PeriodicCallback
from tornado.websocket import WebSocketHandler, websocket_connect
from traitlets.config.configurable import LoggingConfigurable
from ..base.handlers import APIHandler, JupyterHandler
from ..utils import url_path_join
from .gateway_client import GatewayClient
warnings.warn(
"The jupyter_server.gateway.handlers module is deprecated and will not be supported in Jupyter Server 3.0",
DeprecationWarning,
stacklevel=2,
)
# Keepalive ping interval (default: 30 seconds)
GATEWAY_WS_PING_INTERVAL_SECS = int(os.getenv("GATEWAY_WS_PING_INTERVAL_SECS", "30"))
class WebSocketChannelsHandler(WebSocketHandler, JupyterHandler):
"""Gateway web socket channels handler."""
session = None
gateway = None
kernel_id = None
ping_callback = None
def check_origin(self, origin=None):
"""Check origin for the socket."""
return JupyterHandler.check_origin(self, origin)
def set_default_headers(self):
"""Undo the set_default_headers in JupyterHandler which doesn't make sense for websockets"""
def get_compression_options(self):
"""Get the compression options for the socket."""
# use deflate compress websocket
return {}
def authenticate(self):
"""Run before finishing the GET request
Extend this method to add logic that should fire before
the websocket finishes completing.
"""
# authenticate the request before opening the websocket
if self.current_user is None:
self.log.warning("Couldn't authenticate WebSocket connection")
raise web.HTTPError(403)
if self.get_argument("session_id", None):
assert self.session is not None
self.session.session = self.get_argument("session_id") # type:ignore[unreachable]
else:
self.log.warning("No session ID specified")
def initialize(self):
"""Initialize the socket."""
self.log.debug("Initializing websocket connection %s", self.request.path)
self.session = Session(config=self.config)
self.gateway = GatewayWebSocketClient(gateway_url=GatewayClient.instance().url)
async def get(self, kernel_id, *args, **kwargs):
"""Get the socket."""
self.authenticate()
self.kernel_id = kernel_id
kwargs["kernel_id"] = kernel_id
await super().get(*args, **kwargs)
def send_ping(self):
"""Send a ping to the socket."""
if self.ws_connection is None and self.ping_callback is not None:
self.ping_callback.stop() # type:ignore[unreachable]
return
self.ping(b"")
def open(self, kernel_id, *args, **kwargs):
"""Handle web socket connection open to notebook server and delegate to gateway web socket handler"""
self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000)
self.ping_callback.start()
assert self.gateway is not None
self.gateway.on_open(
kernel_id=kernel_id,
message_callback=self.write_message,
compression_options=self.get_compression_options(),
)
def on_message(self, message):
"""Forward message to gateway web socket handler."""
assert self.gateway is not None
self.gateway.on_message(message)
def write_message(self, message, binary=False):
"""Send message back to notebook client. This is called via callback from self.gateway._read_messages."""
if self.ws_connection: # prevent WebSocketClosedError
if isinstance(message, bytes):
binary = True
super().write_message(message, binary=binary)
elif self.log.isEnabledFor(logging.DEBUG):
msg_summary = WebSocketChannelsHandler._get_message_summary(json_decode(utf8(message)))
self.log.debug(
f"Notebook client closed websocket connection - message dropped: {msg_summary}"
)
def on_close(self):
"""Handle a closing socket."""
self.log.debug("Closing websocket connection %s", self.request.path)
assert self.gateway is not None
self.gateway.on_close()
super().on_close()
@staticmethod
def _get_message_summary(message):
"""Get a summary of a message."""
summary = []
message_type = message["msg_type"]
summary.append(f"type: {message_type}")
if message_type == "status":
summary.append(", state: {}".format(message["content"]["execution_state"]))
elif message_type == "error":
summary.append(
", {}:{}:{}".format(
message["content"]["ename"],
message["content"]["evalue"],
message["content"]["traceback"],
)
)
else:
summary.append(", ...") # don't display potentially sensitive data
return "".join(summary)
class GatewayWebSocketClient(LoggingConfigurable):
"""Proxy web socket connection to a kernel/enterprise gateway."""
def __init__(self, **kwargs):
"""Initialize the gateway web socket client."""
super().__init__()
self.kernel_id = None
self.ws = None
self.ws_future: Future[Any] = Future()
self.disconnected = False
self.retry = 0
async def _connect(self, kernel_id, message_callback):
"""Connect to the socket."""
# websocket is initialized before connection
self.ws = None
self.kernel_id = kernel_id
client = GatewayClient.instance()
assert client.ws_url is not None
ws_url = url_path_join(
client.ws_url,
client.kernels_endpoint,
url_escape(kernel_id),
"channels",
)
self.log.info(f"Connecting to {ws_url}")
kwargs: dict[str, Any] = {}
kwargs = client.load_connection_args(**kwargs)
request = HTTPRequest(ws_url, **kwargs)
self.ws_future = cast("Future[Any]", websocket_connect(request))
self.ws_future.add_done_callback(self._connection_done)
loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self._read_messages(message_callback))
def _connection_done(self, fut):
"""Handle a finished connection."""
if (
not self.disconnected and fut.exception() is None
): # prevent concurrent.futures._base.CancelledError
self.ws = fut.result()
self.retry = 0
self.log.debug(f"Connection is ready: ws: {self.ws}")
else:
self.log.warning(
"Websocket connection has been closed via client disconnect or due to error. "
f"Kernel with ID '{self.kernel_id}' may not be terminated on GatewayClient: {GatewayClient.instance().url}"
)
def _disconnect(self):
"""Handle a disconnect."""
self.disconnected = True
if self.ws is not None:
# Close connection
self.ws.close()
elif not self.ws_future.done():
# Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
self.ws_future.cancel()
self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")
async def _read_messages(self, callback):
"""Read messages from gateway server."""
while self.ws is not None:
message = None
if not self.disconnected:
try:
message = await self.ws.read_message()
except Exception as e:
self.log.error(
f"Exception reading message from websocket: {e}"
) # , exc_info=True)
if message is None:
if not self.disconnected:
self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
break
callback(
message
) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
else: # ws cancelled - stop reading
break
# NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
jitter = random.randint(10, 100) * 0.01 # noqa: S311
retry_interval = (
min(
GatewayClient.instance().gateway_retry_interval * (2**self.retry),
GatewayClient.instance().gateway_retry_interval_max,
)
+ jitter
)
self.retry += 1
self.log.info(
"Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
retry_interval,
self.retry,
GatewayClient.instance().gateway_retry_max,
self.kernel_id,
)
await asyncio.sleep(retry_interval)
loop = IOLoop.current()
loop.spawn_callback(self._connect, self.kernel_id, callback)
def on_open(self, kernel_id, message_callback, **kwargs):
"""Web socket connection open against gateway server."""
loop = IOLoop.current()
loop.spawn_callback(self._connect, kernel_id, message_callback)
def on_message(self, message):
"""Send message to gateway server."""
if self.ws is None:
loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self._write_message(message))
else:
self._write_message(message)
def _write_message(self, message):
"""Send message to gateway server."""
try:
if not self.disconnected and self.ws is not None:
self.ws.write_message(message)
except Exception as e:
self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)
def on_close(self):
"""Web socket closed event."""
self._disconnect()
class GatewayResourceHandler(APIHandler):
"""Retrieves resources for specific kernelspec definitions from kernel/enterprise gateway."""
@web.authenticated
async def get(self, kernel_name, path, include_body=True):
"""Get a gateway resource by name and path."""
mimetype: Optional[str] = None
ksm = self.kernel_spec_manager
kernel_spec_res = await ksm.get_kernel_spec_resource( # type:ignore[attr-defined]
kernel_name, path
)
if kernel_spec_res is None:
self.log.warning(
f"Kernelspec resource '{path}' for '{kernel_name}' not found. Gateway may not support"
" resource serving."
)
else:
mimetype = mimetypes.guess_type(path)[0] or "text/plain"
self.finish(kernel_spec_res, set_content_type=mimetype)
from ..services.kernels.handlers import _kernel_id_regex
from ..services.kernelspecs.handlers import kernel_name_regex
default_handlers = [
(r"/api/kernels/%s/channels" % _kernel_id_regex, WebSocketChannelsHandler),
(r"/kernelspecs/%s/(?P<path>.*)" % kernel_name_regex, GatewayResourceHandler),
]