mbuali's picture
Upload folder using huggingface_hub
d1ceb73 verified
""" Language Server stdio-mode readers
Parts of this code are derived from:
> https://github.com/palantir/python-jsonrpc-server/blob/0.2.0/pyls_jsonrpc/streams.py#L83 # noqa
> https://github.com/palantir/python-jsonrpc-server/blob/45ed1931e4b2e5100cc61b3992c16d6f68af2e80/pyls_jsonrpc/streams.py # noqa
> > MIT License https://github.com/palantir/python-jsonrpc-server/blob/0.2.0/LICENSE
> > Copyright 2018 Palantir Technologies, Inc.
"""
# pylint: disable=broad-except
import asyncio
import io
import os
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Text
from tornado.concurrent import run_on_executor
from tornado.gen import convert_yielded
from tornado.httputil import HTTPHeaders
from tornado.ioloop import IOLoop
from tornado.queues import Queue
from traitlets import Float, Instance, default
from traitlets.config import LoggingConfigurable
from .non_blocking import make_non_blocking
class LspStdIoBase(LoggingConfigurable):
"""Non-blocking, queued base for communicating with stdio Language Servers"""
executor = None
stream = Instance( # type:ignore[assignment]
io.RawIOBase, help="the stream to read/write"
) # type: io.RawIOBase
queue = Instance(Queue, help="queue to get/put")
def __repr__(self): # pragma: no cover
return "<{}(parent={})>".format(self.__class__.__name__, self.parent)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.log.debug("%s initialized", self)
self.executor = ThreadPoolExecutor(max_workers=1)
def close(self):
self.stream.close()
self.log.debug("%s closed", self)
class LspStdIoReader(LspStdIoBase):
"""Language Server stdio Reader
Because non-blocking (but still synchronous) IO is used, rudimentary
exponential backoff is used.
"""
max_wait = Float(help="maximum time to wait on idle stream").tag(config=True)
min_wait = Float(0.05, help="minimum time to wait on idle stream").tag(config=True)
next_wait = Float(0.05, help="next time to wait on idle stream").tag(config=True)
@default("max_wait")
def _default_max_wait(self):
return 0.1 if os.name == "nt" else self.min_wait * 2
async def sleep(self):
"""Simple exponential backoff for sleeping"""
if self.stream.closed: # pragma: no cover
return
self.next_wait = min(self.next_wait * 2, self.max_wait)
try:
await asyncio.sleep(self.next_wait)
except Exception: # pragma: no cover
pass
def wake(self):
"""Reset the wait time"""
self.wait = self.min_wait
async def read(self) -> None:
"""Read from a Language Server until it is closed"""
make_non_blocking(self.stream)
while not self.stream.closed:
message = None
try:
message = await self.read_one()
if not message:
await self.sleep()
continue
else:
self.wake()
IOLoop.current().add_callback(self.queue.put_nowait, message)
except Exception as e: # pragma: no cover
self.log.exception(
"%s couldn't enqueue message: %s (%s)", self, message, e
)
await self.sleep()
async def _read_content(
self, length: int, max_parts=1000, max_empties=200
) -> Optional[bytes]:
"""Read the full length of the message unless exceeding max_parts or
max_empties empty reads occur.
See https://github.com/jupyter-lsp/jupyterlab-lsp/issues/450
Crucial docs or read():
"If the argument is positive, and the underlying raw
stream is not interactive, multiple raw reads may be issued
to satisfy the byte count (unless EOF is reached first)"
Args:
- length: the content length
- max_parts: prevent absurdly long messages (1000 parts is several MBs):
1 part is usually sufficient but not enough for some long
messages 2 or 3 parts are often needed.
"""
raw = None
raw_parts: List[bytes] = []
received_size = 0
while received_size < length and len(raw_parts) < max_parts and max_empties > 0:
part = None
try:
part = self.stream.read(length - received_size)
except OSError: # pragma: no cover
pass
if part is None:
max_empties -= 1
await self.sleep()
continue
received_size += len(part)
raw_parts.append(part)
if raw_parts:
raw = b"".join(raw_parts)
if len(raw) != length: # pragma: no cover
self.log.warning(
f"Readout and content-length mismatch: {len(raw)} vs {length};"
f"remaining empties: {max_empties}; remaining parts: {max_parts}"
)
return raw
async def read_one(self) -> Text:
"""Read a single message"""
message = ""
headers = HTTPHeaders()
line = await convert_yielded(self._readline())
if line:
while line and line.strip():
headers.parse_line(line)
line = await convert_yielded(self._readline())
content_length = int(headers.get("content-length", "0"))
if content_length:
raw = await self._read_content(length=content_length)
if raw is not None:
message = raw.decode("utf-8").strip()
else: # pragma: no cover
self.log.warning(
"%s failed to read message of length %s",
self,
content_length,
)
return message
@run_on_executor
def _readline(self) -> Text:
"""Read a line (or immediately return None)"""
try:
return self.stream.readline().decode("utf-8").strip()
except OSError: # pragma: no cover
return ""
class LspStdIoWriter(LspStdIoBase):
"""Language Server stdio Writer"""
async def write(self) -> None:
"""Write to a Language Server until it closes"""
while not self.stream.closed:
message = await self.queue.get()
try:
body = message.encode("utf-8")
response = "Content-Length: {}\r\n\r\n{}".format(len(body), message)
await convert_yielded(self._write_one(response.encode("utf-8")))
except Exception: # pragma: no cover
self.log.exception("%s couldn't write message: %s", self, response)
finally:
self.queue.task_done()
@run_on_executor
def _write_one(self, message) -> None:
self.stream.write(message)
self.stream.flush()