|
""" Language Server stdio-mode readers |
|
|
|
Parts of this code are derived from: |
|
|
|
> https://github.com/palantir/python-jsonrpc-server/blob/0.2.0/pyls_jsonrpc/streams.py#L83 # noqa |
|
> https://github.com/palantir/python-jsonrpc-server/blob/45ed1931e4b2e5100cc61b3992c16d6f68af2e80/pyls_jsonrpc/streams.py # noqa |
|
> > MIT License https://github.com/palantir/python-jsonrpc-server/blob/0.2.0/LICENSE |
|
> > Copyright 2018 Palantir Technologies, Inc. |
|
""" |
|
|
|
|
|
import asyncio |
|
import io |
|
import os |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import List, Optional, Text |
|
|
|
from tornado.concurrent import run_on_executor |
|
from tornado.gen import convert_yielded |
|
from tornado.httputil import HTTPHeaders |
|
from tornado.ioloop import IOLoop |
|
from tornado.queues import Queue |
|
from traitlets import Float, Instance, default |
|
from traitlets.config import LoggingConfigurable |
|
|
|
from .non_blocking import make_non_blocking |
|
|
|
|
|
class LspStdIoBase(LoggingConfigurable): |
|
"""Non-blocking, queued base for communicating with stdio Language Servers""" |
|
|
|
executor = None |
|
|
|
stream = Instance( |
|
io.RawIOBase, help="the stream to read/write" |
|
) |
|
queue = Instance(Queue, help="queue to get/put") |
|
|
|
def __repr__(self): |
|
return "<{}(parent={})>".format(self.__class__.__name__, self.parent) |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
self.log.debug("%s initialized", self) |
|
self.executor = ThreadPoolExecutor(max_workers=1) |
|
|
|
def close(self): |
|
self.stream.close() |
|
self.log.debug("%s closed", self) |
|
|
|
|
|
class LspStdIoReader(LspStdIoBase): |
|
"""Language Server stdio Reader |
|
|
|
Because non-blocking (but still synchronous) IO is used, rudimentary |
|
exponential backoff is used. |
|
""" |
|
|
|
max_wait = Float(help="maximum time to wait on idle stream").tag(config=True) |
|
min_wait = Float(0.05, help="minimum time to wait on idle stream").tag(config=True) |
|
next_wait = Float(0.05, help="next time to wait on idle stream").tag(config=True) |
|
|
|
@default("max_wait") |
|
def _default_max_wait(self): |
|
return 0.1 if os.name == "nt" else self.min_wait * 2 |
|
|
|
async def sleep(self): |
|
"""Simple exponential backoff for sleeping""" |
|
if self.stream.closed: |
|
return |
|
self.next_wait = min(self.next_wait * 2, self.max_wait) |
|
try: |
|
await asyncio.sleep(self.next_wait) |
|
except Exception: |
|
pass |
|
|
|
def wake(self): |
|
"""Reset the wait time""" |
|
self.wait = self.min_wait |
|
|
|
async def read(self) -> None: |
|
"""Read from a Language Server until it is closed""" |
|
make_non_blocking(self.stream) |
|
|
|
while not self.stream.closed: |
|
message = None |
|
try: |
|
message = await self.read_one() |
|
|
|
if not message: |
|
await self.sleep() |
|
continue |
|
else: |
|
self.wake() |
|
|
|
IOLoop.current().add_callback(self.queue.put_nowait, message) |
|
except Exception as e: |
|
self.log.exception( |
|
"%s couldn't enqueue message: %s (%s)", self, message, e |
|
) |
|
await self.sleep() |
|
|
|
async def _read_content( |
|
self, length: int, max_parts=1000, max_empties=200 |
|
) -> Optional[bytes]: |
|
"""Read the full length of the message unless exceeding max_parts or |
|
max_empties empty reads occur. |
|
|
|
See https://github.com/jupyter-lsp/jupyterlab-lsp/issues/450 |
|
|
|
Crucial docs or read(): |
|
"If the argument is positive, and the underlying raw |
|
stream is not interactive, multiple raw reads may be issued |
|
to satisfy the byte count (unless EOF is reached first)" |
|
|
|
Args: |
|
- length: the content length |
|
- max_parts: prevent absurdly long messages (1000 parts is several MBs): |
|
1 part is usually sufficient but not enough for some long |
|
messages 2 or 3 parts are often needed. |
|
""" |
|
raw = None |
|
raw_parts: List[bytes] = [] |
|
received_size = 0 |
|
while received_size < length and len(raw_parts) < max_parts and max_empties > 0: |
|
part = None |
|
try: |
|
part = self.stream.read(length - received_size) |
|
except OSError: |
|
pass |
|
if part is None: |
|
max_empties -= 1 |
|
await self.sleep() |
|
continue |
|
received_size += len(part) |
|
raw_parts.append(part) |
|
|
|
if raw_parts: |
|
raw = b"".join(raw_parts) |
|
if len(raw) != length: |
|
self.log.warning( |
|
f"Readout and content-length mismatch: {len(raw)} vs {length};" |
|
f"remaining empties: {max_empties}; remaining parts: {max_parts}" |
|
) |
|
|
|
return raw |
|
|
|
async def read_one(self) -> Text: |
|
"""Read a single message""" |
|
message = "" |
|
headers = HTTPHeaders() |
|
|
|
line = await convert_yielded(self._readline()) |
|
|
|
if line: |
|
while line and line.strip(): |
|
headers.parse_line(line) |
|
line = await convert_yielded(self._readline()) |
|
|
|
content_length = int(headers.get("content-length", "0")) |
|
|
|
if content_length: |
|
raw = await self._read_content(length=content_length) |
|
if raw is not None: |
|
message = raw.decode("utf-8").strip() |
|
else: |
|
self.log.warning( |
|
"%s failed to read message of length %s", |
|
self, |
|
content_length, |
|
) |
|
|
|
return message |
|
|
|
@run_on_executor |
|
def _readline(self) -> Text: |
|
"""Read a line (or immediately return None)""" |
|
try: |
|
return self.stream.readline().decode("utf-8").strip() |
|
except OSError: |
|
return "" |
|
|
|
|
|
class LspStdIoWriter(LspStdIoBase): |
|
"""Language Server stdio Writer""" |
|
|
|
async def write(self) -> None: |
|
"""Write to a Language Server until it closes""" |
|
while not self.stream.closed: |
|
message = await self.queue.get() |
|
try: |
|
body = message.encode("utf-8") |
|
response = "Content-Length: {}\r\n\r\n{}".format(len(body), message) |
|
await convert_yielded(self._write_one(response.encode("utf-8"))) |
|
except Exception: |
|
self.log.exception("%s couldn't write message: %s", self, response) |
|
finally: |
|
self.queue.task_done() |
|
|
|
@run_on_executor |
|
def _write_one(self, message) -> None: |
|
self.stream.write(message) |
|
self.stream.flush() |
|
|