File size: 7,361 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
"""Gateway connection classes."""
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from __future__ import annotations
import asyncio
import logging
import random
from typing import Any, cast
import tornado.websocket as tornado_websocket
from tornado.concurrent import Future
from tornado.escape import json_decode, url_escape, utf8
from tornado.httpclient import HTTPRequest
from tornado.ioloop import IOLoop
from traitlets import Bool, Instance, Int, Unicode
from ..services.kernels.connection.base import BaseKernelWebsocketConnection
from ..utils import url_path_join
from .gateway_client import GatewayClient
class GatewayWebSocketConnection(BaseKernelWebsocketConnection):
"""Web socket connection that proxies to a kernel/enterprise gateway."""
ws = Instance(klass=tornado_websocket.WebSocketClientConnection, allow_none=True)
ws_future = Instance(klass=Future, allow_none=True)
disconnected = Bool(False)
retry = Int(0)
# When opening ws connection to gateway, server already negotiated subprotocol with notebook client.
# Same protocol must be used for client and gateway, so legacy ws subprotocol for client is enforced here.
kernel_ws_protocol = Unicode("", allow_none=True, config=True)
async def connect(self):
"""Connect to the socket."""
# websocket is initialized before connection
self.ws = None
ws_url = url_path_join(
GatewayClient.instance().ws_url or "",
GatewayClient.instance().kernels_endpoint,
url_escape(self.kernel_id),
"channels",
)
if self.session_id:
ws_url += f"?session_id={url_escape(self.session_id)}"
self.log.info(f"Connecting to {ws_url}")
kwargs: dict[str, Any] = {}
kwargs = GatewayClient.instance().load_connection_args(**kwargs)
request = HTTPRequest(ws_url, **kwargs)
self.ws_future = cast("Future[Any]", tornado_websocket.websocket_connect(request))
self.ws_future.add_done_callback(self._connection_done)
loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self._read_messages())
def _connection_done(self, fut):
"""Handle a finished connection."""
if (
not self.disconnected and fut.exception() is None
): # prevent concurrent.futures._base.CancelledError
self.ws = fut.result()
self.retry = 0
self.log.debug(f"Connection is ready: ws: {self.ws}")
else:
self.log.warning(
"Websocket connection has been closed via client disconnect or due to error. "
f"Kernel with ID '{self.kernel_id}' may not be terminated on GatewayClient: {GatewayClient.instance().url}"
)
def disconnect(self):
"""Handle a disconnect."""
self.disconnected = True
if self.ws is not None:
# Close connection
self.ws.close()
elif self.ws_future and not self.ws_future.done():
# Cancel pending connection. Since future.cancel() is a noop on tornado, we'll track cancellation locally
self.ws_future.cancel()
self.log.debug(f"_disconnect: future cancelled, disconnected: {self.disconnected}")
async def _read_messages(self):
"""Read messages from gateway server."""
while self.ws is not None:
message = None
if not self.disconnected:
try:
message = await self.ws.read_message()
except Exception as e:
self.log.error(
f"Exception reading message from websocket: {e}"
) # , exc_info=True)
if message is None:
if not self.disconnected:
self.log.warning(f"Lost connection to Gateway: {self.kernel_id}")
break
if isinstance(message, bytes):
message = message.decode("utf8")
self.handle_outgoing_message(
message
) # pass back to notebook client (see self.on_open and WebSocketChannelsHandler.open)
else: # ws cancelled - stop reading
break
# NOTE(esevan): if websocket is not disconnected by client, try to reconnect.
if not self.disconnected and self.retry < GatewayClient.instance().gateway_retry_max:
jitter = random.randint(10, 100) * 0.01 # noqa: S311
retry_interval = (
min(
GatewayClient.instance().gateway_retry_interval * (2**self.retry),
GatewayClient.instance().gateway_retry_interval_max,
)
+ jitter
)
self.retry += 1
self.log.info(
"Attempting to re-establish the connection to Gateway in %s secs (%s/%s): %s",
retry_interval,
self.retry,
GatewayClient.instance().gateway_retry_max,
self.kernel_id,
)
await asyncio.sleep(retry_interval)
loop = IOLoop.current()
loop.spawn_callback(self.connect)
def handle_outgoing_message(self, incoming_msg: str, *args: Any) -> None:
"""Send message to the notebook client."""
try:
self.websocket_handler.write_message(incoming_msg)
except tornado_websocket.WebSocketClosedError:
if self.log.isEnabledFor(logging.DEBUG):
msg_summary = GatewayWebSocketConnection._get_message_summary(
json_decode(utf8(incoming_msg))
)
self.log.debug(
f"Notebook client closed websocket connection - message dropped: {msg_summary}"
)
def handle_incoming_message(self, message: str) -> None:
"""Send message to gateway server."""
if self.ws is None and self.ws_future is not None:
loop = IOLoop.current()
loop.add_future(self.ws_future, lambda future: self.handle_incoming_message(message))
else:
self._write_message(message)
def _write_message(self, message):
"""Send message to gateway server."""
try:
if not self.disconnected and self.ws is not None:
self.ws.write_message(message)
except Exception as e:
self.log.error(f"Exception writing message to websocket: {e}") # , exc_info=True)
@staticmethod
def _get_message_summary(message):
"""Get a summary of a message."""
summary = []
message_type = message["msg_type"]
summary.append(f"type: {message_type}")
if message_type == "status":
summary.append(", state: {}".format(message["content"]["execution_state"]))
elif message_type == "error":
summary.append(
", {}:{}:{}".format(
message["content"]["ename"],
message["content"]["evalue"],
message["content"]["traceback"],
)
)
else:
summary.append(", ...") # don't display potentially sensitive data
return "".join(summary)
|