|
from __future__ import annotations |
|
|
|
import array |
|
import asyncio |
|
import concurrent.futures |
|
import math |
|
import socket |
|
import sys |
|
import threading |
|
import weakref |
|
from asyncio import ( |
|
AbstractEventLoop, |
|
CancelledError, |
|
all_tasks, |
|
create_task, |
|
current_task, |
|
get_running_loop, |
|
sleep, |
|
) |
|
from asyncio.base_events import _run_until_complete_cb |
|
from collections import OrderedDict, deque |
|
from collections.abc import AsyncIterator, Generator, Iterable |
|
from concurrent.futures import Future |
|
from contextlib import suppress |
|
from contextvars import Context, copy_context |
|
from dataclasses import dataclass |
|
from functools import partial, wraps |
|
from inspect import ( |
|
CORO_RUNNING, |
|
CORO_SUSPENDED, |
|
getcoroutinestate, |
|
iscoroutine, |
|
) |
|
from io import IOBase |
|
from os import PathLike |
|
from queue import Queue |
|
from signal import Signals |
|
from socket import AddressFamily, SocketKind |
|
from threading import Thread |
|
from types import TracebackType |
|
from typing import ( |
|
IO, |
|
Any, |
|
AsyncGenerator, |
|
Awaitable, |
|
Callable, |
|
Collection, |
|
ContextManager, |
|
Coroutine, |
|
Mapping, |
|
Optional, |
|
Sequence, |
|
Tuple, |
|
TypeVar, |
|
cast, |
|
) |
|
from weakref import WeakKeyDictionary |
|
|
|
import sniffio |
|
|
|
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc |
|
from .._core._eventloop import claim_worker_thread, threadlocals |
|
from .._core._exceptions import ( |
|
BrokenResourceError, |
|
BusyResourceError, |
|
ClosedResourceError, |
|
EndOfStream, |
|
WouldBlock, |
|
) |
|
from .._core._sockets import convert_ipv6_sockaddr |
|
from .._core._streams import create_memory_object_stream |
|
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter |
|
from .._core._synchronization import Event as BaseEvent |
|
from .._core._synchronization import ResourceGuard |
|
from .._core._tasks import CancelScope as BaseCancelScope |
|
from ..abc import ( |
|
AsyncBackend, |
|
IPSockAddrType, |
|
SocketListener, |
|
UDPPacketType, |
|
UNIXDatagramPacketType, |
|
) |
|
from ..lowlevel import RunVar |
|
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
|
|
|
if sys.version_info >= (3, 10): |
|
from typing import ParamSpec |
|
else: |
|
from typing_extensions import ParamSpec |
|
|
|
if sys.version_info >= (3, 11): |
|
from asyncio import Runner |
|
from typing import TypeVarTuple, Unpack |
|
else: |
|
import contextvars |
|
import enum |
|
import signal |
|
from asyncio import coroutines, events, exceptions, tasks |
|
|
|
from exceptiongroup import BaseExceptionGroup |
|
from typing_extensions import TypeVarTuple, Unpack |
|
|
|
class _State(enum.Enum): |
|
CREATED = "created" |
|
INITIALIZED = "initialized" |
|
CLOSED = "closed" |
|
|
|
class Runner: |
|
|
|
def __init__( |
|
self, |
|
*, |
|
debug: bool | None = None, |
|
loop_factory: Callable[[], AbstractEventLoop] | None = None, |
|
): |
|
self._state = _State.CREATED |
|
self._debug = debug |
|
self._loop_factory = loop_factory |
|
self._loop: AbstractEventLoop | None = None |
|
self._context = None |
|
self._interrupt_count = 0 |
|
self._set_event_loop = False |
|
|
|
def __enter__(self) -> Runner: |
|
self._lazy_init() |
|
return self |
|
|
|
def __exit__( |
|
self, |
|
exc_type: type[BaseException], |
|
exc_val: BaseException, |
|
exc_tb: TracebackType, |
|
) -> None: |
|
self.close() |
|
|
|
def close(self) -> None: |
|
"""Shutdown and close event loop.""" |
|
if self._state is not _State.INITIALIZED: |
|
return |
|
try: |
|
loop = self._loop |
|
_cancel_all_tasks(loop) |
|
loop.run_until_complete(loop.shutdown_asyncgens()) |
|
if hasattr(loop, "shutdown_default_executor"): |
|
loop.run_until_complete(loop.shutdown_default_executor()) |
|
else: |
|
loop.run_until_complete(_shutdown_default_executor(loop)) |
|
finally: |
|
if self._set_event_loop: |
|
events.set_event_loop(None) |
|
loop.close() |
|
self._loop = None |
|
self._state = _State.CLOSED |
|
|
|
def get_loop(self) -> AbstractEventLoop: |
|
"""Return embedded event loop.""" |
|
self._lazy_init() |
|
return self._loop |
|
|
|
def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval: |
|
"""Run a coroutine inside the embedded event loop.""" |
|
if not coroutines.iscoroutine(coro): |
|
raise ValueError(f"a coroutine was expected, got {coro!r}") |
|
|
|
if events._get_running_loop() is not None: |
|
|
|
raise RuntimeError( |
|
"Runner.run() cannot be called from a running event loop" |
|
) |
|
|
|
self._lazy_init() |
|
|
|
if context is None: |
|
context = self._context |
|
task = context.run(self._loop.create_task, coro) |
|
|
|
if ( |
|
threading.current_thread() is threading.main_thread() |
|
and signal.getsignal(signal.SIGINT) is signal.default_int_handler |
|
): |
|
sigint_handler = partial(self._on_sigint, main_task=task) |
|
try: |
|
signal.signal(signal.SIGINT, sigint_handler) |
|
except ValueError: |
|
|
|
|
|
|
|
sigint_handler = None |
|
else: |
|
sigint_handler = None |
|
|
|
self._interrupt_count = 0 |
|
try: |
|
return self._loop.run_until_complete(task) |
|
except exceptions.CancelledError: |
|
if self._interrupt_count > 0: |
|
uncancel = getattr(task, "uncancel", None) |
|
if uncancel is not None and uncancel() == 0: |
|
raise KeyboardInterrupt() |
|
raise |
|
finally: |
|
if ( |
|
sigint_handler is not None |
|
and signal.getsignal(signal.SIGINT) is sigint_handler |
|
): |
|
signal.signal(signal.SIGINT, signal.default_int_handler) |
|
|
|
def _lazy_init(self) -> None: |
|
if self._state is _State.CLOSED: |
|
raise RuntimeError("Runner is closed") |
|
if self._state is _State.INITIALIZED: |
|
return |
|
if self._loop_factory is None: |
|
self._loop = events.new_event_loop() |
|
if not self._set_event_loop: |
|
|
|
|
|
events.set_event_loop(self._loop) |
|
self._set_event_loop = True |
|
else: |
|
self._loop = self._loop_factory() |
|
if self._debug is not None: |
|
self._loop.set_debug(self._debug) |
|
self._context = contextvars.copy_context() |
|
self._state = _State.INITIALIZED |
|
|
|
def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None: |
|
self._interrupt_count += 1 |
|
if self._interrupt_count == 1 and not main_task.done(): |
|
main_task.cancel() |
|
|
|
self._loop.call_soon_threadsafe(lambda: None) |
|
return |
|
raise KeyboardInterrupt() |
|
|
|
def _cancel_all_tasks(loop: AbstractEventLoop) -> None: |
|
to_cancel = tasks.all_tasks(loop) |
|
if not to_cancel: |
|
return |
|
|
|
for task in to_cancel: |
|
task.cancel() |
|
|
|
loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True)) |
|
|
|
for task in to_cancel: |
|
if task.cancelled(): |
|
continue |
|
if task.exception() is not None: |
|
loop.call_exception_handler( |
|
{ |
|
"message": "unhandled exception during asyncio.run() shutdown", |
|
"exception": task.exception(), |
|
"task": task, |
|
} |
|
) |
|
|
|
async def _shutdown_default_executor(loop: AbstractEventLoop) -> None: |
|
"""Schedule the shutdown of the default executor.""" |
|
|
|
def _do_shutdown(future: asyncio.futures.Future) -> None: |
|
try: |
|
loop._default_executor.shutdown(wait=True) |
|
loop.call_soon_threadsafe(future.set_result, None) |
|
except Exception as ex: |
|
loop.call_soon_threadsafe(future.set_exception, ex) |
|
|
|
loop._executor_shutdown_called = True |
|
if loop._default_executor is None: |
|
return |
|
future = loop.create_future() |
|
thread = threading.Thread(target=_do_shutdown, args=(future,)) |
|
thread.start() |
|
try: |
|
await future |
|
finally: |
|
thread.join() |
|
|
|
|
|
T_Retval = TypeVar("T_Retval") |
|
T_contra = TypeVar("T_contra", contravariant=True) |
|
PosArgsT = TypeVarTuple("PosArgsT") |
|
P = ParamSpec("P") |
|
|
|
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") |
|
|
|
|
|
def find_root_task() -> asyncio.Task: |
|
root_task = _root_task.get(None) |
|
if root_task is not None and not root_task.done(): |
|
return root_task |
|
|
|
|
|
for task in all_tasks(): |
|
if task._callbacks and not task.done(): |
|
callbacks = [cb for cb, context in task._callbacks] |
|
for cb in callbacks: |
|
if ( |
|
cb is _run_until_complete_cb |
|
or getattr(cb, "__module__", None) == "uvloop.loop" |
|
): |
|
_root_task.set(task) |
|
return task |
|
|
|
|
|
task = cast(asyncio.Task, current_task()) |
|
state = _task_states.get(task) |
|
if state: |
|
cancel_scope = state.cancel_scope |
|
while cancel_scope and cancel_scope._parent_scope is not None: |
|
cancel_scope = cancel_scope._parent_scope |
|
|
|
if cancel_scope is not None: |
|
return cast(asyncio.Task, cancel_scope._host_task) |
|
|
|
return task |
|
|
|
|
|
def get_callable_name(func: Callable) -> str: |
|
module = getattr(func, "__module__", None) |
|
qualname = getattr(func, "__qualname__", None) |
|
return ".".join([x for x in (module, qualname) if x]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary() |
|
|
|
|
|
def _task_started(task: asyncio.Task) -> bool: |
|
"""Return ``True`` if the task has been started and has not finished.""" |
|
try: |
|
return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED) |
|
except AttributeError: |
|
|
|
raise Exception(f"Cannot determine if task {task} has started or not") from None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CancelScope(BaseCancelScope): |
|
def __new__( |
|
cls, *, deadline: float = math.inf, shield: bool = False |
|
) -> CancelScope: |
|
return object.__new__(cls) |
|
|
|
def __init__(self, deadline: float = math.inf, shield: bool = False): |
|
self._deadline = deadline |
|
self._shield = shield |
|
self._parent_scope: CancelScope | None = None |
|
self._child_scopes: set[CancelScope] = set() |
|
self._cancel_called = False |
|
self._cancelled_caught = False |
|
self._active = False |
|
self._timeout_handle: asyncio.TimerHandle | None = None |
|
self._cancel_handle: asyncio.Handle | None = None |
|
self._tasks: set[asyncio.Task] = set() |
|
self._host_task: asyncio.Task | None = None |
|
self._cancel_calls: int = 0 |
|
self._cancelling: int | None = None |
|
|
|
def __enter__(self) -> CancelScope: |
|
if self._active: |
|
raise RuntimeError( |
|
"Each CancelScope may only be used for a single 'with' block" |
|
) |
|
|
|
self._host_task = host_task = cast(asyncio.Task, current_task()) |
|
self._tasks.add(host_task) |
|
try: |
|
task_state = _task_states[host_task] |
|
except KeyError: |
|
task_state = TaskState(None, self) |
|
_task_states[host_task] = task_state |
|
else: |
|
self._parent_scope = task_state.cancel_scope |
|
task_state.cancel_scope = self |
|
if self._parent_scope is not None: |
|
self._parent_scope._child_scopes.add(self) |
|
self._parent_scope._tasks.remove(host_task) |
|
|
|
self._timeout() |
|
self._active = True |
|
if sys.version_info >= (3, 11): |
|
self._cancelling = self._host_task.cancelling() |
|
|
|
|
|
if self._cancel_called: |
|
self._deliver_cancellation(self) |
|
|
|
return self |
|
|
|
def __exit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> bool | None: |
|
if not self._active: |
|
raise RuntimeError("This cancel scope is not active") |
|
if current_task() is not self._host_task: |
|
raise RuntimeError( |
|
"Attempted to exit cancel scope in a different task than it was " |
|
"entered in" |
|
) |
|
|
|
assert self._host_task is not None |
|
host_task_state = _task_states.get(self._host_task) |
|
if host_task_state is None or host_task_state.cancel_scope is not self: |
|
raise RuntimeError( |
|
"Attempted to exit a cancel scope that isn't the current tasks's " |
|
"current cancel scope" |
|
) |
|
|
|
self._active = False |
|
if self._timeout_handle: |
|
self._timeout_handle.cancel() |
|
self._timeout_handle = None |
|
|
|
self._tasks.remove(self._host_task) |
|
if self._parent_scope is not None: |
|
self._parent_scope._child_scopes.remove(self) |
|
self._parent_scope._tasks.add(self._host_task) |
|
|
|
host_task_state.cancel_scope = self._parent_scope |
|
|
|
|
|
|
|
self._restart_cancellation_in_parent() |
|
|
|
if self._cancel_called and exc_val is not None: |
|
for exc in iterate_exceptions(exc_val): |
|
if isinstance(exc, CancelledError): |
|
self._cancelled_caught = self._uncancel(exc) |
|
if self._cancelled_caught: |
|
break |
|
|
|
return self._cancelled_caught |
|
|
|
return None |
|
|
|
def _uncancel(self, cancelled_exc: CancelledError) -> bool: |
|
if sys.version_info < (3, 9) or self._host_task is None: |
|
self._cancel_calls = 0 |
|
return True |
|
|
|
|
|
if self._cancelling is not None: |
|
while self._cancel_calls: |
|
self._cancel_calls -= 1 |
|
if self._host_task.uncancel() <= self._cancelling: |
|
return True |
|
|
|
self._cancel_calls = 0 |
|
return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args |
|
|
|
def _timeout(self) -> None: |
|
if self._deadline != math.inf: |
|
loop = get_running_loop() |
|
if loop.time() >= self._deadline: |
|
self.cancel() |
|
else: |
|
self._timeout_handle = loop.call_at(self._deadline, self._timeout) |
|
|
|
def _deliver_cancellation(self, origin: CancelScope) -> bool: |
|
""" |
|
Deliver cancellation to directly contained tasks and nested cancel scopes. |
|
|
|
Schedule another run at the end if we still have tasks eligible for |
|
cancellation. |
|
|
|
:param origin: the cancel scope that originated the cancellation |
|
:return: ``True`` if the delivery needs to be retried on the next cycle |
|
|
|
""" |
|
should_retry = False |
|
current = current_task() |
|
for task in self._tasks: |
|
if task._must_cancel: |
|
continue |
|
|
|
|
|
should_retry = True |
|
if task is not current and (task is self._host_task or _task_started(task)): |
|
waiter = task._fut_waiter |
|
if not isinstance(waiter, asyncio.Future) or not waiter.done(): |
|
origin._cancel_calls += 1 |
|
if sys.version_info >= (3, 9): |
|
task.cancel(f"Cancelled by cancel scope {id(origin):x}") |
|
else: |
|
task.cancel() |
|
|
|
|
|
|
|
for scope in self._child_scopes: |
|
if not scope._shield and not scope.cancel_called: |
|
should_retry = scope._deliver_cancellation(origin) or should_retry |
|
|
|
|
|
if origin is self: |
|
if should_retry: |
|
self._cancel_handle = get_running_loop().call_soon( |
|
self._deliver_cancellation, origin |
|
) |
|
else: |
|
self._cancel_handle = None |
|
|
|
return should_retry |
|
|
|
def _restart_cancellation_in_parent(self) -> None: |
|
""" |
|
Restart the cancellation effort in the closest directly cancelled parent scope. |
|
|
|
""" |
|
scope = self._parent_scope |
|
while scope is not None: |
|
if scope._cancel_called: |
|
if scope._cancel_handle is None: |
|
scope._deliver_cancellation(scope) |
|
|
|
break |
|
|
|
|
|
if scope._shield: |
|
break |
|
|
|
scope = scope._parent_scope |
|
|
|
def _parent_cancelled(self) -> bool: |
|
|
|
cancel_scope = self._parent_scope |
|
while cancel_scope is not None and not cancel_scope._shield: |
|
if cancel_scope._cancel_called: |
|
return True |
|
else: |
|
cancel_scope = cancel_scope._parent_scope |
|
|
|
return False |
|
|
|
def cancel(self) -> None: |
|
if not self._cancel_called: |
|
if self._timeout_handle: |
|
self._timeout_handle.cancel() |
|
self._timeout_handle = None |
|
|
|
self._cancel_called = True |
|
if self._host_task is not None: |
|
self._deliver_cancellation(self) |
|
|
|
@property |
|
def deadline(self) -> float: |
|
return self._deadline |
|
|
|
@deadline.setter |
|
def deadline(self, value: float) -> None: |
|
self._deadline = float(value) |
|
if self._timeout_handle is not None: |
|
self._timeout_handle.cancel() |
|
self._timeout_handle = None |
|
|
|
if self._active and not self._cancel_called: |
|
self._timeout() |
|
|
|
@property |
|
def cancel_called(self) -> bool: |
|
return self._cancel_called |
|
|
|
@property |
|
def cancelled_caught(self) -> bool: |
|
return self._cancelled_caught |
|
|
|
@property |
|
def shield(self) -> bool: |
|
return self._shield |
|
|
|
@shield.setter |
|
def shield(self, value: bool) -> None: |
|
if self._shield != value: |
|
self._shield = value |
|
if not value: |
|
self._restart_cancellation_in_parent() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TaskState: |
|
""" |
|
Encapsulates auxiliary task information that cannot be added to the Task instance |
|
itself because there are no guarantees about its implementation. |
|
""" |
|
|
|
__slots__ = "parent_id", "cancel_scope", "__weakref__" |
|
|
|
def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None): |
|
self.parent_id = parent_id |
|
self.cancel_scope = cancel_scope |
|
|
|
|
|
_task_states: WeakKeyDictionary[asyncio.Task, TaskState] = WeakKeyDictionary() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AsyncioTaskStatus(abc.TaskStatus): |
|
def __init__(self, future: asyncio.Future, parent_id: int): |
|
self._future = future |
|
self._parent_id = parent_id |
|
|
|
def started(self, value: T_contra | None = None) -> None: |
|
try: |
|
self._future.set_result(value) |
|
except asyncio.InvalidStateError: |
|
if not self._future.cancelled(): |
|
raise RuntimeError( |
|
"called 'started' twice on the same task status" |
|
) from None |
|
|
|
task = cast(asyncio.Task, current_task()) |
|
_task_states[task].parent_id = self._parent_id |
|
|
|
|
|
def iterate_exceptions( |
|
exception: BaseException, |
|
) -> Generator[BaseException, None, None]: |
|
if isinstance(exception, BaseExceptionGroup): |
|
for exc in exception.exceptions: |
|
yield from iterate_exceptions(exc) |
|
else: |
|
yield exception |
|
|
|
|
|
class TaskGroup(abc.TaskGroup): |
|
def __init__(self) -> None: |
|
self.cancel_scope: CancelScope = CancelScope() |
|
self._active = False |
|
self._exceptions: list[BaseException] = [] |
|
self._tasks: set[asyncio.Task] = set() |
|
|
|
async def __aenter__(self) -> TaskGroup: |
|
self.cancel_scope.__enter__() |
|
self._active = True |
|
return self |
|
|
|
async def __aexit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> bool | None: |
|
ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) |
|
if exc_val is not None: |
|
self.cancel_scope.cancel() |
|
if not isinstance(exc_val, CancelledError): |
|
self._exceptions.append(exc_val) |
|
|
|
cancelled_exc_while_waiting_tasks: CancelledError | None = None |
|
while self._tasks: |
|
try: |
|
await asyncio.wait(self._tasks) |
|
except CancelledError as exc: |
|
|
|
|
|
self.cancel_scope.cancel() |
|
if cancelled_exc_while_waiting_tasks is None: |
|
cancelled_exc_while_waiting_tasks = exc |
|
|
|
self._active = False |
|
if self._exceptions: |
|
raise BaseExceptionGroup( |
|
"unhandled errors in a TaskGroup", self._exceptions |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if cancelled_exc_while_waiting_tasks: |
|
if exc_val is None or ignore_exception: |
|
raise cancelled_exc_while_waiting_tasks |
|
|
|
return ignore_exception |
|
|
|
def _spawn( |
|
self, |
|
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], |
|
args: tuple[Unpack[PosArgsT]], |
|
name: object, |
|
task_status_future: asyncio.Future | None = None, |
|
) -> asyncio.Task: |
|
def task_done(_task: asyncio.Task) -> None: |
|
task_state = _task_states[_task] |
|
assert task_state.cancel_scope is not None |
|
assert _task in task_state.cancel_scope._tasks |
|
task_state.cancel_scope._tasks.remove(_task) |
|
self._tasks.remove(task) |
|
del _task_states[_task] |
|
|
|
try: |
|
exc = _task.exception() |
|
except CancelledError as e: |
|
while isinstance(e.__context__, CancelledError): |
|
e = e.__context__ |
|
|
|
exc = e |
|
|
|
if exc is not None: |
|
|
|
|
|
|
|
if task_status_future is not None and task_status_future.cancelled(): |
|
return |
|
|
|
if task_status_future is None or task_status_future.done(): |
|
if not isinstance(exc, CancelledError): |
|
self._exceptions.append(exc) |
|
|
|
if not self.cancel_scope._parent_cancelled(): |
|
self.cancel_scope.cancel() |
|
else: |
|
task_status_future.set_exception(exc) |
|
elif task_status_future is not None and not task_status_future.done(): |
|
task_status_future.set_exception( |
|
RuntimeError("Child exited without calling task_status.started()") |
|
) |
|
|
|
if not self._active: |
|
raise RuntimeError( |
|
"This task group is not active; no new tasks can be started." |
|
) |
|
|
|
kwargs = {} |
|
if task_status_future: |
|
parent_id = id(current_task()) |
|
kwargs["task_status"] = _AsyncioTaskStatus( |
|
task_status_future, id(self.cancel_scope._host_task) |
|
) |
|
else: |
|
parent_id = id(self.cancel_scope._host_task) |
|
|
|
coro = func(*args, **kwargs) |
|
if not iscoroutine(coro): |
|
prefix = f"{func.__module__}." if hasattr(func, "__module__") else "" |
|
raise TypeError( |
|
f"Expected {prefix}{func.__qualname__}() to return a coroutine, but " |
|
f"the return value ({coro!r}) is not a coroutine object" |
|
) |
|
|
|
name = get_callable_name(func) if name is None else str(name) |
|
task = create_task(coro, name=name) |
|
task.add_done_callback(task_done) |
|
|
|
|
|
_task_states[task] = TaskState( |
|
parent_id=parent_id, cancel_scope=self.cancel_scope |
|
) |
|
self.cancel_scope._tasks.add(task) |
|
self._tasks.add(task) |
|
return task |
|
|
|
def start_soon( |
|
self, |
|
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], |
|
*args: Unpack[PosArgsT], |
|
name: object = None, |
|
) -> None: |
|
self._spawn(func, args, name) |
|
|
|
async def start( |
|
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None |
|
) -> Any: |
|
future: asyncio.Future = asyncio.Future() |
|
task = self._spawn(func, args, name, future) |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
return await future |
|
except CancelledError: |
|
|
|
task.cancel() |
|
with CancelScope(shield=True), suppress(CancelledError): |
|
await task |
|
|
|
raise |
|
|
|
|
|
|
|
|
|
|
|
|
|
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]] |
|
|
|
|
|
class WorkerThread(Thread): |
|
MAX_IDLE_TIME = 10 |
|
|
|
def __init__( |
|
self, |
|
root_task: asyncio.Task, |
|
workers: set[WorkerThread], |
|
idle_workers: deque[WorkerThread], |
|
): |
|
super().__init__(name="AnyIO worker thread") |
|
self.root_task = root_task |
|
self.workers = workers |
|
self.idle_workers = idle_workers |
|
self.loop = root_task._loop |
|
self.queue: Queue[ |
|
tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None |
|
] = Queue(2) |
|
self.idle_since = AsyncIOBackend.current_time() |
|
self.stopping = False |
|
|
|
def _report_result( |
|
self, future: asyncio.Future, result: Any, exc: BaseException | None |
|
) -> None: |
|
self.idle_since = AsyncIOBackend.current_time() |
|
if not self.stopping: |
|
self.idle_workers.append(self) |
|
|
|
if not future.cancelled(): |
|
if exc is not None: |
|
if isinstance(exc, StopIteration): |
|
new_exc = RuntimeError("coroutine raised StopIteration") |
|
new_exc.__cause__ = exc |
|
exc = new_exc |
|
|
|
future.set_exception(exc) |
|
else: |
|
future.set_result(result) |
|
|
|
def run(self) -> None: |
|
with claim_worker_thread(AsyncIOBackend, self.loop): |
|
while True: |
|
item = self.queue.get() |
|
if item is None: |
|
|
|
return |
|
|
|
context, func, args, future, cancel_scope = item |
|
if not future.cancelled(): |
|
result = None |
|
exception: BaseException | None = None |
|
threadlocals.current_cancel_scope = cancel_scope |
|
try: |
|
result = context.run(func, *args) |
|
except BaseException as exc: |
|
exception = exc |
|
finally: |
|
del threadlocals.current_cancel_scope |
|
|
|
if not self.loop.is_closed(): |
|
self.loop.call_soon_threadsafe( |
|
self._report_result, future, result, exception |
|
) |
|
|
|
self.queue.task_done() |
|
|
|
def stop(self, f: asyncio.Task | None = None) -> None: |
|
self.stopping = True |
|
self.queue.put_nowait(None) |
|
self.workers.discard(self) |
|
try: |
|
self.idle_workers.remove(self) |
|
except ValueError: |
|
pass |
|
|
|
|
|
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( |
|
"_threadpool_idle_workers" |
|
) |
|
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") |
|
|
|
|
|
class BlockingPortal(abc.BlockingPortal): |
|
def __new__(cls) -> BlockingPortal: |
|
return object.__new__(cls) |
|
|
|
def __init__(self) -> None: |
|
super().__init__() |
|
self._loop = get_running_loop() |
|
|
|
def _spawn_task_from_thread( |
|
self, |
|
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], |
|
args: tuple[Unpack[PosArgsT]], |
|
kwargs: dict[str, Any], |
|
name: object, |
|
future: Future[T_Retval], |
|
) -> None: |
|
AsyncIOBackend.run_sync_from_thread( |
|
partial(self._task_group.start_soon, name=name), |
|
(self._call_func, func, args, kwargs, future), |
|
self._loop, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(eq=False) |
|
class StreamReaderWrapper(abc.ByteReceiveStream): |
|
_stream: asyncio.StreamReader |
|
|
|
async def receive(self, max_bytes: int = 65536) -> bytes: |
|
data = await self._stream.read(max_bytes) |
|
if data: |
|
return data |
|
else: |
|
raise EndOfStream |
|
|
|
async def aclose(self) -> None: |
|
self._stream.feed_eof() |
|
await AsyncIOBackend.checkpoint() |
|
|
|
|
|
@dataclass(eq=False) |
|
class StreamWriterWrapper(abc.ByteSendStream): |
|
_stream: asyncio.StreamWriter |
|
|
|
async def send(self, item: bytes) -> None: |
|
self._stream.write(item) |
|
await self._stream.drain() |
|
|
|
async def aclose(self) -> None: |
|
self._stream.close() |
|
await AsyncIOBackend.checkpoint() |
|
|
|
|
|
@dataclass(eq=False) |
|
class Process(abc.Process): |
|
_process: asyncio.subprocess.Process |
|
_stdin: StreamWriterWrapper | None |
|
_stdout: StreamReaderWrapper | None |
|
_stderr: StreamReaderWrapper | None |
|
|
|
async def aclose(self) -> None: |
|
with CancelScope(shield=True): |
|
if self._stdin: |
|
await self._stdin.aclose() |
|
if self._stdout: |
|
await self._stdout.aclose() |
|
if self._stderr: |
|
await self._stderr.aclose() |
|
|
|
try: |
|
await self.wait() |
|
except BaseException: |
|
self.kill() |
|
with CancelScope(shield=True): |
|
await self.wait() |
|
|
|
raise |
|
|
|
async def wait(self) -> int: |
|
return await self._process.wait() |
|
|
|
def terminate(self) -> None: |
|
self._process.terminate() |
|
|
|
def kill(self) -> None: |
|
self._process.kill() |
|
|
|
def send_signal(self, signal: int) -> None: |
|
self._process.send_signal(signal) |
|
|
|
@property |
|
def pid(self) -> int: |
|
return self._process.pid |
|
|
|
@property |
|
def returncode(self) -> int | None: |
|
return self._process.returncode |
|
|
|
@property |
|
def stdin(self) -> abc.ByteSendStream | None: |
|
return self._stdin |
|
|
|
@property |
|
def stdout(self) -> abc.ByteReceiveStream | None: |
|
return self._stdout |
|
|
|
@property |
|
def stderr(self) -> abc.ByteReceiveStream | None: |
|
return self._stderr |
|
|
|
|
|
def _forcibly_shutdown_process_pool_on_exit( |
|
workers: set[Process], _task: object |
|
) -> None: |
|
""" |
|
Forcibly shuts down worker processes belonging to this event loop.""" |
|
child_watcher: asyncio.AbstractChildWatcher | None = None |
|
if sys.version_info < (3, 12): |
|
try: |
|
child_watcher = asyncio.get_event_loop_policy().get_child_watcher() |
|
except NotImplementedError: |
|
pass |
|
|
|
|
|
for process in workers: |
|
if process.returncode is None: |
|
continue |
|
|
|
process._stdin._stream._transport.close() |
|
process._stdout._stream._transport.close() |
|
process._stderr._stream._transport.close() |
|
process.kill() |
|
if child_watcher: |
|
child_watcher.remove_child_handler(process.pid) |
|
|
|
|
|
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None: |
|
""" |
|
Shuts down worker processes belonging to this event loop. |
|
|
|
NOTE: this only works when the event loop was started using asyncio.run() or |
|
anyio.run(). |
|
|
|
""" |
|
process: abc.Process |
|
try: |
|
await sleep(math.inf) |
|
except asyncio.CancelledError: |
|
for process in workers: |
|
if process.returncode is None: |
|
process.kill() |
|
|
|
for process in workers: |
|
await process.aclose() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class StreamProtocol(asyncio.Protocol): |
|
read_queue: deque[bytes] |
|
read_event: asyncio.Event |
|
write_event: asyncio.Event |
|
exception: Exception | None = None |
|
is_at_eof: bool = False |
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None: |
|
self.read_queue = deque() |
|
self.read_event = asyncio.Event() |
|
self.write_event = asyncio.Event() |
|
self.write_event.set() |
|
cast(asyncio.Transport, transport).set_write_buffer_limits(0) |
|
|
|
def connection_lost(self, exc: Exception | None) -> None: |
|
if exc: |
|
self.exception = BrokenResourceError() |
|
self.exception.__cause__ = exc |
|
|
|
self.read_event.set() |
|
self.write_event.set() |
|
|
|
def data_received(self, data: bytes) -> None: |
|
self.read_queue.append(data) |
|
self.read_event.set() |
|
|
|
def eof_received(self) -> bool | None: |
|
self.is_at_eof = True |
|
self.read_event.set() |
|
return True |
|
|
|
def pause_writing(self) -> None: |
|
self.write_event = asyncio.Event() |
|
|
|
def resume_writing(self) -> None: |
|
self.write_event.set() |
|
|
|
|
|
class DatagramProtocol(asyncio.DatagramProtocol): |
|
read_queue: deque[tuple[bytes, IPSockAddrType]] |
|
read_event: asyncio.Event |
|
write_event: asyncio.Event |
|
exception: Exception | None = None |
|
|
|
def connection_made(self, transport: asyncio.BaseTransport) -> None: |
|
self.read_queue = deque(maxlen=100) |
|
self.read_event = asyncio.Event() |
|
self.write_event = asyncio.Event() |
|
self.write_event.set() |
|
|
|
def connection_lost(self, exc: Exception | None) -> None: |
|
self.read_event.set() |
|
self.write_event.set() |
|
|
|
def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None: |
|
addr = convert_ipv6_sockaddr(addr) |
|
self.read_queue.append((data, addr)) |
|
self.read_event.set() |
|
|
|
def error_received(self, exc: Exception) -> None: |
|
self.exception = exc |
|
|
|
def pause_writing(self) -> None: |
|
self.write_event.clear() |
|
|
|
def resume_writing(self) -> None: |
|
self.write_event.set() |
|
|
|
|
|
class SocketStream(abc.SocketStream): |
|
def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): |
|
self._transport = transport |
|
self._protocol = protocol |
|
self._receive_guard = ResourceGuard("reading from") |
|
self._send_guard = ResourceGuard("writing to") |
|
self._closed = False |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self._transport.get_extra_info("socket") |
|
|
|
async def receive(self, max_bytes: int = 65536) -> bytes: |
|
with self._receive_guard: |
|
if ( |
|
not self._protocol.read_event.is_set() |
|
and not self._transport.is_closing() |
|
and not self._protocol.is_at_eof |
|
): |
|
self._transport.resume_reading() |
|
await self._protocol.read_event.wait() |
|
self._transport.pause_reading() |
|
else: |
|
await AsyncIOBackend.checkpoint() |
|
|
|
try: |
|
chunk = self._protocol.read_queue.popleft() |
|
except IndexError: |
|
if self._closed: |
|
raise ClosedResourceError from None |
|
elif self._protocol.exception: |
|
raise self._protocol.exception from None |
|
else: |
|
raise EndOfStream from None |
|
|
|
if len(chunk) > max_bytes: |
|
|
|
chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] |
|
self._protocol.read_queue.appendleft(leftover) |
|
|
|
|
|
|
|
if not self._protocol.read_queue: |
|
self._protocol.read_event.clear() |
|
|
|
return chunk |
|
|
|
async def send(self, item: bytes) -> None: |
|
with self._send_guard: |
|
await AsyncIOBackend.checkpoint() |
|
|
|
if self._closed: |
|
raise ClosedResourceError |
|
elif self._protocol.exception is not None: |
|
raise self._protocol.exception |
|
|
|
try: |
|
self._transport.write(item) |
|
except RuntimeError as exc: |
|
if self._transport.is_closing(): |
|
raise BrokenResourceError from exc |
|
else: |
|
raise |
|
|
|
await self._protocol.write_event.wait() |
|
|
|
async def send_eof(self) -> None: |
|
try: |
|
self._transport.write_eof() |
|
except OSError: |
|
pass |
|
|
|
async def aclose(self) -> None: |
|
if not self._transport.is_closing(): |
|
self._closed = True |
|
try: |
|
self._transport.write_eof() |
|
except OSError: |
|
pass |
|
|
|
self._transport.close() |
|
await sleep(0) |
|
self._transport.abort() |
|
|
|
|
|
class _RawSocketMixin: |
|
_receive_future: asyncio.Future | None = None |
|
_send_future: asyncio.Future | None = None |
|
_closing = False |
|
|
|
def __init__(self, raw_socket: socket.socket): |
|
self.__raw_socket = raw_socket |
|
self._receive_guard = ResourceGuard("reading from") |
|
self._send_guard = ResourceGuard("writing to") |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self.__raw_socket |
|
|
|
def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: |
|
def callback(f: object) -> None: |
|
del self._receive_future |
|
loop.remove_reader(self.__raw_socket) |
|
|
|
f = self._receive_future = asyncio.Future() |
|
loop.add_reader(self.__raw_socket, f.set_result, None) |
|
f.add_done_callback(callback) |
|
return f |
|
|
|
def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: |
|
def callback(f: object) -> None: |
|
del self._send_future |
|
loop.remove_writer(self.__raw_socket) |
|
|
|
f = self._send_future = asyncio.Future() |
|
loop.add_writer(self.__raw_socket, f.set_result, None) |
|
f.add_done_callback(callback) |
|
return f |
|
|
|
async def aclose(self) -> None: |
|
if not self._closing: |
|
self._closing = True |
|
if self.__raw_socket.fileno() != -1: |
|
self.__raw_socket.close() |
|
|
|
if self._receive_future: |
|
self._receive_future.set_result(None) |
|
if self._send_future: |
|
self._send_future.set_result(None) |
|
|
|
|
|
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream): |
|
async def send_eof(self) -> None: |
|
with self._send_guard: |
|
self._raw_socket.shutdown(socket.SHUT_WR) |
|
|
|
async def receive(self, max_bytes: int = 65536) -> bytes: |
|
loop = get_running_loop() |
|
await AsyncIOBackend.checkpoint() |
|
with self._receive_guard: |
|
while True: |
|
try: |
|
data = self._raw_socket.recv(max_bytes) |
|
except BlockingIOError: |
|
await self._wait_until_readable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
if not data: |
|
raise EndOfStream |
|
|
|
return data |
|
|
|
async def send(self, item: bytes) -> None: |
|
loop = get_running_loop() |
|
await AsyncIOBackend.checkpoint() |
|
with self._send_guard: |
|
view = memoryview(item) |
|
while view: |
|
try: |
|
bytes_sent = self._raw_socket.send(view) |
|
except BlockingIOError: |
|
await self._wait_until_writable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
view = view[bytes_sent:] |
|
|
|
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: |
|
if not isinstance(msglen, int) or msglen < 0: |
|
raise ValueError("msglen must be a non-negative integer") |
|
if not isinstance(maxfds, int) or maxfds < 1: |
|
raise ValueError("maxfds must be a positive integer") |
|
|
|
loop = get_running_loop() |
|
fds = array.array("i") |
|
await AsyncIOBackend.checkpoint() |
|
with self._receive_guard: |
|
while True: |
|
try: |
|
message, ancdata, flags, addr = self._raw_socket.recvmsg( |
|
msglen, socket.CMSG_LEN(maxfds * fds.itemsize) |
|
) |
|
except BlockingIOError: |
|
await self._wait_until_readable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
if not message and not ancdata: |
|
raise EndOfStream |
|
|
|
break |
|
|
|
for cmsg_level, cmsg_type, cmsg_data in ancdata: |
|
if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: |
|
raise RuntimeError( |
|
f"Received unexpected ancillary data; message = {message!r}, " |
|
f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" |
|
) |
|
|
|
fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) |
|
|
|
return message, list(fds) |
|
|
|
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: |
|
if not message: |
|
raise ValueError("message must not be empty") |
|
if not fds: |
|
raise ValueError("fds must not be empty") |
|
|
|
loop = get_running_loop() |
|
filenos: list[int] = [] |
|
for fd in fds: |
|
if isinstance(fd, int): |
|
filenos.append(fd) |
|
elif isinstance(fd, IOBase): |
|
filenos.append(fd.fileno()) |
|
|
|
fdarray = array.array("i", filenos) |
|
await AsyncIOBackend.checkpoint() |
|
with self._send_guard: |
|
while True: |
|
try: |
|
|
|
|
|
self._raw_socket.sendmsg( |
|
[message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] |
|
) |
|
break |
|
except BlockingIOError: |
|
await self._wait_until_writable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
|
|
|
|
class TCPSocketListener(abc.SocketListener): |
|
_accept_scope: CancelScope | None = None |
|
_closed = False |
|
|
|
def __init__(self, raw_socket: socket.socket): |
|
self.__raw_socket = raw_socket |
|
self._loop = cast(asyncio.BaseEventLoop, get_running_loop()) |
|
self._accept_guard = ResourceGuard("accepting connections from") |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self.__raw_socket |
|
|
|
async def accept(self) -> abc.SocketStream: |
|
if self._closed: |
|
raise ClosedResourceError |
|
|
|
with self._accept_guard: |
|
await AsyncIOBackend.checkpoint() |
|
with CancelScope() as self._accept_scope: |
|
try: |
|
client_sock, _addr = await self._loop.sock_accept(self._raw_socket) |
|
except asyncio.CancelledError: |
|
|
|
try: |
|
self._loop.remove_reader(self._raw_socket) |
|
except (ValueError, NotImplementedError): |
|
pass |
|
|
|
if self._closed: |
|
raise ClosedResourceError from None |
|
|
|
raise |
|
finally: |
|
self._accept_scope = None |
|
|
|
client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) |
|
transport, protocol = await self._loop.connect_accepted_socket( |
|
StreamProtocol, client_sock |
|
) |
|
return SocketStream(transport, protocol) |
|
|
|
async def aclose(self) -> None: |
|
if self._closed: |
|
return |
|
|
|
self._closed = True |
|
if self._accept_scope: |
|
|
|
try: |
|
self._loop.remove_reader(self._raw_socket) |
|
except (ValueError, NotImplementedError): |
|
pass |
|
|
|
self._accept_scope.cancel() |
|
await sleep(0) |
|
|
|
self._raw_socket.close() |
|
|
|
|
|
class UNIXSocketListener(abc.SocketListener): |
|
def __init__(self, raw_socket: socket.socket): |
|
self.__raw_socket = raw_socket |
|
self._loop = get_running_loop() |
|
self._accept_guard = ResourceGuard("accepting connections from") |
|
self._closed = False |
|
|
|
async def accept(self) -> abc.SocketStream: |
|
await AsyncIOBackend.checkpoint() |
|
with self._accept_guard: |
|
while True: |
|
try: |
|
client_sock, _ = self.__raw_socket.accept() |
|
client_sock.setblocking(False) |
|
return UNIXSocketStream(client_sock) |
|
except BlockingIOError: |
|
f: asyncio.Future = asyncio.Future() |
|
self._loop.add_reader(self.__raw_socket, f.set_result, None) |
|
f.add_done_callback( |
|
lambda _: self._loop.remove_reader(self.__raw_socket) |
|
) |
|
await f |
|
except OSError as exc: |
|
if self._closed: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
|
|
async def aclose(self) -> None: |
|
self._closed = True |
|
self.__raw_socket.close() |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self.__raw_socket |
|
|
|
|
|
class UDPSocket(abc.UDPSocket): |
|
def __init__( |
|
self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol |
|
): |
|
self._transport = transport |
|
self._protocol = protocol |
|
self._receive_guard = ResourceGuard("reading from") |
|
self._send_guard = ResourceGuard("writing to") |
|
self._closed = False |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self._transport.get_extra_info("socket") |
|
|
|
async def aclose(self) -> None: |
|
if not self._transport.is_closing(): |
|
self._closed = True |
|
self._transport.close() |
|
|
|
async def receive(self) -> tuple[bytes, IPSockAddrType]: |
|
with self._receive_guard: |
|
await AsyncIOBackend.checkpoint() |
|
|
|
|
|
if not self._protocol.read_queue and not self._transport.is_closing(): |
|
self._protocol.read_event.clear() |
|
await self._protocol.read_event.wait() |
|
|
|
try: |
|
return self._protocol.read_queue.popleft() |
|
except IndexError: |
|
if self._closed: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from None |
|
|
|
async def send(self, item: UDPPacketType) -> None: |
|
with self._send_guard: |
|
await AsyncIOBackend.checkpoint() |
|
await self._protocol.write_event.wait() |
|
if self._closed: |
|
raise ClosedResourceError |
|
elif self._transport.is_closing(): |
|
raise BrokenResourceError |
|
else: |
|
self._transport.sendto(*item) |
|
|
|
|
|
class ConnectedUDPSocket(abc.ConnectedUDPSocket): |
|
def __init__( |
|
self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol |
|
): |
|
self._transport = transport |
|
self._protocol = protocol |
|
self._receive_guard = ResourceGuard("reading from") |
|
self._send_guard = ResourceGuard("writing to") |
|
self._closed = False |
|
|
|
@property |
|
def _raw_socket(self) -> socket.socket: |
|
return self._transport.get_extra_info("socket") |
|
|
|
async def aclose(self) -> None: |
|
if not self._transport.is_closing(): |
|
self._closed = True |
|
self._transport.close() |
|
|
|
async def receive(self) -> bytes: |
|
with self._receive_guard: |
|
await AsyncIOBackend.checkpoint() |
|
|
|
|
|
if not self._protocol.read_queue and not self._transport.is_closing(): |
|
self._protocol.read_event.clear() |
|
await self._protocol.read_event.wait() |
|
|
|
try: |
|
packet = self._protocol.read_queue.popleft() |
|
except IndexError: |
|
if self._closed: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from None |
|
|
|
return packet[0] |
|
|
|
async def send(self, item: bytes) -> None: |
|
with self._send_guard: |
|
await AsyncIOBackend.checkpoint() |
|
await self._protocol.write_event.wait() |
|
if self._closed: |
|
raise ClosedResourceError |
|
elif self._transport.is_closing(): |
|
raise BrokenResourceError |
|
else: |
|
self._transport.sendto(item) |
|
|
|
|
|
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket): |
|
async def receive(self) -> UNIXDatagramPacketType: |
|
loop = get_running_loop() |
|
await AsyncIOBackend.checkpoint() |
|
with self._receive_guard: |
|
while True: |
|
try: |
|
data = self._raw_socket.recvfrom(65536) |
|
except BlockingIOError: |
|
await self._wait_until_readable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
return data |
|
|
|
async def send(self, item: UNIXDatagramPacketType) -> None: |
|
loop = get_running_loop() |
|
await AsyncIOBackend.checkpoint() |
|
with self._send_guard: |
|
while True: |
|
try: |
|
self._raw_socket.sendto(*item) |
|
except BlockingIOError: |
|
await self._wait_until_writable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
return |
|
|
|
|
|
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket): |
|
async def receive(self) -> bytes: |
|
loop = get_running_loop() |
|
await AsyncIOBackend.checkpoint() |
|
with self._receive_guard: |
|
while True: |
|
try: |
|
data = self._raw_socket.recv(65536) |
|
except BlockingIOError: |
|
await self._wait_until_readable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
return data |
|
|
|
async def send(self, item: bytes) -> None: |
|
loop = get_running_loop() |
|
await AsyncIOBackend.checkpoint() |
|
with self._send_guard: |
|
while True: |
|
try: |
|
self._raw_socket.send(item) |
|
except BlockingIOError: |
|
await self._wait_until_writable(loop) |
|
except OSError as exc: |
|
if self._closing: |
|
raise ClosedResourceError from None |
|
else: |
|
raise BrokenResourceError from exc |
|
else: |
|
return |
|
|
|
|
|
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events") |
|
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Event(BaseEvent): |
|
def __new__(cls) -> Event: |
|
return object.__new__(cls) |
|
|
|
def __init__(self) -> None: |
|
self._event = asyncio.Event() |
|
|
|
def set(self) -> None: |
|
self._event.set() |
|
|
|
def is_set(self) -> bool: |
|
return self._event.is_set() |
|
|
|
async def wait(self) -> None: |
|
if self.is_set(): |
|
await AsyncIOBackend.checkpoint() |
|
else: |
|
await self._event.wait() |
|
|
|
def statistics(self) -> EventStatistics: |
|
return EventStatistics(len(self._event._waiters)) |
|
|
|
|
|
class CapacityLimiter(BaseCapacityLimiter): |
|
_total_tokens: float = 0 |
|
|
|
def __new__(cls, total_tokens: float) -> CapacityLimiter: |
|
return object.__new__(cls) |
|
|
|
def __init__(self, total_tokens: float): |
|
self._borrowers: set[Any] = set() |
|
self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict() |
|
self.total_tokens = total_tokens |
|
|
|
async def __aenter__(self) -> None: |
|
await self.acquire() |
|
|
|
async def __aexit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> None: |
|
self.release() |
|
|
|
@property |
|
def total_tokens(self) -> float: |
|
return self._total_tokens |
|
|
|
@total_tokens.setter |
|
def total_tokens(self, value: float) -> None: |
|
if not isinstance(value, int) and not math.isinf(value): |
|
raise TypeError("total_tokens must be an int or math.inf") |
|
if value < 1: |
|
raise ValueError("total_tokens must be >= 1") |
|
|
|
waiters_to_notify = max(value - self._total_tokens, 0) |
|
self._total_tokens = value |
|
|
|
|
|
while self._wait_queue and waiters_to_notify: |
|
event = self._wait_queue.popitem(last=False)[1] |
|
event.set() |
|
waiters_to_notify -= 1 |
|
|
|
@property |
|
def borrowed_tokens(self) -> int: |
|
return len(self._borrowers) |
|
|
|
@property |
|
def available_tokens(self) -> float: |
|
return self._total_tokens - len(self._borrowers) |
|
|
|
def acquire_nowait(self) -> None: |
|
self.acquire_on_behalf_of_nowait(current_task()) |
|
|
|
def acquire_on_behalf_of_nowait(self, borrower: object) -> None: |
|
if borrower in self._borrowers: |
|
raise RuntimeError( |
|
"this borrower is already holding one of this CapacityLimiter's " |
|
"tokens" |
|
) |
|
|
|
if self._wait_queue or len(self._borrowers) >= self._total_tokens: |
|
raise WouldBlock |
|
|
|
self._borrowers.add(borrower) |
|
|
|
async def acquire(self) -> None: |
|
return await self.acquire_on_behalf_of(current_task()) |
|
|
|
async def acquire_on_behalf_of(self, borrower: object) -> None: |
|
await AsyncIOBackend.checkpoint_if_cancelled() |
|
try: |
|
self.acquire_on_behalf_of_nowait(borrower) |
|
except WouldBlock: |
|
event = asyncio.Event() |
|
self._wait_queue[borrower] = event |
|
try: |
|
await event.wait() |
|
except BaseException: |
|
self._wait_queue.pop(borrower, None) |
|
raise |
|
|
|
self._borrowers.add(borrower) |
|
else: |
|
try: |
|
await AsyncIOBackend.cancel_shielded_checkpoint() |
|
except BaseException: |
|
self.release() |
|
raise |
|
|
|
def release(self) -> None: |
|
self.release_on_behalf_of(current_task()) |
|
|
|
def release_on_behalf_of(self, borrower: object) -> None: |
|
try: |
|
self._borrowers.remove(borrower) |
|
except KeyError: |
|
raise RuntimeError( |
|
"this borrower isn't holding any of this CapacityLimiter's tokens" |
|
) from None |
|
|
|
|
|
if self._wait_queue and len(self._borrowers) < self._total_tokens: |
|
event = self._wait_queue.popitem(last=False)[1] |
|
event.set() |
|
|
|
def statistics(self) -> CapacityLimiterStatistics: |
|
return CapacityLimiterStatistics( |
|
self.borrowed_tokens, |
|
self.total_tokens, |
|
tuple(self._borrowers), |
|
len(self._wait_queue), |
|
) |
|
|
|
|
|
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _SignalReceiver: |
|
def __init__(self, signals: tuple[Signals, ...]): |
|
self._signals = signals |
|
self._loop = get_running_loop() |
|
self._signal_queue: deque[Signals] = deque() |
|
self._future: asyncio.Future = asyncio.Future() |
|
self._handled_signals: set[Signals] = set() |
|
|
|
def _deliver(self, signum: Signals) -> None: |
|
self._signal_queue.append(signum) |
|
if not self._future.done(): |
|
self._future.set_result(None) |
|
|
|
def __enter__(self) -> _SignalReceiver: |
|
for sig in set(self._signals): |
|
self._loop.add_signal_handler(sig, self._deliver, sig) |
|
self._handled_signals.add(sig) |
|
|
|
return self |
|
|
|
def __exit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> bool | None: |
|
for sig in self._handled_signals: |
|
self._loop.remove_signal_handler(sig) |
|
return None |
|
|
|
def __aiter__(self) -> _SignalReceiver: |
|
return self |
|
|
|
async def __anext__(self) -> Signals: |
|
await AsyncIOBackend.checkpoint() |
|
if not self._signal_queue: |
|
self._future = asyncio.Future() |
|
await self._future |
|
|
|
return self._signal_queue.popleft() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AsyncIOTaskInfo(TaskInfo): |
|
def __init__(self, task: asyncio.Task): |
|
task_state = _task_states.get(task) |
|
if task_state is None: |
|
parent_id = None |
|
else: |
|
parent_id = task_state.parent_id |
|
|
|
super().__init__(id(task), parent_id, task.get_name(), task.get_coro()) |
|
self._task = weakref.ref(task) |
|
|
|
def has_pending_cancellation(self) -> bool: |
|
if not (task := self._task()): |
|
|
|
return False |
|
|
|
if sys.version_info >= (3, 11): |
|
if task.cancelling(): |
|
return True |
|
elif ( |
|
isinstance(task._fut_waiter, asyncio.Future) |
|
and task._fut_waiter.cancelled() |
|
): |
|
return True |
|
|
|
if task_state := _task_states.get(task): |
|
if cancel_scope := task_state.cancel_scope: |
|
return cancel_scope.cancel_called or cancel_scope._parent_cancelled() |
|
|
|
return False |
|
|
|
|
|
class TestRunner(abc.TestRunner): |
|
_send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]] |
|
|
|
def __init__( |
|
self, |
|
*, |
|
debug: bool | None = None, |
|
use_uvloop: bool = False, |
|
loop_factory: Callable[[], AbstractEventLoop] | None = None, |
|
) -> None: |
|
if use_uvloop and loop_factory is None: |
|
import uvloop |
|
|
|
loop_factory = uvloop.new_event_loop |
|
|
|
self._runner = Runner(debug=debug, loop_factory=loop_factory) |
|
self._exceptions: list[BaseException] = [] |
|
self._runner_task: asyncio.Task | None = None |
|
|
|
def __enter__(self) -> TestRunner: |
|
self._runner.__enter__() |
|
self.get_loop().set_exception_handler(self._exception_handler) |
|
return self |
|
|
|
def __exit__( |
|
self, |
|
exc_type: type[BaseException] | None, |
|
exc_val: BaseException | None, |
|
exc_tb: TracebackType | None, |
|
) -> None: |
|
self._runner.__exit__(exc_type, exc_val, exc_tb) |
|
|
|
def get_loop(self) -> AbstractEventLoop: |
|
return self._runner.get_loop() |
|
|
|
def _exception_handler( |
|
self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] |
|
) -> None: |
|
if isinstance(context.get("exception"), Exception): |
|
self._exceptions.append(context["exception"]) |
|
else: |
|
loop.default_exception_handler(context) |
|
|
|
def _raise_async_exceptions(self) -> None: |
|
|
|
if self._exceptions: |
|
exceptions, self._exceptions = self._exceptions, [] |
|
if len(exceptions) == 1: |
|
raise exceptions[0] |
|
elif exceptions: |
|
raise BaseExceptionGroup( |
|
"Multiple exceptions occurred in asynchronous callbacks", exceptions |
|
) |
|
|
|
async def _run_tests_and_fixtures( |
|
self, |
|
receive_stream: MemoryObjectReceiveStream[ |
|
tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]] |
|
], |
|
) -> None: |
|
with receive_stream, self._send_stream: |
|
async for coro, future in receive_stream: |
|
try: |
|
retval = await coro |
|
except BaseException as exc: |
|
if not future.cancelled(): |
|
future.set_exception(exc) |
|
else: |
|
if not future.cancelled(): |
|
future.set_result(retval) |
|
|
|
async def _call_in_runner_task( |
|
self, |
|
func: Callable[P, Awaitable[T_Retval]], |
|
*args: P.args, |
|
**kwargs: P.kwargs, |
|
) -> T_Retval: |
|
if not self._runner_task: |
|
self._send_stream, receive_stream = create_memory_object_stream[ |
|
Tuple[Awaitable[Any], asyncio.Future] |
|
](1) |
|
self._runner_task = self.get_loop().create_task( |
|
self._run_tests_and_fixtures(receive_stream) |
|
) |
|
|
|
coro = func(*args, **kwargs) |
|
future: asyncio.Future[T_Retval] = self.get_loop().create_future() |
|
self._send_stream.send_nowait((coro, future)) |
|
return await future |
|
|
|
def run_asyncgen_fixture( |
|
self, |
|
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], |
|
kwargs: dict[str, Any], |
|
) -> Iterable[T_Retval]: |
|
asyncgen = fixture_func(**kwargs) |
|
fixturevalue: T_Retval = self.get_loop().run_until_complete( |
|
self._call_in_runner_task(asyncgen.asend, None) |
|
) |
|
self._raise_async_exceptions() |
|
|
|
yield fixturevalue |
|
|
|
try: |
|
self.get_loop().run_until_complete( |
|
self._call_in_runner_task(asyncgen.asend, None) |
|
) |
|
except StopAsyncIteration: |
|
self._raise_async_exceptions() |
|
else: |
|
self.get_loop().run_until_complete(asyncgen.aclose()) |
|
raise RuntimeError("Async generator fixture did not stop") |
|
|
|
def run_fixture( |
|
self, |
|
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], |
|
kwargs: dict[str, Any], |
|
) -> T_Retval: |
|
retval = self.get_loop().run_until_complete( |
|
self._call_in_runner_task(fixture_func, **kwargs) |
|
) |
|
self._raise_async_exceptions() |
|
return retval |
|
|
|
def run_test( |
|
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] |
|
) -> None: |
|
try: |
|
self.get_loop().run_until_complete( |
|
self._call_in_runner_task(test_func, **kwargs) |
|
) |
|
except Exception as exc: |
|
self._exceptions.append(exc) |
|
|
|
self._raise_async_exceptions() |
|
|
|
|
|
class AsyncIOBackend(AsyncBackend): |
|
@classmethod |
|
def run( |
|
cls, |
|
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], |
|
args: tuple[Unpack[PosArgsT]], |
|
kwargs: dict[str, Any], |
|
options: dict[str, Any], |
|
) -> T_Retval: |
|
@wraps(func) |
|
async def wrapper() -> T_Retval: |
|
task = cast(asyncio.Task, current_task()) |
|
task.set_name(get_callable_name(func)) |
|
_task_states[task] = TaskState(None, None) |
|
|
|
try: |
|
return await func(*args) |
|
finally: |
|
del _task_states[task] |
|
|
|
debug = options.get("debug", None) |
|
loop_factory = options.get("loop_factory", None) |
|
if loop_factory is None and options.get("use_uvloop", False): |
|
import uvloop |
|
|
|
loop_factory = uvloop.new_event_loop |
|
|
|
with Runner(debug=debug, loop_factory=loop_factory) as runner: |
|
return runner.run(wrapper()) |
|
|
|
@classmethod |
|
def current_token(cls) -> object: |
|
return get_running_loop() |
|
|
|
@classmethod |
|
def current_time(cls) -> float: |
|
return get_running_loop().time() |
|
|
|
@classmethod |
|
def cancelled_exception_class(cls) -> type[BaseException]: |
|
return CancelledError |
|
|
|
@classmethod |
|
async def checkpoint(cls) -> None: |
|
await sleep(0) |
|
|
|
@classmethod |
|
async def checkpoint_if_cancelled(cls) -> None: |
|
task = current_task() |
|
if task is None: |
|
return |
|
|
|
try: |
|
cancel_scope = _task_states[task].cancel_scope |
|
except KeyError: |
|
return |
|
|
|
while cancel_scope: |
|
if cancel_scope.cancel_called: |
|
await sleep(0) |
|
elif cancel_scope.shield: |
|
break |
|
else: |
|
cancel_scope = cancel_scope._parent_scope |
|
|
|
@classmethod |
|
async def cancel_shielded_checkpoint(cls) -> None: |
|
with CancelScope(shield=True): |
|
await sleep(0) |
|
|
|
@classmethod |
|
async def sleep(cls, delay: float) -> None: |
|
await sleep(delay) |
|
|
|
@classmethod |
|
def create_cancel_scope( |
|
cls, *, deadline: float = math.inf, shield: bool = False |
|
) -> CancelScope: |
|
return CancelScope(deadline=deadline, shield=shield) |
|
|
|
@classmethod |
|
def current_effective_deadline(cls) -> float: |
|
try: |
|
cancel_scope = _task_states[ |
|
current_task() |
|
].cancel_scope |
|
except KeyError: |
|
return math.inf |
|
|
|
deadline = math.inf |
|
while cancel_scope: |
|
deadline = min(deadline, cancel_scope.deadline) |
|
if cancel_scope._cancel_called: |
|
deadline = -math.inf |
|
break |
|
elif cancel_scope.shield: |
|
break |
|
else: |
|
cancel_scope = cancel_scope._parent_scope |
|
|
|
return deadline |
|
|
|
@classmethod |
|
def create_task_group(cls) -> abc.TaskGroup: |
|
return TaskGroup() |
|
|
|
@classmethod |
|
def create_event(cls) -> abc.Event: |
|
return Event() |
|
|
|
@classmethod |
|
def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter: |
|
return CapacityLimiter(total_tokens) |
|
|
|
@classmethod |
|
async def run_sync_in_worker_thread( |
|
cls, |
|
func: Callable[[Unpack[PosArgsT]], T_Retval], |
|
args: tuple[Unpack[PosArgsT]], |
|
abandon_on_cancel: bool = False, |
|
limiter: abc.CapacityLimiter | None = None, |
|
) -> T_Retval: |
|
await cls.checkpoint() |
|
|
|
|
|
|
|
try: |
|
idle_workers = _threadpool_idle_workers.get() |
|
workers = _threadpool_workers.get() |
|
except LookupError: |
|
idle_workers = deque() |
|
workers = set() |
|
_threadpool_idle_workers.set(idle_workers) |
|
_threadpool_workers.set(workers) |
|
|
|
async with limiter or cls.current_default_thread_limiter(): |
|
with CancelScope(shield=not abandon_on_cancel) as scope: |
|
future: asyncio.Future = asyncio.Future() |
|
root_task = find_root_task() |
|
if not idle_workers: |
|
worker = WorkerThread(root_task, workers, idle_workers) |
|
worker.start() |
|
workers.add(worker) |
|
root_task.add_done_callback(worker.stop) |
|
else: |
|
worker = idle_workers.pop() |
|
|
|
|
|
|
|
now = cls.current_time() |
|
while idle_workers: |
|
if ( |
|
now - idle_workers[0].idle_since |
|
< WorkerThread.MAX_IDLE_TIME |
|
): |
|
break |
|
|
|
expired_worker = idle_workers.popleft() |
|
expired_worker.root_task.remove_done_callback( |
|
expired_worker.stop |
|
) |
|
expired_worker.stop() |
|
|
|
context = copy_context() |
|
context.run(sniffio.current_async_library_cvar.set, None) |
|
if abandon_on_cancel or scope._parent_scope is None: |
|
worker_scope = scope |
|
else: |
|
worker_scope = scope._parent_scope |
|
|
|
worker.queue.put_nowait((context, func, args, future, worker_scope)) |
|
return await future |
|
|
|
@classmethod |
|
def check_cancelled(cls) -> None: |
|
scope: CancelScope | None = threadlocals.current_cancel_scope |
|
while scope is not None: |
|
if scope.cancel_called: |
|
raise CancelledError(f"Cancelled by cancel scope {id(scope):x}") |
|
|
|
if scope.shield: |
|
return |
|
|
|
scope = scope._parent_scope |
|
|
|
@classmethod |
|
def run_async_from_thread( |
|
cls, |
|
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], |
|
args: tuple[Unpack[PosArgsT]], |
|
token: object, |
|
) -> T_Retval: |
|
async def task_wrapper(scope: CancelScope) -> T_Retval: |
|
__tracebackhide__ = True |
|
task = cast(asyncio.Task, current_task()) |
|
_task_states[task] = TaskState(None, scope) |
|
scope._tasks.add(task) |
|
try: |
|
return await func(*args) |
|
except CancelledError as exc: |
|
raise concurrent.futures.CancelledError(str(exc)) from None |
|
finally: |
|
scope._tasks.discard(task) |
|
|
|
loop = cast(AbstractEventLoop, token) |
|
context = copy_context() |
|
context.run(sniffio.current_async_library_cvar.set, "asyncio") |
|
wrapper = task_wrapper(threadlocals.current_cancel_scope) |
|
f: concurrent.futures.Future[T_Retval] = context.run( |
|
asyncio.run_coroutine_threadsafe, wrapper, loop |
|
) |
|
return f.result() |
|
|
|
@classmethod |
|
def run_sync_from_thread( |
|
cls, |
|
func: Callable[[Unpack[PosArgsT]], T_Retval], |
|
args: tuple[Unpack[PosArgsT]], |
|
token: object, |
|
) -> T_Retval: |
|
@wraps(func) |
|
def wrapper() -> None: |
|
try: |
|
sniffio.current_async_library_cvar.set("asyncio") |
|
f.set_result(func(*args)) |
|
except BaseException as exc: |
|
f.set_exception(exc) |
|
if not isinstance(exc, Exception): |
|
raise |
|
|
|
f: concurrent.futures.Future[T_Retval] = Future() |
|
loop = cast(AbstractEventLoop, token) |
|
loop.call_soon_threadsafe(wrapper) |
|
return f.result() |
|
|
|
@classmethod |
|
def create_blocking_portal(cls) -> abc.BlockingPortal: |
|
return BlockingPortal() |
|
|
|
@classmethod |
|
async def open_process( |
|
cls, |
|
command: str | bytes | Sequence[str | bytes], |
|
*, |
|
shell: bool, |
|
stdin: int | IO[Any] | None, |
|
stdout: int | IO[Any] | None, |
|
stderr: int | IO[Any] | None, |
|
cwd: str | bytes | PathLike | None = None, |
|
env: Mapping[str, str] | None = None, |
|
start_new_session: bool = False, |
|
) -> Process: |
|
await cls.checkpoint() |
|
if shell: |
|
process = await asyncio.create_subprocess_shell( |
|
cast("str | bytes", command), |
|
stdin=stdin, |
|
stdout=stdout, |
|
stderr=stderr, |
|
cwd=cwd, |
|
env=env, |
|
start_new_session=start_new_session, |
|
) |
|
else: |
|
process = await asyncio.create_subprocess_exec( |
|
*command, |
|
stdin=stdin, |
|
stdout=stdout, |
|
stderr=stderr, |
|
cwd=cwd, |
|
env=env, |
|
start_new_session=start_new_session, |
|
) |
|
|
|
stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None |
|
stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None |
|
stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None |
|
return Process(process, stdin_stream, stdout_stream, stderr_stream) |
|
|
|
@classmethod |
|
def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None: |
|
create_task( |
|
_shutdown_process_pool_on_exit(workers), |
|
name="AnyIO process pool shutdown task", |
|
) |
|
find_root_task().add_done_callback( |
|
partial(_forcibly_shutdown_process_pool_on_exit, workers) |
|
) |
|
|
|
@classmethod |
|
async def connect_tcp( |
|
cls, host: str, port: int, local_address: IPSockAddrType | None = None |
|
) -> abc.SocketStream: |
|
transport, protocol = cast( |
|
Tuple[asyncio.Transport, StreamProtocol], |
|
await get_running_loop().create_connection( |
|
StreamProtocol, host, port, local_addr=local_address |
|
), |
|
) |
|
transport.pause_reading() |
|
return SocketStream(transport, protocol) |
|
|
|
@classmethod |
|
async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream: |
|
await cls.checkpoint() |
|
loop = get_running_loop() |
|
raw_socket = socket.socket(socket.AF_UNIX) |
|
raw_socket.setblocking(False) |
|
while True: |
|
try: |
|
raw_socket.connect(path) |
|
except BlockingIOError: |
|
f: asyncio.Future = asyncio.Future() |
|
loop.add_writer(raw_socket, f.set_result, None) |
|
f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) |
|
await f |
|
except BaseException: |
|
raw_socket.close() |
|
raise |
|
else: |
|
return UNIXSocketStream(raw_socket) |
|
|
|
@classmethod |
|
def create_tcp_listener(cls, sock: socket.socket) -> SocketListener: |
|
return TCPSocketListener(sock) |
|
|
|
@classmethod |
|
def create_unix_listener(cls, sock: socket.socket) -> SocketListener: |
|
return UNIXSocketListener(sock) |
|
|
|
@classmethod |
|
async def create_udp_socket( |
|
cls, |
|
family: AddressFamily, |
|
local_address: IPSockAddrType | None, |
|
remote_address: IPSockAddrType | None, |
|
reuse_port: bool, |
|
) -> UDPSocket | ConnectedUDPSocket: |
|
transport, protocol = await get_running_loop().create_datagram_endpoint( |
|
DatagramProtocol, |
|
local_addr=local_address, |
|
remote_addr=remote_address, |
|
family=family, |
|
reuse_port=reuse_port, |
|
) |
|
if protocol.exception: |
|
transport.close() |
|
raise protocol.exception |
|
|
|
if not remote_address: |
|
return UDPSocket(transport, protocol) |
|
else: |
|
return ConnectedUDPSocket(transport, protocol) |
|
|
|
@classmethod |
|
async def create_unix_datagram_socket( |
|
cls, raw_socket: socket.socket, remote_path: str | bytes | None |
|
) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket: |
|
await cls.checkpoint() |
|
loop = get_running_loop() |
|
|
|
if remote_path: |
|
while True: |
|
try: |
|
raw_socket.connect(remote_path) |
|
except BlockingIOError: |
|
f: asyncio.Future = asyncio.Future() |
|
loop.add_writer(raw_socket, f.set_result, None) |
|
f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) |
|
await f |
|
except BaseException: |
|
raw_socket.close() |
|
raise |
|
else: |
|
return ConnectedUNIXDatagramSocket(raw_socket) |
|
else: |
|
return UNIXDatagramSocket(raw_socket) |
|
|
|
@classmethod |
|
async def getaddrinfo( |
|
cls, |
|
host: bytes | str | None, |
|
port: str | int | None, |
|
*, |
|
family: int | AddressFamily = 0, |
|
type: int | SocketKind = 0, |
|
proto: int = 0, |
|
flags: int = 0, |
|
) -> list[ |
|
tuple[ |
|
AddressFamily, |
|
SocketKind, |
|
int, |
|
str, |
|
tuple[str, int] | tuple[str, int, int, int], |
|
] |
|
]: |
|
return await get_running_loop().getaddrinfo( |
|
host, port, family=family, type=type, proto=proto, flags=flags |
|
) |
|
|
|
@classmethod |
|
async def getnameinfo( |
|
cls, sockaddr: IPSockAddrType, flags: int = 0 |
|
) -> tuple[str, str]: |
|
return await get_running_loop().getnameinfo(sockaddr, flags) |
|
|
|
@classmethod |
|
async def wait_socket_readable(cls, sock: socket.socket) -> None: |
|
await cls.checkpoint() |
|
try: |
|
read_events = _read_events.get() |
|
except LookupError: |
|
read_events = {} |
|
_read_events.set(read_events) |
|
|
|
if read_events.get(sock): |
|
raise BusyResourceError("reading from") from None |
|
|
|
loop = get_running_loop() |
|
event = read_events[sock] = asyncio.Event() |
|
loop.add_reader(sock, event.set) |
|
try: |
|
await event.wait() |
|
finally: |
|
if read_events.pop(sock, None) is not None: |
|
loop.remove_reader(sock) |
|
readable = True |
|
else: |
|
readable = False |
|
|
|
if not readable: |
|
raise ClosedResourceError |
|
|
|
@classmethod |
|
async def wait_socket_writable(cls, sock: socket.socket) -> None: |
|
await cls.checkpoint() |
|
try: |
|
write_events = _write_events.get() |
|
except LookupError: |
|
write_events = {} |
|
_write_events.set(write_events) |
|
|
|
if write_events.get(sock): |
|
raise BusyResourceError("writing to") from None |
|
|
|
loop = get_running_loop() |
|
event = write_events[sock] = asyncio.Event() |
|
loop.add_writer(sock.fileno(), event.set) |
|
try: |
|
await event.wait() |
|
finally: |
|
if write_events.pop(sock, None) is not None: |
|
loop.remove_writer(sock) |
|
writable = True |
|
else: |
|
writable = False |
|
|
|
if not writable: |
|
raise ClosedResourceError |
|
|
|
@classmethod |
|
def current_default_thread_limiter(cls) -> CapacityLimiter: |
|
try: |
|
return _default_thread_limiter.get() |
|
except LookupError: |
|
limiter = CapacityLimiter(40) |
|
_default_thread_limiter.set(limiter) |
|
return limiter |
|
|
|
@classmethod |
|
def open_signal_receiver( |
|
cls, *signals: Signals |
|
) -> ContextManager[AsyncIterator[Signals]]: |
|
return _SignalReceiver(signals) |
|
|
|
@classmethod |
|
def get_current_task(cls) -> TaskInfo: |
|
return AsyncIOTaskInfo(current_task()) |
|
|
|
@classmethod |
|
def get_running_tasks(cls) -> Sequence[TaskInfo]: |
|
return [AsyncIOTaskInfo(task) for task in all_tasks() if not task.done()] |
|
|
|
@classmethod |
|
async def wait_all_tasks_blocked(cls) -> None: |
|
await cls.checkpoint() |
|
this_task = current_task() |
|
while True: |
|
for task in all_tasks(): |
|
if task is this_task: |
|
continue |
|
|
|
waiter = task._fut_waiter |
|
if waiter is None or waiter.done(): |
|
await sleep(0.1) |
|
break |
|
else: |
|
return |
|
|
|
@classmethod |
|
def create_test_runner(cls, options: dict[str, Any]) -> TestRunner: |
|
return TestRunner(**options) |
|
|
|
|
|
backend_class = AsyncIOBackend |
|
|