|
"""Implementation of the WebSocket protocol. |
|
|
|
`WebSockets <http://dev.w3.org/html5/websockets/>`_ allow for bidirectional |
|
communication between the browser and server. WebSockets are supported in the |
|
current versions of all major browsers. |
|
|
|
This module implements the final version of the WebSocket protocol as |
|
defined in `RFC 6455 <http://tools.ietf.org/html/rfc6455>`_. |
|
|
|
.. versionchanged:: 4.0 |
|
Removed support for the draft 76 protocol version. |
|
""" |
|
|
|
import abc |
|
import asyncio |
|
import base64 |
|
import hashlib |
|
import os |
|
import sys |
|
import struct |
|
import tornado |
|
from urllib.parse import urlparse |
|
import warnings |
|
import zlib |
|
|
|
from tornado.concurrent import Future, future_set_result_unless_cancelled |
|
from tornado.escape import utf8, native_str, to_unicode |
|
from tornado import gen, httpclient, httputil |
|
from tornado.ioloop import IOLoop, PeriodicCallback |
|
from tornado.iostream import StreamClosedError, IOStream |
|
from tornado.log import gen_log, app_log |
|
from tornado.netutil import Resolver |
|
from tornado import simple_httpclient |
|
from tornado.queues import Queue |
|
from tornado.tcpclient import TCPClient |
|
from tornado.util import _websocket_mask |
|
|
|
from typing import ( |
|
TYPE_CHECKING, |
|
cast, |
|
Any, |
|
Optional, |
|
Dict, |
|
Union, |
|
List, |
|
Awaitable, |
|
Callable, |
|
Tuple, |
|
Type, |
|
) |
|
from types import TracebackType |
|
|
|
if TYPE_CHECKING: |
|
from typing_extensions import Protocol |
|
|
|
|
|
|
|
class _Compressor(Protocol): |
|
def compress(self, data: bytes) -> bytes: |
|
pass |
|
|
|
def flush(self, mode: int) -> bytes: |
|
pass |
|
|
|
class _Decompressor(Protocol): |
|
unconsumed_tail = b"" |
|
|
|
def decompress(self, data: bytes, max_length: int) -> bytes: |
|
pass |
|
|
|
class _WebSocketDelegate(Protocol): |
|
|
|
|
|
|
|
def on_ws_connection_close( |
|
self, close_code: Optional[int] = None, close_reason: Optional[str] = None |
|
) -> None: |
|
pass |
|
|
|
def on_message(self, message: Union[str, bytes]) -> Optional["Awaitable[None]"]: |
|
pass |
|
|
|
def on_ping(self, data: bytes) -> None: |
|
pass |
|
|
|
def on_pong(self, data: bytes) -> None: |
|
pass |
|
|
|
def log_exception( |
|
self, |
|
typ: Optional[Type[BaseException]], |
|
value: Optional[BaseException], |
|
tb: Optional[TracebackType], |
|
) -> None: |
|
pass |
|
|
|
|
|
_default_max_message_size = 10 * 1024 * 1024 |
|
|
|
|
|
class WebSocketError(Exception): |
|
pass |
|
|
|
|
|
class WebSocketClosedError(WebSocketError): |
|
"""Raised by operations on a closed connection. |
|
|
|
.. versionadded:: 3.2 |
|
""" |
|
|
|
pass |
|
|
|
|
|
class _DecompressTooLargeError(Exception): |
|
pass |
|
|
|
|
|
class _WebSocketParams(object): |
|
def __init__( |
|
self, |
|
ping_interval: Optional[float] = None, |
|
ping_timeout: Optional[float] = None, |
|
max_message_size: int = _default_max_message_size, |
|
compression_options: Optional[Dict[str, Any]] = None, |
|
) -> None: |
|
self.ping_interval = ping_interval |
|
self.ping_timeout = ping_timeout |
|
self.max_message_size = max_message_size |
|
self.compression_options = compression_options |
|
|
|
|
|
class WebSocketHandler(tornado.web.RequestHandler): |
|
"""Subclass this class to create a basic WebSocket handler. |
|
|
|
Override `on_message` to handle incoming messages, and use |
|
`write_message` to send messages to the client. You can also |
|
override `open` and `on_close` to handle opened and closed |
|
connections. |
|
|
|
Custom upgrade response headers can be sent by overriding |
|
`~tornado.web.RequestHandler.set_default_headers` or |
|
`~tornado.web.RequestHandler.prepare`. |
|
|
|
See http://dev.w3.org/html5/websockets/ for details on the |
|
JavaScript interface. The protocol is specified at |
|
http://tools.ietf.org/html/rfc6455. |
|
|
|
Here is an example WebSocket handler that echos back all received messages |
|
back to the client: |
|
|
|
.. testcode:: |
|
|
|
class EchoWebSocket(tornado.websocket.WebSocketHandler): |
|
def open(self): |
|
print("WebSocket opened") |
|
|
|
def on_message(self, message): |
|
self.write_message(u"You said: " + message) |
|
|
|
def on_close(self): |
|
print("WebSocket closed") |
|
|
|
.. testoutput:: |
|
:hide: |
|
|
|
WebSockets are not standard HTTP connections. The "handshake" is |
|
HTTP, but after the handshake, the protocol is |
|
message-based. Consequently, most of the Tornado HTTP facilities |
|
are not available in handlers of this type. The only communication |
|
methods available to you are `write_message()`, `ping()`, and |
|
`close()`. Likewise, your request handler class should implement |
|
`open()` method rather than ``get()`` or ``post()``. |
|
|
|
If you map the handler above to ``/websocket`` in your application, you can |
|
invoke it in JavaScript with:: |
|
|
|
var ws = new WebSocket("ws://localhost:8888/websocket"); |
|
ws.onopen = function() { |
|
ws.send("Hello, world"); |
|
}; |
|
ws.onmessage = function (evt) { |
|
alert(evt.data); |
|
}; |
|
|
|
This script pops up an alert box that says "You said: Hello, world". |
|
|
|
Web browsers allow any site to open a websocket connection to any other, |
|
instead of using the same-origin policy that governs other network |
|
access from JavaScript. This can be surprising and is a potential |
|
security hole, so since Tornado 4.0 `WebSocketHandler` requires |
|
applications that wish to receive cross-origin websockets to opt in |
|
by overriding the `~WebSocketHandler.check_origin` method (see that |
|
method's docs for details). Failure to do so is the most likely |
|
cause of 403 errors when making a websocket connection. |
|
|
|
When using a secure websocket connection (``wss://``) with a self-signed |
|
certificate, the connection from a browser may fail because it wants |
|
to show the "accept this certificate" dialog but has nowhere to show it. |
|
You must first visit a regular HTML page using the same certificate |
|
to accept it before the websocket connection will succeed. |
|
|
|
If the application setting ``websocket_ping_interval`` has a non-zero |
|
value, a ping will be sent periodically, and the connection will be |
|
closed if a response is not received before the ``websocket_ping_timeout``. |
|
|
|
Messages larger than the ``websocket_max_message_size`` application setting |
|
(default 10MiB) will not be accepted. |
|
|
|
.. versionchanged:: 4.5 |
|
Added ``websocket_ping_interval``, ``websocket_ping_timeout``, and |
|
``websocket_max_message_size``. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
application: tornado.web.Application, |
|
request: httputil.HTTPServerRequest, |
|
**kwargs: Any |
|
) -> None: |
|
super().__init__(application, request, **kwargs) |
|
self.ws_connection = None |
|
self.close_code = None |
|
self.close_reason = None |
|
self._on_close_called = False |
|
|
|
async def get(self, *args: Any, **kwargs: Any) -> None: |
|
self.open_args = args |
|
self.open_kwargs = kwargs |
|
|
|
|
|
if self.request.headers.get("Upgrade", "").lower() != "websocket": |
|
self.set_status(400) |
|
log_msg = 'Can "Upgrade" only to "WebSocket".' |
|
self.finish(log_msg) |
|
gen_log.debug(log_msg) |
|
return |
|
|
|
|
|
|
|
|
|
headers = self.request.headers |
|
connection = map( |
|
lambda s: s.strip().lower(), headers.get("Connection", "").split(",") |
|
) |
|
if "upgrade" not in connection: |
|
self.set_status(400) |
|
log_msg = '"Connection" must be "Upgrade".' |
|
self.finish(log_msg) |
|
gen_log.debug(log_msg) |
|
return |
|
|
|
|
|
|
|
|
|
|
|
if "Origin" in self.request.headers: |
|
origin = self.request.headers.get("Origin") |
|
else: |
|
origin = self.request.headers.get("Sec-Websocket-Origin", None) |
|
|
|
|
|
|
|
|
|
if origin is not None and not self.check_origin(origin): |
|
self.set_status(403) |
|
log_msg = "Cross origin websockets not allowed" |
|
self.finish(log_msg) |
|
gen_log.debug(log_msg) |
|
return |
|
|
|
self.ws_connection = self.get_websocket_protocol() |
|
if self.ws_connection: |
|
await self.ws_connection.accept_connection(self) |
|
else: |
|
self.set_status(426, "Upgrade Required") |
|
self.set_header("Sec-WebSocket-Version", "7, 8, 13") |
|
|
|
@property |
|
def ping_interval(self) -> Optional[float]: |
|
"""The interval for websocket keep-alive pings. |
|
|
|
Set websocket_ping_interval = 0 to disable pings. |
|
""" |
|
return self.settings.get("websocket_ping_interval", None) |
|
|
|
@property |
|
def ping_timeout(self) -> Optional[float]: |
|
"""If no ping is received in this many seconds, |
|
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). |
|
Default is max of 3 pings or 30 seconds. |
|
""" |
|
return self.settings.get("websocket_ping_timeout", None) |
|
|
|
@property |
|
def max_message_size(self) -> int: |
|
"""Maximum allowed message size. |
|
|
|
If the remote peer sends a message larger than this, the connection |
|
will be closed. |
|
|
|
Default is 10MiB. |
|
""" |
|
return self.settings.get( |
|
"websocket_max_message_size", _default_max_message_size |
|
) |
|
|
|
def write_message( |
|
self, message: Union[bytes, str, Dict[str, Any]], binary: bool = False |
|
) -> "Future[None]": |
|
"""Sends the given message to the client of this Web Socket. |
|
|
|
The message may be either a string or a dict (which will be |
|
encoded as json). If the ``binary`` argument is false, the |
|
message will be sent as utf8; in binary mode any byte string |
|
is allowed. |
|
|
|
If the connection is already closed, raises `WebSocketClosedError`. |
|
Returns a `.Future` which can be used for flow control. |
|
|
|
.. versionchanged:: 3.2 |
|
`WebSocketClosedError` was added (previously a closed connection |
|
would raise an `AttributeError`) |
|
|
|
.. versionchanged:: 4.3 |
|
Returns a `.Future` which can be used for flow control. |
|
|
|
.. versionchanged:: 5.0 |
|
Consistently raises `WebSocketClosedError`. Previously could |
|
sometimes raise `.StreamClosedError`. |
|
""" |
|
if self.ws_connection is None or self.ws_connection.is_closing(): |
|
raise WebSocketClosedError() |
|
if isinstance(message, dict): |
|
message = tornado.escape.json_encode(message) |
|
return self.ws_connection.write_message(message, binary=binary) |
|
|
|
def select_subprotocol(self, subprotocols: List[str]) -> Optional[str]: |
|
"""Override to implement subprotocol negotiation. |
|
|
|
``subprotocols`` is a list of strings identifying the |
|
subprotocols proposed by the client. This method may be |
|
overridden to return one of those strings to select it, or |
|
``None`` to not select a subprotocol. |
|
|
|
Failure to select a subprotocol does not automatically abort |
|
the connection, although clients may close the connection if |
|
none of their proposed subprotocols was selected. |
|
|
|
The list may be empty, in which case this method must return |
|
None. This method is always called exactly once even if no |
|
subprotocols were proposed so that the handler can be advised |
|
of this fact. |
|
|
|
.. versionchanged:: 5.1 |
|
|
|
Previously, this method was called with a list containing |
|
an empty string instead of an empty list if no subprotocols |
|
were proposed by the client. |
|
""" |
|
return None |
|
|
|
@property |
|
def selected_subprotocol(self) -> Optional[str]: |
|
"""The subprotocol returned by `select_subprotocol`. |
|
|
|
.. versionadded:: 5.1 |
|
""" |
|
assert self.ws_connection is not None |
|
return self.ws_connection.selected_subprotocol |
|
|
|
def get_compression_options(self) -> Optional[Dict[str, Any]]: |
|
"""Override to return compression options for the connection. |
|
|
|
If this method returns None (the default), compression will |
|
be disabled. If it returns a dict (even an empty one), it |
|
will be enabled. The contents of the dict may be used to |
|
control the following compression options: |
|
|
|
``compression_level`` specifies the compression level. |
|
|
|
``mem_level`` specifies the amount of memory used for the internal compression state. |
|
|
|
These parameters are documented in details here: |
|
https://docs.python.org/3.6/library/zlib.html#zlib.compressobj |
|
|
|
.. versionadded:: 4.1 |
|
|
|
.. versionchanged:: 4.5 |
|
|
|
Added ``compression_level`` and ``mem_level``. |
|
""" |
|
|
|
return None |
|
|
|
def open(self, *args: str, **kwargs: str) -> Optional[Awaitable[None]]: |
|
"""Invoked when a new WebSocket is opened. |
|
|
|
The arguments to `open` are extracted from the `tornado.web.URLSpec` |
|
regular expression, just like the arguments to |
|
`tornado.web.RequestHandler.get`. |
|
|
|
`open` may be a coroutine. `on_message` will not be called until |
|
`open` has returned. |
|
|
|
.. versionchanged:: 5.1 |
|
|
|
``open`` may be a coroutine. |
|
""" |
|
pass |
|
|
|
def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: |
|
"""Handle incoming messages on the WebSocket |
|
|
|
This method must be overridden. |
|
|
|
.. versionchanged:: 4.5 |
|
|
|
``on_message`` can be a coroutine. |
|
""" |
|
raise NotImplementedError |
|
|
|
def ping(self, data: Union[str, bytes] = b"") -> None: |
|
"""Send ping frame to the remote end. |
|
|
|
The data argument allows a small amount of data (up to 125 |
|
bytes) to be sent as a part of the ping message. Note that not |
|
all websocket implementations expose this data to |
|
applications. |
|
|
|
Consider using the ``websocket_ping_interval`` application |
|
setting instead of sending pings manually. |
|
|
|
.. versionchanged:: 5.1 |
|
|
|
The data argument is now optional. |
|
|
|
""" |
|
data = utf8(data) |
|
if self.ws_connection is None or self.ws_connection.is_closing(): |
|
raise WebSocketClosedError() |
|
self.ws_connection.write_ping(data) |
|
|
|
def on_pong(self, data: bytes) -> None: |
|
"""Invoked when the response to a ping frame is received.""" |
|
pass |
|
|
|
def on_ping(self, data: bytes) -> None: |
|
"""Invoked when the a ping frame is received.""" |
|
pass |
|
|
|
def on_close(self) -> None: |
|
"""Invoked when the WebSocket is closed. |
|
|
|
If the connection was closed cleanly and a status code or reason |
|
phrase was supplied, these values will be available as the attributes |
|
``self.close_code`` and ``self.close_reason``. |
|
|
|
.. versionchanged:: 4.0 |
|
|
|
Added ``close_code`` and ``close_reason`` attributes. |
|
""" |
|
pass |
|
|
|
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: |
|
"""Closes this Web Socket. |
|
|
|
Once the close handshake is successful the socket will be closed. |
|
|
|
``code`` may be a numeric status code, taken from the values |
|
defined in `RFC 6455 section 7.4.1 |
|
<https://tools.ietf.org/html/rfc6455#section-7.4.1>`_. |
|
``reason`` may be a textual message about why the connection is |
|
closing. These values are made available to the client, but are |
|
not otherwise interpreted by the websocket protocol. |
|
|
|
.. versionchanged:: 4.0 |
|
|
|
Added the ``code`` and ``reason`` arguments. |
|
""" |
|
if self.ws_connection: |
|
self.ws_connection.close(code, reason) |
|
self.ws_connection = None |
|
|
|
def check_origin(self, origin: str) -> bool: |
|
"""Override to enable support for allowing alternate origins. |
|
|
|
The ``origin`` argument is the value of the ``Origin`` HTTP |
|
header, the url responsible for initiating this request. This |
|
method is not called for clients that do not send this header; |
|
such requests are always allowed (because all browsers that |
|
implement WebSockets support this header, and non-browser |
|
clients do not have the same cross-site security concerns). |
|
|
|
Should return ``True`` to accept the request or ``False`` to |
|
reject it. By default, rejects all requests with an origin on |
|
a host other than this one. |
|
|
|
This is a security protection against cross site scripting attacks on |
|
browsers, since WebSockets are allowed to bypass the usual same-origin |
|
policies and don't use CORS headers. |
|
|
|
.. warning:: |
|
|
|
This is an important security measure; don't disable it |
|
without understanding the security implications. In |
|
particular, if your authentication is cookie-based, you |
|
must either restrict the origins allowed by |
|
``check_origin()`` or implement your own XSRF-like |
|
protection for websocket connections. See `these |
|
<https://www.christian-schneider.net/CrossSiteWebSocketHijacking.html>`_ |
|
`articles |
|
<https://devcenter.heroku.com/articles/websocket-security>`_ |
|
for more. |
|
|
|
To accept all cross-origin traffic (which was the default prior to |
|
Tornado 4.0), simply override this method to always return ``True``:: |
|
|
|
def check_origin(self, origin): |
|
return True |
|
|
|
To allow connections from any subdomain of your site, you might |
|
do something like:: |
|
|
|
def check_origin(self, origin): |
|
parsed_origin = urllib.parse.urlparse(origin) |
|
return parsed_origin.netloc.endswith(".mydomain.com") |
|
|
|
.. versionadded:: 4.0 |
|
|
|
""" |
|
parsed_origin = urlparse(origin) |
|
origin = parsed_origin.netloc |
|
origin = origin.lower() |
|
|
|
host = self.request.headers.get("Host") |
|
|
|
|
|
return origin == host |
|
|
|
def set_nodelay(self, value: bool) -> None: |
|
"""Set the no-delay flag for this stream. |
|
|
|
By default, small messages may be delayed and/or combined to minimize |
|
the number of packets sent. This can sometimes cause 200-500ms delays |
|
due to the interaction between Nagle's algorithm and TCP delayed |
|
ACKs. To reduce this delay (at the expense of possibly increasing |
|
bandwidth usage), call ``self.set_nodelay(True)`` once the websocket |
|
connection is established. |
|
|
|
See `.BaseIOStream.set_nodelay` for additional details. |
|
|
|
.. versionadded:: 3.1 |
|
""" |
|
assert self.ws_connection is not None |
|
self.ws_connection.set_nodelay(value) |
|
|
|
def on_connection_close(self) -> None: |
|
if self.ws_connection: |
|
self.ws_connection.on_connection_close() |
|
self.ws_connection = None |
|
if not self._on_close_called: |
|
self._on_close_called = True |
|
self.on_close() |
|
self._break_cycles() |
|
|
|
def on_ws_connection_close( |
|
self, close_code: Optional[int] = None, close_reason: Optional[str] = None |
|
) -> None: |
|
self.close_code = close_code |
|
self.close_reason = close_reason |
|
self.on_connection_close() |
|
|
|
def _break_cycles(self) -> None: |
|
|
|
|
|
|
|
|
|
|
|
if self.get_status() != 101 or self._on_close_called: |
|
super()._break_cycles() |
|
|
|
def get_websocket_protocol(self) -> Optional["WebSocketProtocol"]: |
|
websocket_version = self.request.headers.get("Sec-WebSocket-Version") |
|
if websocket_version in ("7", "8", "13"): |
|
params = _WebSocketParams( |
|
ping_interval=self.ping_interval, |
|
ping_timeout=self.ping_timeout, |
|
max_message_size=self.max_message_size, |
|
compression_options=self.get_compression_options(), |
|
) |
|
return WebSocketProtocol13(self, False, params) |
|
return None |
|
|
|
def _detach_stream(self) -> IOStream: |
|
|
|
for method in [ |
|
"write", |
|
"redirect", |
|
"set_header", |
|
"set_cookie", |
|
"set_status", |
|
"flush", |
|
"finish", |
|
]: |
|
setattr(self, method, _raise_not_supported_for_websockets) |
|
return self.detach() |
|
|
|
|
|
def _raise_not_supported_for_websockets(*args: Any, **kwargs: Any) -> None: |
|
raise RuntimeError("Method not supported for Web Sockets") |
|
|
|
|
|
class WebSocketProtocol(abc.ABC): |
|
"""Base class for WebSocket protocol versions.""" |
|
|
|
def __init__(self, handler: "_WebSocketDelegate") -> None: |
|
self.handler = handler |
|
self.stream = None |
|
self.client_terminated = False |
|
self.server_terminated = False |
|
|
|
def _run_callback( |
|
self, callback: Callable, *args: Any, **kwargs: Any |
|
) -> "Optional[Future[Any]]": |
|
"""Runs the given callback with exception handling. |
|
|
|
If the callback is a coroutine, returns its Future. On error, aborts the |
|
websocket connection and returns None. |
|
""" |
|
try: |
|
result = callback(*args, **kwargs) |
|
except Exception: |
|
self.handler.log_exception(*sys.exc_info()) |
|
self._abort() |
|
return None |
|
else: |
|
if result is not None: |
|
result = gen.convert_yielded(result) |
|
assert self.stream is not None |
|
self.stream.io_loop.add_future(result, lambda f: f.result()) |
|
return result |
|
|
|
def on_connection_close(self) -> None: |
|
self._abort() |
|
|
|
def _abort(self) -> None: |
|
"""Instantly aborts the WebSocket connection by closing the socket""" |
|
self.client_terminated = True |
|
self.server_terminated = True |
|
if self.stream is not None: |
|
self.stream.close() |
|
self.close() |
|
|
|
@abc.abstractmethod |
|
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: |
|
raise NotImplementedError() |
|
|
|
@abc.abstractmethod |
|
def is_closing(self) -> bool: |
|
raise NotImplementedError() |
|
|
|
@abc.abstractmethod |
|
async def accept_connection(self, handler: WebSocketHandler) -> None: |
|
raise NotImplementedError() |
|
|
|
@abc.abstractmethod |
|
def write_message( |
|
self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False |
|
) -> "Future[None]": |
|
raise NotImplementedError() |
|
|
|
@property |
|
@abc.abstractmethod |
|
def selected_subprotocol(self) -> Optional[str]: |
|
raise NotImplementedError() |
|
|
|
@abc.abstractmethod |
|
def write_ping(self, data: bytes) -> None: |
|
raise NotImplementedError() |
|
|
|
|
|
|
|
|
|
|
|
@abc.abstractmethod |
|
def _process_server_headers( |
|
self, key: Union[str, bytes], headers: httputil.HTTPHeaders |
|
) -> None: |
|
raise NotImplementedError() |
|
|
|
@abc.abstractmethod |
|
def start_pinging(self) -> None: |
|
raise NotImplementedError() |
|
|
|
@abc.abstractmethod |
|
async def _receive_frame_loop(self) -> None: |
|
raise NotImplementedError() |
|
|
|
@abc.abstractmethod |
|
def set_nodelay(self, x: bool) -> None: |
|
raise NotImplementedError() |
|
|
|
|
|
class _PerMessageDeflateCompressor(object): |
|
def __init__( |
|
self, |
|
persistent: bool, |
|
max_wbits: Optional[int], |
|
compression_options: Optional[Dict[str, Any]] = None, |
|
) -> None: |
|
if max_wbits is None: |
|
max_wbits = zlib.MAX_WBITS |
|
|
|
if not (8 <= max_wbits <= zlib.MAX_WBITS): |
|
raise ValueError( |
|
"Invalid max_wbits value %r; allowed range 8-%d", |
|
max_wbits, |
|
zlib.MAX_WBITS, |
|
) |
|
self._max_wbits = max_wbits |
|
|
|
if ( |
|
compression_options is None |
|
or "compression_level" not in compression_options |
|
): |
|
self._compression_level = tornado.web.GZipContentEncoding.GZIP_LEVEL |
|
else: |
|
self._compression_level = compression_options["compression_level"] |
|
|
|
if compression_options is None or "mem_level" not in compression_options: |
|
self._mem_level = 8 |
|
else: |
|
self._mem_level = compression_options["mem_level"] |
|
|
|
if persistent: |
|
self._compressor = self._create_compressor() |
|
else: |
|
self._compressor = None |
|
|
|
def _create_compressor(self) -> "_Compressor": |
|
return zlib.compressobj( |
|
self._compression_level, zlib.DEFLATED, -self._max_wbits, self._mem_level |
|
) |
|
|
|
def compress(self, data: bytes) -> bytes: |
|
compressor = self._compressor or self._create_compressor() |
|
data = compressor.compress(data) + compressor.flush(zlib.Z_SYNC_FLUSH) |
|
assert data.endswith(b"\x00\x00\xff\xff") |
|
return data[:-4] |
|
|
|
|
|
class _PerMessageDeflateDecompressor(object): |
|
def __init__( |
|
self, |
|
persistent: bool, |
|
max_wbits: Optional[int], |
|
max_message_size: int, |
|
compression_options: Optional[Dict[str, Any]] = None, |
|
) -> None: |
|
self._max_message_size = max_message_size |
|
if max_wbits is None: |
|
max_wbits = zlib.MAX_WBITS |
|
if not (8 <= max_wbits <= zlib.MAX_WBITS): |
|
raise ValueError( |
|
"Invalid max_wbits value %r; allowed range 8-%d", |
|
max_wbits, |
|
zlib.MAX_WBITS, |
|
) |
|
self._max_wbits = max_wbits |
|
if persistent: |
|
self._decompressor = ( |
|
self._create_decompressor() |
|
) |
|
else: |
|
self._decompressor = None |
|
|
|
def _create_decompressor(self) -> "_Decompressor": |
|
return zlib.decompressobj(-self._max_wbits) |
|
|
|
def decompress(self, data: bytes) -> bytes: |
|
decompressor = self._decompressor or self._create_decompressor() |
|
result = decompressor.decompress( |
|
data + b"\x00\x00\xff\xff", self._max_message_size |
|
) |
|
if decompressor.unconsumed_tail: |
|
raise _DecompressTooLargeError() |
|
return result |
|
|
|
|
|
class WebSocketProtocol13(WebSocketProtocol): |
|
"""Implementation of the WebSocket protocol from RFC 6455. |
|
|
|
This class supports versions 7 and 8 of the protocol in addition to the |
|
final version 13. |
|
""" |
|
|
|
|
|
FIN = 0x80 |
|
RSV1 = 0x40 |
|
RSV2 = 0x20 |
|
RSV3 = 0x10 |
|
RSV_MASK = RSV1 | RSV2 | RSV3 |
|
OPCODE_MASK = 0x0F |
|
|
|
stream = None |
|
|
|
def __init__( |
|
self, |
|
handler: "_WebSocketDelegate", |
|
mask_outgoing: bool, |
|
params: _WebSocketParams, |
|
) -> None: |
|
WebSocketProtocol.__init__(self, handler) |
|
self.mask_outgoing = mask_outgoing |
|
self.params = params |
|
self._final_frame = False |
|
self._frame_opcode = None |
|
self._masked_frame = None |
|
self._frame_mask = None |
|
self._frame_length = None |
|
self._fragmented_message_buffer = None |
|
self._fragmented_message_opcode = None |
|
self._waiting = None |
|
self._compression_options = params.compression_options |
|
self._decompressor = None |
|
self._compressor = None |
|
self._frame_compressed = None |
|
|
|
|
|
|
|
self._message_bytes_in = 0 |
|
self._message_bytes_out = 0 |
|
|
|
|
|
self._wire_bytes_in = 0 |
|
self._wire_bytes_out = 0 |
|
self.ping_callback = None |
|
self.last_ping = 0.0 |
|
self.last_pong = 0.0 |
|
self.close_code = None |
|
self.close_reason = None |
|
|
|
|
|
@property |
|
def selected_subprotocol(self) -> Optional[str]: |
|
return self._selected_subprotocol |
|
|
|
@selected_subprotocol.setter |
|
def selected_subprotocol(self, value: Optional[str]) -> None: |
|
self._selected_subprotocol = value |
|
|
|
async def accept_connection(self, handler: WebSocketHandler) -> None: |
|
try: |
|
self._handle_websocket_headers(handler) |
|
except ValueError: |
|
handler.set_status(400) |
|
log_msg = "Missing/Invalid WebSocket headers" |
|
handler.finish(log_msg) |
|
gen_log.debug(log_msg) |
|
return |
|
|
|
try: |
|
await self._accept_connection(handler) |
|
except asyncio.CancelledError: |
|
self._abort() |
|
return |
|
except ValueError: |
|
gen_log.debug("Malformed WebSocket request received", exc_info=True) |
|
self._abort() |
|
return |
|
|
|
def _handle_websocket_headers(self, handler: WebSocketHandler) -> None: |
|
"""Verifies all invariant- and required headers |
|
|
|
If a header is missing or have an incorrect value ValueError will be |
|
raised |
|
""" |
|
fields = ("Host", "Sec-Websocket-Key", "Sec-Websocket-Version") |
|
if not all(map(lambda f: handler.request.headers.get(f), fields)): |
|
raise ValueError("Missing/Invalid WebSocket headers") |
|
|
|
@staticmethod |
|
def compute_accept_value(key: Union[str, bytes]) -> str: |
|
"""Computes the value for the Sec-WebSocket-Accept header, |
|
given the value for Sec-WebSocket-Key. |
|
""" |
|
sha1 = hashlib.sha1() |
|
sha1.update(utf8(key)) |
|
sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11") |
|
return native_str(base64.b64encode(sha1.digest())) |
|
|
|
def _challenge_response(self, handler: WebSocketHandler) -> str: |
|
return WebSocketProtocol13.compute_accept_value( |
|
cast(str, handler.request.headers.get("Sec-Websocket-Key")) |
|
) |
|
|
|
async def _accept_connection(self, handler: WebSocketHandler) -> None: |
|
subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") |
|
if subprotocol_header: |
|
subprotocols = [s.strip() for s in subprotocol_header.split(",")] |
|
else: |
|
subprotocols = [] |
|
self.selected_subprotocol = handler.select_subprotocol(subprotocols) |
|
if self.selected_subprotocol: |
|
assert self.selected_subprotocol in subprotocols |
|
handler.set_header("Sec-WebSocket-Protocol", self.selected_subprotocol) |
|
|
|
extensions = self._parse_extensions_header(handler.request.headers) |
|
for ext in extensions: |
|
if ext[0] == "permessage-deflate" and self._compression_options is not None: |
|
|
|
|
|
self._create_compressors("server", ext[1], self._compression_options) |
|
if ( |
|
"client_max_window_bits" in ext[1] |
|
and ext[1]["client_max_window_bits"] is None |
|
): |
|
|
|
|
|
del ext[1]["client_max_window_bits"] |
|
handler.set_header( |
|
"Sec-WebSocket-Extensions", |
|
httputil._encode_header("permessage-deflate", ext[1]), |
|
) |
|
break |
|
|
|
handler.clear_header("Content-Type") |
|
handler.set_status(101) |
|
handler.set_header("Upgrade", "websocket") |
|
handler.set_header("Connection", "Upgrade") |
|
handler.set_header("Sec-WebSocket-Accept", self._challenge_response(handler)) |
|
handler.finish() |
|
|
|
self.stream = handler._detach_stream() |
|
|
|
self.start_pinging() |
|
try: |
|
open_result = handler.open(*handler.open_args, **handler.open_kwargs) |
|
if open_result is not None: |
|
await open_result |
|
except Exception: |
|
handler.log_exception(*sys.exc_info()) |
|
self._abort() |
|
return |
|
|
|
await self._receive_frame_loop() |
|
|
|
def _parse_extensions_header( |
|
self, headers: httputil.HTTPHeaders |
|
) -> List[Tuple[str, Dict[str, str]]]: |
|
extensions = headers.get("Sec-WebSocket-Extensions", "") |
|
if extensions: |
|
return [httputil._parse_header(e.strip()) for e in extensions.split(",")] |
|
return [] |
|
|
|
def _process_server_headers( |
|
self, key: Union[str, bytes], headers: httputil.HTTPHeaders |
|
) -> None: |
|
"""Process the headers sent by the server to this client connection. |
|
|
|
'key' is the websocket handshake challenge/response key. |
|
""" |
|
assert headers["Upgrade"].lower() == "websocket" |
|
assert headers["Connection"].lower() == "upgrade" |
|
accept = self.compute_accept_value(key) |
|
assert headers["Sec-Websocket-Accept"] == accept |
|
|
|
extensions = self._parse_extensions_header(headers) |
|
for ext in extensions: |
|
if ext[0] == "permessage-deflate" and self._compression_options is not None: |
|
self._create_compressors("client", ext[1]) |
|
else: |
|
raise ValueError("unsupported extension %r", ext) |
|
|
|
self.selected_subprotocol = headers.get("Sec-WebSocket-Protocol", None) |
|
|
|
def _get_compressor_options( |
|
self, |
|
side: str, |
|
agreed_parameters: Dict[str, Any], |
|
compression_options: Optional[Dict[str, Any]] = None, |
|
) -> Dict[str, Any]: |
|
"""Converts a websocket agreed_parameters set to keyword arguments |
|
for our compressor objects. |
|
""" |
|
options = dict( |
|
persistent=(side + "_no_context_takeover") not in agreed_parameters |
|
) |
|
wbits_header = agreed_parameters.get(side + "_max_window_bits", None) |
|
if wbits_header is None: |
|
options["max_wbits"] = zlib.MAX_WBITS |
|
else: |
|
options["max_wbits"] = int(wbits_header) |
|
options["compression_options"] = compression_options |
|
return options |
|
|
|
def _create_compressors( |
|
self, |
|
side: str, |
|
agreed_parameters: Dict[str, Any], |
|
compression_options: Optional[Dict[str, Any]] = None, |
|
) -> None: |
|
|
|
allowed_keys = set( |
|
[ |
|
"server_no_context_takeover", |
|
"client_no_context_takeover", |
|
"server_max_window_bits", |
|
"client_max_window_bits", |
|
] |
|
) |
|
for key in agreed_parameters: |
|
if key not in allowed_keys: |
|
raise ValueError("unsupported compression parameter %r" % key) |
|
other_side = "client" if (side == "server") else "server" |
|
self._compressor = _PerMessageDeflateCompressor( |
|
**self._get_compressor_options(side, agreed_parameters, compression_options) |
|
) |
|
self._decompressor = _PerMessageDeflateDecompressor( |
|
max_message_size=self.params.max_message_size, |
|
**self._get_compressor_options( |
|
other_side, agreed_parameters, compression_options |
|
) |
|
) |
|
|
|
def _write_frame( |
|
self, fin: bool, opcode: int, data: bytes, flags: int = 0 |
|
) -> "Future[None]": |
|
data_len = len(data) |
|
if opcode & 0x8: |
|
|
|
|
|
if not fin: |
|
raise ValueError("control frames may not be fragmented") |
|
if data_len > 125: |
|
raise ValueError("control frame payloads may not exceed 125 bytes") |
|
if fin: |
|
finbit = self.FIN |
|
else: |
|
finbit = 0 |
|
frame = struct.pack("B", finbit | opcode | flags) |
|
if self.mask_outgoing: |
|
mask_bit = 0x80 |
|
else: |
|
mask_bit = 0 |
|
if data_len < 126: |
|
frame += struct.pack("B", data_len | mask_bit) |
|
elif data_len <= 0xFFFF: |
|
frame += struct.pack("!BH", 126 | mask_bit, data_len) |
|
else: |
|
frame += struct.pack("!BQ", 127 | mask_bit, data_len) |
|
if self.mask_outgoing: |
|
mask = os.urandom(4) |
|
data = mask + _websocket_mask(mask, data) |
|
frame += data |
|
self._wire_bytes_out += len(frame) |
|
return self.stream.write(frame) |
|
|
|
def write_message( |
|
self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False |
|
) -> "Future[None]": |
|
"""Sends the given message to the client of this Web Socket.""" |
|
if binary: |
|
opcode = 0x2 |
|
else: |
|
opcode = 0x1 |
|
if isinstance(message, dict): |
|
message = tornado.escape.json_encode(message) |
|
message = tornado.escape.utf8(message) |
|
assert isinstance(message, bytes) |
|
self._message_bytes_out += len(message) |
|
flags = 0 |
|
if self._compressor: |
|
message = self._compressor.compress(message) |
|
flags |= self.RSV1 |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
fut = self._write_frame(True, opcode, message, flags=flags) |
|
except StreamClosedError: |
|
raise WebSocketClosedError() |
|
|
|
async def wrapper() -> None: |
|
try: |
|
await fut |
|
except StreamClosedError: |
|
raise WebSocketClosedError() |
|
|
|
return asyncio.ensure_future(wrapper()) |
|
|
|
def write_ping(self, data: bytes) -> None: |
|
"""Send ping frame.""" |
|
assert isinstance(data, bytes) |
|
self._write_frame(True, 0x9, data) |
|
|
|
async def _receive_frame_loop(self) -> None: |
|
try: |
|
while not self.client_terminated: |
|
await self._receive_frame() |
|
except StreamClosedError: |
|
self._abort() |
|
self.handler.on_ws_connection_close(self.close_code, self.close_reason) |
|
|
|
async def _read_bytes(self, n: int) -> bytes: |
|
data = await self.stream.read_bytes(n) |
|
self._wire_bytes_in += n |
|
return data |
|
|
|
async def _receive_frame(self) -> None: |
|
|
|
data = await self._read_bytes(2) |
|
header, mask_payloadlen = struct.unpack("BB", data) |
|
is_final_frame = header & self.FIN |
|
reserved_bits = header & self.RSV_MASK |
|
opcode = header & self.OPCODE_MASK |
|
opcode_is_control = opcode & 0x8 |
|
if self._decompressor is not None and opcode != 0: |
|
|
|
|
|
|
|
self._frame_compressed = bool(reserved_bits & self.RSV1) |
|
reserved_bits &= ~self.RSV1 |
|
if reserved_bits: |
|
|
|
self._abort() |
|
return |
|
is_masked = bool(mask_payloadlen & 0x80) |
|
payloadlen = mask_payloadlen & 0x7F |
|
|
|
|
|
if opcode_is_control and payloadlen >= 126: |
|
|
|
self._abort() |
|
return |
|
if payloadlen < 126: |
|
self._frame_length = payloadlen |
|
elif payloadlen == 126: |
|
data = await self._read_bytes(2) |
|
payloadlen = struct.unpack("!H", data)[0] |
|
elif payloadlen == 127: |
|
data = await self._read_bytes(8) |
|
payloadlen = struct.unpack("!Q", data)[0] |
|
new_len = payloadlen |
|
if self._fragmented_message_buffer is not None: |
|
new_len += len(self._fragmented_message_buffer) |
|
if new_len > self.params.max_message_size: |
|
self.close(1009, "message too big") |
|
self._abort() |
|
return |
|
|
|
|
|
if is_masked: |
|
self._frame_mask = await self._read_bytes(4) |
|
data = await self._read_bytes(payloadlen) |
|
if is_masked: |
|
assert self._frame_mask is not None |
|
data = _websocket_mask(self._frame_mask, data) |
|
|
|
|
|
if opcode_is_control: |
|
|
|
|
|
|
|
if not is_final_frame: |
|
|
|
self._abort() |
|
return |
|
elif opcode == 0: |
|
if self._fragmented_message_buffer is None: |
|
|
|
self._abort() |
|
return |
|
self._fragmented_message_buffer.extend(data) |
|
if is_final_frame: |
|
opcode = self._fragmented_message_opcode |
|
data = bytes(self._fragmented_message_buffer) |
|
self._fragmented_message_buffer = None |
|
else: |
|
if self._fragmented_message_buffer is not None: |
|
|
|
self._abort() |
|
return |
|
if not is_final_frame: |
|
self._fragmented_message_opcode = opcode |
|
self._fragmented_message_buffer = bytearray(data) |
|
|
|
if is_final_frame: |
|
handled_future = self._handle_message(opcode, data) |
|
if handled_future is not None: |
|
await handled_future |
|
|
|
def _handle_message(self, opcode: int, data: bytes) -> "Optional[Future[None]]": |
|
"""Execute on_message, returning its Future if it is a coroutine.""" |
|
if self.client_terminated: |
|
return None |
|
|
|
if self._frame_compressed: |
|
assert self._decompressor is not None |
|
try: |
|
data = self._decompressor.decompress(data) |
|
except _DecompressTooLargeError: |
|
self.close(1009, "message too big after decompression") |
|
self._abort() |
|
return None |
|
|
|
if opcode == 0x1: |
|
|
|
self._message_bytes_in += len(data) |
|
try: |
|
decoded = data.decode("utf-8") |
|
except UnicodeDecodeError: |
|
self._abort() |
|
return None |
|
return self._run_callback(self.handler.on_message, decoded) |
|
elif opcode == 0x2: |
|
|
|
self._message_bytes_in += len(data) |
|
return self._run_callback(self.handler.on_message, data) |
|
elif opcode == 0x8: |
|
|
|
self.client_terminated = True |
|
if len(data) >= 2: |
|
self.close_code = struct.unpack(">H", data[:2])[0] |
|
if len(data) > 2: |
|
self.close_reason = to_unicode(data[2:]) |
|
|
|
self.close(self.close_code) |
|
elif opcode == 0x9: |
|
|
|
try: |
|
self._write_frame(True, 0xA, data) |
|
except StreamClosedError: |
|
self._abort() |
|
self._run_callback(self.handler.on_ping, data) |
|
elif opcode == 0xA: |
|
|
|
self.last_pong = IOLoop.current().time() |
|
return self._run_callback(self.handler.on_pong, data) |
|
else: |
|
self._abort() |
|
return None |
|
|
|
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: |
|
"""Closes the WebSocket connection.""" |
|
if not self.server_terminated: |
|
if not self.stream.closed(): |
|
if code is None and reason is not None: |
|
code = 1000 |
|
if code is None: |
|
close_data = b"" |
|
else: |
|
close_data = struct.pack(">H", code) |
|
if reason is not None: |
|
close_data += utf8(reason) |
|
try: |
|
self._write_frame(True, 0x8, close_data) |
|
except StreamClosedError: |
|
self._abort() |
|
self.server_terminated = True |
|
if self.client_terminated: |
|
if self._waiting is not None: |
|
self.stream.io_loop.remove_timeout(self._waiting) |
|
self._waiting = None |
|
self.stream.close() |
|
elif self._waiting is None: |
|
|
|
|
|
self._waiting = self.stream.io_loop.add_timeout( |
|
self.stream.io_loop.time() + 5, self._abort |
|
) |
|
if self.ping_callback: |
|
self.ping_callback.stop() |
|
self.ping_callback = None |
|
|
|
def is_closing(self) -> bool: |
|
"""Return ``True`` if this connection is closing. |
|
|
|
The connection is considered closing if either side has |
|
initiated its closing handshake or if the stream has been |
|
shut down uncleanly. |
|
""" |
|
return self.stream.closed() or self.client_terminated or self.server_terminated |
|
|
|
@property |
|
def ping_interval(self) -> Optional[float]: |
|
interval = self.params.ping_interval |
|
if interval is not None: |
|
return interval |
|
return 0 |
|
|
|
@property |
|
def ping_timeout(self) -> Optional[float]: |
|
timeout = self.params.ping_timeout |
|
if timeout is not None: |
|
return timeout |
|
assert self.ping_interval is not None |
|
return max(3 * self.ping_interval, 30) |
|
|
|
def start_pinging(self) -> None: |
|
"""Start sending periodic pings to keep the connection alive""" |
|
assert self.ping_interval is not None |
|
if self.ping_interval > 0: |
|
self.last_ping = self.last_pong = IOLoop.current().time() |
|
self.ping_callback = PeriodicCallback( |
|
self.periodic_ping, self.ping_interval * 1000 |
|
) |
|
self.ping_callback.start() |
|
|
|
def periodic_ping(self) -> None: |
|
"""Send a ping to keep the websocket alive |
|
|
|
Called periodically if the websocket_ping_interval is set and non-zero. |
|
""" |
|
if self.is_closing() and self.ping_callback is not None: |
|
self.ping_callback.stop() |
|
return |
|
|
|
|
|
|
|
|
|
now = IOLoop.current().time() |
|
since_last_pong = now - self.last_pong |
|
since_last_ping = now - self.last_ping |
|
assert self.ping_interval is not None |
|
assert self.ping_timeout is not None |
|
if ( |
|
since_last_ping < 2 * self.ping_interval |
|
and since_last_pong > self.ping_timeout |
|
): |
|
self.close() |
|
return |
|
|
|
self.write_ping(b"") |
|
self.last_ping = now |
|
|
|
def set_nodelay(self, x: bool) -> None: |
|
self.stream.set_nodelay(x) |
|
|
|
|
|
class WebSocketClientConnection(simple_httpclient._HTTPConnection): |
|
"""WebSocket client connection. |
|
|
|
This class should not be instantiated directly; use the |
|
`websocket_connect` function instead. |
|
""" |
|
|
|
protocol = None |
|
|
|
def __init__( |
|
self, |
|
request: httpclient.HTTPRequest, |
|
on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, |
|
compression_options: Optional[Dict[str, Any]] = None, |
|
ping_interval: Optional[float] = None, |
|
ping_timeout: Optional[float] = None, |
|
max_message_size: int = _default_max_message_size, |
|
subprotocols: Optional[List[str]] = None, |
|
resolver: Optional[Resolver] = None, |
|
) -> None: |
|
self.connect_future = Future() |
|
self.read_queue = Queue(1) |
|
self.key = base64.b64encode(os.urandom(16)) |
|
self._on_message_callback = on_message_callback |
|
self.close_code = None |
|
self.close_reason = None |
|
self.params = _WebSocketParams( |
|
ping_interval=ping_interval, |
|
ping_timeout=ping_timeout, |
|
max_message_size=max_message_size, |
|
compression_options=compression_options, |
|
) |
|
|
|
scheme, sep, rest = request.url.partition(":") |
|
scheme = {"ws": "http", "wss": "https"}[scheme] |
|
request.url = scheme + sep + rest |
|
request.headers.update( |
|
{ |
|
"Upgrade": "websocket", |
|
"Connection": "Upgrade", |
|
"Sec-WebSocket-Key": self.key, |
|
"Sec-WebSocket-Version": "13", |
|
} |
|
) |
|
if subprotocols is not None: |
|
request.headers["Sec-WebSocket-Protocol"] = ",".join(subprotocols) |
|
if compression_options is not None: |
|
|
|
|
|
|
|
|
|
|
|
request.headers["Sec-WebSocket-Extensions"] = ( |
|
"permessage-deflate; client_max_window_bits" |
|
) |
|
|
|
|
|
request.follow_redirects = False |
|
|
|
self.tcp_client = TCPClient(resolver=resolver) |
|
super().__init__( |
|
None, |
|
request, |
|
lambda: None, |
|
self._on_http_response, |
|
104857600, |
|
self.tcp_client, |
|
65536, |
|
104857600, |
|
) |
|
|
|
def __del__(self) -> None: |
|
if self.protocol is not None: |
|
|
|
|
|
|
|
|
|
|
|
warnings.warn("Unclosed WebSocketClientConnection", ResourceWarning) |
|
|
|
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> None: |
|
"""Closes the websocket connection. |
|
|
|
``code`` and ``reason`` are documented under |
|
`WebSocketHandler.close`. |
|
|
|
.. versionadded:: 3.2 |
|
|
|
.. versionchanged:: 4.0 |
|
|
|
Added the ``code`` and ``reason`` arguments. |
|
""" |
|
if self.protocol is not None: |
|
self.protocol.close(code, reason) |
|
self.protocol = None |
|
|
|
def on_connection_close(self) -> None: |
|
if not self.connect_future.done(): |
|
self.connect_future.set_exception(StreamClosedError()) |
|
self._on_message(None) |
|
self.tcp_client.close() |
|
super().on_connection_close() |
|
|
|
def on_ws_connection_close( |
|
self, close_code: Optional[int] = None, close_reason: Optional[str] = None |
|
) -> None: |
|
self.close_code = close_code |
|
self.close_reason = close_reason |
|
self.on_connection_close() |
|
|
|
def _on_http_response(self, response: httpclient.HTTPResponse) -> None: |
|
if not self.connect_future.done(): |
|
if response.error: |
|
self.connect_future.set_exception(response.error) |
|
else: |
|
self.connect_future.set_exception( |
|
WebSocketError("Non-websocket response") |
|
) |
|
|
|
async def headers_received( |
|
self, |
|
start_line: Union[httputil.RequestStartLine, httputil.ResponseStartLine], |
|
headers: httputil.HTTPHeaders, |
|
) -> None: |
|
assert isinstance(start_line, httputil.ResponseStartLine) |
|
if start_line.code != 101: |
|
await super().headers_received(start_line, headers) |
|
return |
|
|
|
if self._timeout is not None: |
|
self.io_loop.remove_timeout(self._timeout) |
|
self._timeout = None |
|
|
|
self.headers = headers |
|
self.protocol = self.get_websocket_protocol() |
|
self.protocol._process_server_headers(self.key, self.headers) |
|
self.protocol.stream = self.connection.detach() |
|
|
|
IOLoop.current().add_callback(self.protocol._receive_frame_loop) |
|
self.protocol.start_pinging() |
|
|
|
|
|
|
|
|
|
|
|
self.final_callback = None |
|
|
|
future_set_result_unless_cancelled(self.connect_future, self) |
|
|
|
def write_message( |
|
self, message: Union[str, bytes, Dict[str, Any]], binary: bool = False |
|
) -> "Future[None]": |
|
"""Sends a message to the WebSocket server. |
|
|
|
If the stream is closed, raises `WebSocketClosedError`. |
|
Returns a `.Future` which can be used for flow control. |
|
|
|
.. versionchanged:: 5.0 |
|
Exception raised on a closed stream changed from `.StreamClosedError` |
|
to `WebSocketClosedError`. |
|
""" |
|
if self.protocol is None: |
|
raise WebSocketClosedError("Client connection has been closed") |
|
return self.protocol.write_message(message, binary=binary) |
|
|
|
def read_message( |
|
self, |
|
callback: Optional[Callable[["Future[Union[None, str, bytes]]"], None]] = None, |
|
) -> Awaitable[Union[None, str, bytes]]: |
|
"""Reads a message from the WebSocket server. |
|
|
|
If on_message_callback was specified at WebSocket |
|
initialization, this function will never return messages |
|
|
|
Returns a future whose result is the message, or None |
|
if the connection is closed. If a callback argument |
|
is given it will be called with the future when it is |
|
ready. |
|
""" |
|
|
|
awaitable = self.read_queue.get() |
|
if callback is not None: |
|
self.io_loop.add_future(asyncio.ensure_future(awaitable), callback) |
|
return awaitable |
|
|
|
def on_message(self, message: Union[str, bytes]) -> Optional[Awaitable[None]]: |
|
return self._on_message(message) |
|
|
|
def _on_message( |
|
self, message: Union[None, str, bytes] |
|
) -> Optional[Awaitable[None]]: |
|
if self._on_message_callback: |
|
self._on_message_callback(message) |
|
return None |
|
else: |
|
return self.read_queue.put(message) |
|
|
|
def ping(self, data: bytes = b"") -> None: |
|
"""Send ping frame to the remote end. |
|
|
|
The data argument allows a small amount of data (up to 125 |
|
bytes) to be sent as a part of the ping message. Note that not |
|
all websocket implementations expose this data to |
|
applications. |
|
|
|
Consider using the ``ping_interval`` argument to |
|
`websocket_connect` instead of sending pings manually. |
|
|
|
.. versionadded:: 5.1 |
|
|
|
""" |
|
data = utf8(data) |
|
if self.protocol is None: |
|
raise WebSocketClosedError() |
|
self.protocol.write_ping(data) |
|
|
|
def on_pong(self, data: bytes) -> None: |
|
pass |
|
|
|
def on_ping(self, data: bytes) -> None: |
|
pass |
|
|
|
def get_websocket_protocol(self) -> WebSocketProtocol: |
|
return WebSocketProtocol13(self, mask_outgoing=True, params=self.params) |
|
|
|
@property |
|
def selected_subprotocol(self) -> Optional[str]: |
|
"""The subprotocol selected by the server. |
|
|
|
.. versionadded:: 5.1 |
|
""" |
|
return self.protocol.selected_subprotocol |
|
|
|
def log_exception( |
|
self, |
|
typ: "Optional[Type[BaseException]]", |
|
value: Optional[BaseException], |
|
tb: Optional[TracebackType], |
|
) -> None: |
|
assert typ is not None |
|
assert value is not None |
|
app_log.error("Uncaught exception %s", value, exc_info=(typ, value, tb)) |
|
|
|
|
|
def websocket_connect( |
|
url: Union[str, httpclient.HTTPRequest], |
|
callback: Optional[Callable[["Future[WebSocketClientConnection]"], None]] = None, |
|
connect_timeout: Optional[float] = None, |
|
on_message_callback: Optional[Callable[[Union[None, str, bytes]], None]] = None, |
|
compression_options: Optional[Dict[str, Any]] = None, |
|
ping_interval: Optional[float] = None, |
|
ping_timeout: Optional[float] = None, |
|
max_message_size: int = _default_max_message_size, |
|
subprotocols: Optional[List[str]] = None, |
|
resolver: Optional[Resolver] = None, |
|
) -> "Awaitable[WebSocketClientConnection]": |
|
"""Client-side websocket support. |
|
|
|
Takes a url and returns a Future whose result is a |
|
`WebSocketClientConnection`. |
|
|
|
``compression_options`` is interpreted in the same way as the |
|
return value of `.WebSocketHandler.get_compression_options`. |
|
|
|
The connection supports two styles of operation. In the coroutine |
|
style, the application typically calls |
|
`~.WebSocketClientConnection.read_message` in a loop:: |
|
|
|
conn = yield websocket_connect(url) |
|
while True: |
|
msg = yield conn.read_message() |
|
if msg is None: break |
|
# Do something with msg |
|
|
|
In the callback style, pass an ``on_message_callback`` to |
|
``websocket_connect``. In both styles, a message of ``None`` |
|
indicates that the connection has been closed. |
|
|
|
``subprotocols`` may be a list of strings specifying proposed |
|
subprotocols. The selected protocol may be found on the |
|
``selected_subprotocol`` attribute of the connection object |
|
when the connection is complete. |
|
|
|
.. versionchanged:: 3.2 |
|
Also accepts ``HTTPRequest`` objects in place of urls. |
|
|
|
.. versionchanged:: 4.1 |
|
Added ``compression_options`` and ``on_message_callback``. |
|
|
|
.. versionchanged:: 4.5 |
|
Added the ``ping_interval``, ``ping_timeout``, and ``max_message_size`` |
|
arguments, which have the same meaning as in `WebSocketHandler`. |
|
|
|
.. versionchanged:: 5.0 |
|
The ``io_loop`` argument (deprecated since version 4.1) has been removed. |
|
|
|
.. versionchanged:: 5.1 |
|
Added the ``subprotocols`` argument. |
|
|
|
.. versionchanged:: 6.3 |
|
Added the ``resolver`` argument. |
|
""" |
|
if isinstance(url, httpclient.HTTPRequest): |
|
assert connect_timeout is None |
|
request = url |
|
|
|
|
|
request.headers = httputil.HTTPHeaders(request.headers) |
|
else: |
|
request = httpclient.HTTPRequest(url, connect_timeout=connect_timeout) |
|
request = cast( |
|
httpclient.HTTPRequest, |
|
httpclient._RequestProxy(request, httpclient.HTTPRequest._DEFAULTS), |
|
) |
|
conn = WebSocketClientConnection( |
|
request, |
|
on_message_callback=on_message_callback, |
|
compression_options=compression_options, |
|
ping_interval=ping_interval, |
|
ping_timeout=ping_timeout, |
|
max_message_size=max_message_size, |
|
subprotocols=subprotocols, |
|
resolver=resolver, |
|
) |
|
if callback is not None: |
|
IOLoop.current().add_future(conn.connect_future, callback) |
|
return conn.connect_future |
|
|