"""Gateway connection classes.""" # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. from __future__ import annotations import asyncio import logging import random from typing import Any, cast import tornado.websocket as tornado_websocket from tornado.concurrent import Future from tornado.escape import json_decode, url_escape, utf8 from tornado.httpclient import HTTPRequest from tornado.ioloop import IOLoop from traitlets import Bool, Instance, Int, Unicode from ..services.kernels.connection.base import BaseKernelWebsocketConnection from ..utils import url_path_join from .gateway_client import GatewayClient class GatewayWebSocketConnection(BaseKernelWebsocketConnection): """Web socket connection that proxies to a kernel/enterprise gateway.""" ws = Instance(klass=tornado_websocket.WebSocketClientConnection, allow_none=True) ws_future = Instance(klass=Future, allow_none=True) disconnected = Bool(False) retry = Int(0) # When opening ws connection to gateway, server already negotiated subprotocol with notebook client. # Same protocol must be used for client and gateway, so legacy ws subprotocol for client is enforced here. kernel_ws_protocol = Unicode("", allow_none=True, config=True) async def connect(self): """Connect to the socket.""" # websocket is initialized before connection self.ws = None ws_url = url_path_join( GatewayClient.instance().ws_url or "", GatewayClient.instance().kernels_endpoint, url_escape(self.kernel_id), "channels", ) if self.session_id: ws_url += f"?session_id={url_escape(self.session_id)}" self.log.info(f"Connecting to {ws_url}") kwargs: dict[str, Any] = {} kwargs = GatewayClient.instance().load_connection_args(**kwargs) request = HTTPRequest(ws_url, **kwargs) self.ws_future = cast("Future[Any]", tornado_websocket.websocket_connect(request)) self.ws_future.add_done_callback(self._connection_done) loop = IOLoop.current() loop.add_future(self.ws_future, lambda future: self._read_messages()) def _connection_done(self, fut): """Handle a finished connection.""" if ( not self.disconnected and fut.exception() is None ): # prevent concurrent.futures._base.CancelledError self.ws = fut.result() self.retry = 0 self.log.debug(f"Connection is ready: ws: {self.ws}") else: self.log.warning( "Websocket connection has been closed via client disconnect or due to error. " f"Kernel with ID '{self.kernel_id}' may not be terminated on GatewayClient: {GatewayClient.instance().url}" ) def disconnect(self): """Handle a disconnect.""" self.disconnected = True if self.ws is not None: # Close connection self.ws.close() elif self.ws_future and not self.ws_future.done(): # Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally self.ws_future.cancel() self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}") async def _read_messages(self): """Read messages from gateway server.""" while self.ws is not None: message = None if not self.disconnected: try: message = await self.ws.read_message() except Exception as e: self.log.error( f"Exception reading message from websocket: {e}" ) # , exc_info=True) if message is None: if not self.disconnected: self.log.warning(f"Lost connection to Gateway: {self.kernel_id}") break if isinstance(message, bytes): message = message.decode("utf8") self.handle_outgoing_message( message ) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open) else: # ws cancelled - stop reading break # NOTE(esevan): if websocket is not disconnected by client, try to reconnect. if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max: jitter = random.randint(10, 100) * 0.01 # noqa: S311 retry_interval = ( min( GatewayClient.instance().gateway_retry_interval * (2**self.retry), GatewayClient.instance().gateway_retry_interval_max, ) + jitter ) self.retry += 1 self.log.info( "Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s", retry_interval, self.retry, GatewayClient.instance().gateway_retry_max, self.kernel_id, ) await asyncio.sleep(retry_interval) loop = IOLoop.current() loop.spawn_callback(self.connect) def handle_outgoing_message(self, incoming_msg: str, *args: Any) -> None: """Send message to the notebook client.""" try: self.websocket_handler.write_message(incoming_msg) except tornado_websocket.WebSocketClosedError: if self.log.isEnabledFor(logging.DEBUG): msg_summary = GatewayWebSocketConnection._get_message_summary( json_decode(utf8(incoming_msg)) ) self.log.debug( f"Notebook client closed websocket connection - message dropped: {msg_summary}" ) def handle_incoming_message(self, message: str) -> None: """Send message to gateway server.""" if self.ws is None and self.ws_future is not None: loop = IOLoop.current() loop.add_future(self.ws_future, lambda future: self.handle_incoming_message(message)) else: self._write_message(message) def _write_message(self, message): """Send message to gateway server.""" try: if not self.disconnected and self.ws is not None: self.ws.write_message(message) except Exception as e: self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True) @staticmethod def _get_message_summary(message): """Get a summary of a message.""" summary = [] message_type = message["msg_type"] summary.append(f"type: {message_type}") if message_type == "status": summary.append(", state: {}".format(message["content"]["execution_state"])) elif message_type == "error": summary.append( ", {}:{}:{}".format( message["content"]["ename"], message["content"]["evalue"], message["content"]["traceback"], ) ) else: summary.append(", ...") # don't display potentially sensitive data return "".join(summary)