File size: 8,971 Bytes
5a2de49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

"""asyncio library query support"""

import asyncio
import socket
import sys

import dns._asyncbackend
import dns._features
import dns.exception
import dns.inet

_is_win32 = sys.platform == "win32"


def _get_running_loop():
    try:
        return asyncio.get_running_loop()
    except AttributeError:  # pragma: no cover
        return asyncio.get_event_loop()


class _DatagramProtocol:
    def __init__(self):
        self.transport = None
        self.recvfrom = None

    def connection_made(self, transport):
        self.transport = transport

    def datagram_received(self, data, addr):
        if self.recvfrom and not self.recvfrom.done():
            self.recvfrom.set_result((data, addr))

    def error_received(self, exc):  # pragma: no cover
        if self.recvfrom and not self.recvfrom.done():
            self.recvfrom.set_exception(exc)

    def connection_lost(self, exc):
        if self.recvfrom and not self.recvfrom.done():
            if exc is None:
                # EOF we triggered.  Is there a better way to do this?
                try:
                    raise EOFError
                except EOFError as e:
                    self.recvfrom.set_exception(e)
            else:
                self.recvfrom.set_exception(exc)

    def close(self):
        self.transport.close()


async def _maybe_wait_for(awaitable, timeout):
    if timeout is not None:
        try:
            return await asyncio.wait_for(awaitable, timeout)
        except asyncio.TimeoutError:
            raise dns.exception.Timeout(timeout=timeout)
    else:
        return await awaitable


class DatagramSocket(dns._asyncbackend.DatagramSocket):
    def __init__(self, family, transport, protocol):
        super().__init__(family)
        self.transport = transport
        self.protocol = protocol

    async def sendto(self, what, destination, timeout):  # pragma: no cover
        # no timeout for asyncio sendto
        self.transport.sendto(what, destination)
        return len(what)

    async def recvfrom(self, size, timeout):
        # ignore size as there's no way I know to tell protocol about it
        done = _get_running_loop().create_future()
        try:
            assert self.protocol.recvfrom is None
            self.protocol.recvfrom = done
            await _maybe_wait_for(done, timeout)
            return done.result()
        finally:
            self.protocol.recvfrom = None

    async def close(self):
        self.protocol.close()

    async def getpeername(self):
        return self.transport.get_extra_info("peername")

    async def getsockname(self):
        return self.transport.get_extra_info("sockname")

    async def getpeercert(self, timeout):
        raise NotImplementedError


class StreamSocket(dns._asyncbackend.StreamSocket):
    def __init__(self, af, reader, writer):
        self.family = af
        self.reader = reader
        self.writer = writer

    async def sendall(self, what, timeout):
        self.writer.write(what)
        return await _maybe_wait_for(self.writer.drain(), timeout)

    async def recv(self, size, timeout):
        return await _maybe_wait_for(self.reader.read(size), timeout)

    async def close(self):
        self.writer.close()

    async def getpeername(self):
        return self.writer.get_extra_info("peername")

    async def getsockname(self):
        return self.writer.get_extra_info("sockname")

    async def getpeercert(self, timeout):
        return self.writer.get_extra_info("peercert")


if dns._features.have("doh"):
    import anyio
    import httpcore
    import httpcore._backends.anyio
    import httpx

    _CoreAsyncNetworkBackend = httpcore.AsyncNetworkBackend
    _CoreAnyIOStream = httpcore._backends.anyio.AnyIOStream

    from dns.query import _compute_times, _expiration_for_this_attempt, _remaining

    class _NetworkBackend(_CoreAsyncNetworkBackend):
        def __init__(self, resolver, local_port, bootstrap_address, family):
            super().__init__()
            self._local_port = local_port
            self._resolver = resolver
            self._bootstrap_address = bootstrap_address
            self._family = family
            if local_port != 0:
                raise NotImplementedError(
                    "the asyncio transport for HTTPX cannot set the local port"
                )

        async def connect_tcp(
            self, host, port, timeout, local_address, socket_options=None
        ):  # pylint: disable=signature-differs
            addresses = []
            _, expiration = _compute_times(timeout)
            if dns.inet.is_address(host):
                addresses.append(host)
            elif self._bootstrap_address is not None:
                addresses.append(self._bootstrap_address)
            else:
                timeout = _remaining(expiration)
                family = self._family
                if local_address:
                    family = dns.inet.af_for_address(local_address)
                answers = await self._resolver.resolve_name(
                    host, family=family, lifetime=timeout
                )
                addresses = answers.addresses()
            for address in addresses:
                try:
                    attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
                    timeout = _remaining(attempt_expiration)
                    with anyio.fail_after(timeout):
                        stream = await anyio.connect_tcp(
                            remote_host=address,
                            remote_port=port,
                            local_host=local_address,
                        )
                    return _CoreAnyIOStream(stream)
                except Exception:
                    pass
            raise httpcore.ConnectError

        async def connect_unix_socket(
            self, path, timeout, socket_options=None
        ):  # pylint: disable=signature-differs
            raise NotImplementedError

        async def sleep(self, seconds):  # pylint: disable=signature-differs
            await anyio.sleep(seconds)

    class _HTTPTransport(httpx.AsyncHTTPTransport):
        def __init__(
            self,
            *args,
            local_port=0,
            bootstrap_address=None,
            resolver=None,
            family=socket.AF_UNSPEC,
            **kwargs,
        ):
            if resolver is None:
                # pylint: disable=import-outside-toplevel,redefined-outer-name
                import dns.asyncresolver

                resolver = dns.asyncresolver.Resolver()
            super().__init__(*args, **kwargs)
            self._pool._network_backend = _NetworkBackend(
                resolver, local_port, bootstrap_address, family
            )

else:
    _HTTPTransport = dns._asyncbackend.NullTransport  # type: ignore


class Backend(dns._asyncbackend.Backend):
    def name(self):
        return "asyncio"

    async def make_socket(
        self,
        af,
        socktype,
        proto=0,
        source=None,
        destination=None,
        timeout=None,
        ssl_context=None,
        server_hostname=None,
    ):
        loop = _get_running_loop()
        if socktype == socket.SOCK_DGRAM:
            if _is_win32 and source is None:
                # Win32 wants explicit binding before recvfrom().  This is the
                # proper fix for [#637].
                source = (dns.inet.any_for_af(af), 0)
            transport, protocol = await loop.create_datagram_endpoint(
                _DatagramProtocol,
                source,
                family=af,
                proto=proto,
                remote_addr=destination,
            )
            return DatagramSocket(af, transport, protocol)
        elif socktype == socket.SOCK_STREAM:
            if destination is None:
                # This shouldn't happen, but we check to make code analysis software
                # happier.
                raise ValueError("destination required for stream sockets")
            (r, w) = await _maybe_wait_for(
                asyncio.open_connection(
                    destination[0],
                    destination[1],
                    ssl=ssl_context,
                    family=af,
                    proto=proto,
                    local_addr=source,
                    server_hostname=server_hostname,
                ),
                timeout,
            )
            return StreamSocket(af, r, w)
        raise NotImplementedError(
            "unsupported socket " + f"type {socktype}"
        )  # pragma: no cover

    async def sleep(self, interval):
        await asyncio.sleep(interval)

    def datagram_connection_required(self):
        return False

    def get_transport_class(self):
        return _HTTPTransport

    async def wait_for(self, awaitable, timeout):
        return await _maybe_wait_for(awaitable, timeout)