Spaces:
Building
Building
File size: 8,269 Bytes
5a2de49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 |
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""trio async I/O library query support"""
import socket
import trio
import trio.socket # type: ignore
import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet
if not dns._features.have("trio"):
raise ImportError("trio not found or too old")
def _maybe_timeout(timeout):
if timeout is not None:
return trio.move_on_after(timeout)
else:
return dns._asyncbackend.NullContext()
# for brevity
_lltuple = dns.inet.low_level_address_tuple
# pylint: disable=redefined-outer-name
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, socket):
super().__init__(socket.family)
self.socket = socket
async def sendto(self, what, destination, timeout):
with _maybe_timeout(timeout):
return await self.socket.sendto(what, destination)
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
async def recvfrom(self, size, timeout):
with _maybe_timeout(timeout):
return await self.socket.recvfrom(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
self.socket.close()
async def getpeername(self):
return self.socket.getpeername()
async def getsockname(self):
return self.socket.getsockname()
async def getpeercert(self, timeout):
raise NotImplementedError
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, family, stream, tls=False):
self.family = family
self.stream = stream
self.tls = tls
async def sendall(self, what, timeout):
with _maybe_timeout(timeout):
return await self.stream.send_all(what)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def recv(self, size, timeout):
with _maybe_timeout(timeout):
return await self.stream.receive_some(size)
raise dns.exception.Timeout(timeout=timeout) # lgtm[py/unreachable-statement]
async def close(self):
await self.stream.aclose()
async def getpeername(self):
if self.tls:
return self.stream.transport_stream.socket.getpeername()
else:
return self.stream.socket.getpeername()
async def getsockname(self):
if self.tls:
return self.stream.transport_stream.socket.getsockname()
else:
return self.stream.socket.getsockname()
async def getpeercert(self, timeout):
if self.tls:
with _maybe_timeout(timeout):
await self.stream.do_handshake()
return self.stream.getpeercert()
else:
raise NotImplementedError
if dns._features.have("doh"):
import httpcore
import httpcore._backends.trio
import httpx
_CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
_CoreTrioStream = httpcore._backends.trio.TrioStream
from dns.query import _compute_times, _expiration_for_this_attempt, _remaining
class _NetworkBackend(_CoreAsyncNetworkBackend):
def __init__(self, resolver, local_port, bootstrap_address, family):
super().__init__()
self._local_port = local_port
self._resolver = resolver
self._bootstrap_address = bootstrap_address
self._family = family
async def connect_tcp(
self, host, port, timeout, local_address, socket_options=None
): # pylint: disable=signature-differs
addresses = []
_, expiration = _compute_times(timeout)
if dns.inet.is_address(host):
addresses.append(host)
elif self._bootstrap_address is not None:
addresses.append(self._bootstrap_address)
else:
timeout = _remaining(expiration)
family = self._family
if local_address:
family = dns.inet.af_for_address(local_address)
answers = await self._resolver.resolve_name(
host, family=family, lifetime=timeout
)
addresses = answers.addresses()
for address in addresses:
try:
af = dns.inet.af_for_address(address)
if local_address is not None or self._local_port != 0:
source = (local_address, self._local_port)
else:
source = None
destination = (address, port)
attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
timeout = _remaining(attempt_expiration)
sock = await Backend().make_socket(
af, socket.SOCK_STREAM, 0, source, destination, timeout
)
return _CoreTrioStream(sock.stream)
except Exception:
continue
raise httpcore.ConnectError
async def connect_unix_socket(
self, path, timeout, socket_options=None
): # pylint: disable=signature-differs
raise NotImplementedError
async def sleep(self, seconds): # pylint: disable=signature-differs
await trio.sleep(seconds)
class _HTTPTransport(httpx.AsyncHTTPTransport):
def __init__(
self,
*args,
local_port=0,
bootstrap_address=None,
resolver=None,
family=socket.AF_UNSPEC,
**kwargs,
):
if resolver is None:
# pylint: disable=import-outside-toplevel,redefined-outer-name
import dns.asyncresolver
resolver = dns.asyncresolver.Resolver()
super().__init__(*args, **kwargs)
self._pool._network_backend = _NetworkBackend(
resolver, local_port, bootstrap_address, family
)
else:
_HTTPTransport = dns._asyncbackend.NullTransport # type: ignore
class Backend(dns._asyncbackend.Backend):
def name(self):
return "trio"
async def make_socket(
self,
af,
socktype,
proto=0,
source=None,
destination=None,
timeout=None,
ssl_context=None,
server_hostname=None,
):
s = trio.socket.socket(af, socktype, proto)
stream = None
try:
if source:
await s.bind(_lltuple(source, af))
if socktype == socket.SOCK_STREAM:
connected = False
with _maybe_timeout(timeout):
await s.connect(_lltuple(destination, af))
connected = True
if not connected:
raise dns.exception.Timeout(
timeout=timeout
) # lgtm[py/unreachable-statement]
except Exception: # pragma: no cover
s.close()
raise
if socktype == socket.SOCK_DGRAM:
return DatagramSocket(s)
elif socktype == socket.SOCK_STREAM:
stream = trio.SocketStream(s)
tls = False
if ssl_context:
tls = True
try:
stream = trio.SSLStream(
stream, ssl_context, server_hostname=server_hostname
)
except Exception: # pragma: no cover
await stream.aclose()
raise
return StreamSocket(af, stream, tls)
raise NotImplementedError(
"unsupported socket " + f"type {socktype}"
) # pragma: no cover
async def sleep(self, interval):
await trio.sleep(interval)
def get_transport_class(self):
return _HTTPTransport
async def wait_for(self, awaitable, timeout):
with _maybe_timeout(timeout):
return await awaitable
raise dns.exception.Timeout(
timeout=timeout
) # pragma: no cover lgtm[py/unreachable-statement]
|