Spaces:
Building
Building
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license | |
"""asyncio library query support""" | |
import asyncio | |
import socket | |
import sys | |
import dns._asyncbackend | |
import dns._features | |
import dns.exception | |
import dns.inet | |
_is_win32 = sys.platform == "win32" | |
def _get_running_loop(): | |
try: | |
return asyncio.get_running_loop() | |
except AttributeError: # pragma: no cover | |
return asyncio.get_event_loop() | |
class _DatagramProtocol: | |
def __init__(self): | |
self.transport = None | |
self.recvfrom = None | |
def connection_made(self, transport): | |
self.transport = transport | |
def datagram_received(self, data, addr): | |
if self.recvfrom and not self.recvfrom.done(): | |
self.recvfrom.set_result((data, addr)) | |
def error_received(self, exc): # pragma: no cover | |
if self.recvfrom and not self.recvfrom.done(): | |
self.recvfrom.set_exception(exc) | |
def connection_lost(self, exc): | |
if self.recvfrom and not self.recvfrom.done(): | |
if exc is None: | |
# EOF we triggered. Is there a better way to do this? | |
try: | |
raise EOFError | |
except EOFError as e: | |
self.recvfrom.set_exception(e) | |
else: | |
self.recvfrom.set_exception(exc) | |
def close(self): | |
self.transport.close() | |
async def _maybe_wait_for(awaitable, timeout): | |
if timeout is not None: | |
try: | |
return await asyncio.wait_for(awaitable, timeout) | |
except asyncio.TimeoutError: | |
raise dns.exception.Timeout(timeout=timeout) | |
else: | |
return await awaitable | |
class DatagramSocket(dns._asyncbackend.DatagramSocket): | |
def __init__(self, family, transport, protocol): | |
super().__init__(family) | |
self.transport = transport | |
self.protocol = protocol | |
async def sendto(self, what, destination, timeout): # pragma: no cover | |
# no timeout for asyncio sendto | |
self.transport.sendto(what, destination) | |
return len(what) | |
async def recvfrom(self, size, timeout): | |
# ignore size as there's no way I know to tell protocol about it | |
done = _get_running_loop().create_future() | |
try: | |
assert self.protocol.recvfrom is None | |
self.protocol.recvfrom = done | |
await _maybe_wait_for(done, timeout) | |
return done.result() | |
finally: | |
self.protocol.recvfrom = None | |
async def close(self): | |
self.protocol.close() | |
async def getpeername(self): | |
return self.transport.get_extra_info("peername") | |
async def getsockname(self): | |
return self.transport.get_extra_info("sockname") | |
async def getpeercert(self, timeout): | |
raise NotImplementedError | |
class StreamSocket(dns._asyncbackend.StreamSocket): | |
def __init__(self, af, reader, writer): | |
self.family = af | |
self.reader = reader | |
self.writer = writer | |
async def sendall(self, what, timeout): | |
self.writer.write(what) | |
return await _maybe_wait_for(self.writer.drain(), timeout) | |
async def recv(self, size, timeout): | |
return await _maybe_wait_for(self.reader.read(size), timeout) | |
async def close(self): | |
self.writer.close() | |
async def getpeername(self): | |
return self.writer.get_extra_info("peername") | |
async def getsockname(self): | |
return self.writer.get_extra_info("sockname") | |
async def getpeercert(self, timeout): | |
return self.writer.get_extra_info("peercert") | |
if dns._features.have("doh"): | |
import anyio | |
import httpcore | |
import httpcore._backends.anyio | |
import httpx | |
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend | |
_CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream | |
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining | |
class _NetworkBackend(_CoreAsyncNetworkBackend): | |
def __init__(self, resolver, local_port, bootstrap_address, family): | |
super().__init__() | |
self._local_port = local_port | |
self._resolver = resolver | |
self._bootstrap_address = bootstrap_address | |
self._family = family | |
if local_port != 0: | |
raise NotImplementedError( | |
"the asyncio transport for HTTPX cannot set the local port" | |
) | |
async def connect_tcp( | |
self, host, port, timeout, local_address, socket_options=None | |
): # pylint: disable=signature-differs | |
addresses = [] | |
_, expiration = _compute_times(timeout) | |
if dns.inet.is_address(host): | |
addresses.append(host) | |
elif self._bootstrap_address is not None: | |
addresses.append(self._bootstrap_address) | |
else: | |
timeout = _remaining(expiration) | |
family = self._family | |
if local_address: | |
family = dns.inet.af_for_address(local_address) | |
answers = await self._resolver.resolve_name( | |
host, family=family, lifetime=timeout | |
) | |
addresses = answers.addresses() | |
for address in addresses: | |
try: | |
attempt_expiration = _expiration_for_this_attempt(2.0, expiration) | |
timeout = _remaining(attempt_expiration) | |
with anyio.fail_after(timeout): | |
stream = await anyio.connect_tcp( | |
remote_host=address, | |
remote_port=port, | |
local_host=local_address, | |
) | |
return _CoreAnyIOStream(stream) | |
except Exception: | |
pass | |
raise httpcore.ConnectError | |
async def connect_unix_socket( | |
self, path, timeout, socket_options=None | |
): # pylint: disable=signature-differs | |
raise NotImplementedError | |
async def sleep(self, seconds): # pylint: disable=signature-differs | |
await anyio.sleep(seconds) | |
class _HTTPTransport(httpx.AsyncHTTPTransport): | |
def __init__( | |
self, | |
*args, | |
local_port=0, | |
bootstrap_address=None, | |
resolver=None, | |
family=socket.AF_UNSPEC, | |
**kwargs, | |
): | |
if resolver is None: | |
# pylint: disable=import-outside-toplevel,redefined-outer-name | |
import dns.asyncresolver | |
resolver = dns.asyncresolver.Resolver() | |
super().__init__(*args, **kwargs) | |
self._pool._network_backend = _NetworkBackend( | |
resolver, local_port, bootstrap_address, family | |
) | |
else: | |
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore | |
class Backend(dns._asyncbackend.Backend): | |
def name(self): | |
return "asyncio" | |
async def make_socket( | |
self, | |
af, | |
socktype, | |
proto=0, | |
source=None, | |
destination=None, | |
timeout=None, | |
ssl_context=None, | |
server_hostname=None, | |
): | |
loop = _get_running_loop() | |
if socktype == socket.SOCK_DGRAM: | |
if _is_win32 and source is None: | |
# Win32 wants explicit binding before recvfrom(). This is the | |
# proper fix for [#637]. | |
source = (dns.inet.any_for_af(af), 0) | |
transport, protocol = await loop.create_datagram_endpoint( | |
_DatagramProtocol, | |
source, | |
family=af, | |
proto=proto, | |
remote_addr=destination, | |
) | |
return DatagramSocket(af, transport, protocol) | |
elif socktype == socket.SOCK_STREAM: | |
if destination is None: | |
# This shouldn't happen, but we check to make code analysis software | |
# happier. | |
raise ValueError("destination required for stream sockets") | |
(r, w) = await _maybe_wait_for( | |
asyncio.open_connection( | |
destination[0], | |
destination[1], | |
ssl=ssl_context, | |
family=af, | |
proto=proto, | |
local_addr=source, | |
server_hostname=server_hostname, | |
), | |
timeout, | |
) | |
return StreamSocket(af, r, w) | |
raise NotImplementedError( | |
"unsupported socket " + f"type {socktype}" | |
) # pragma: no cover | |
async def sleep(self, interval): | |
await asyncio.sleep(interval) | |
def datagram_connection_required(self): | |
return False | |
def get_transport_class(self): | |
return _HTTPTransport | |
async def wait_for(self, awaitable, timeout): | |
return await _maybe_wait_for(awaitable, timeout) | |