summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-06-11 18:50:30 -0700
committerBob Halley <halley@dnspython.org>2020-06-11 18:50:30 -0700
commit98b344d625190c3b536682775fb0e0340e715063 (patch)
tree5882b861b4579633467106a1571503a586037be3
parentfa1d4938a8296b4ae88f7bbc5dfa67377995b9e2 (diff)
downloaddnspython-98b344d625190c3b536682775fb0e0340e715063.tar.gz
Support trio, curio, and asyncio with one API!
-rw-r--r--dns/__init__.py2
-rw-r--r--dns/_asyncbackend.py77
-rw-r--r--dns/_asyncio_backend.py118
-rw-r--r--dns/_curio_backend.py92
-rw-r--r--dns/_trio_backend.py92
-rw-r--r--dns/asyncbackend.py43
-rw-r--r--dns/asyncquery.py422
-rw-r--r--dns/asyncresolver.py (renamed from dns/trio/resolver.py)72
-rw-r--r--dns/trio/__init__.py8
-rw-r--r--dns/trio/query.py374
-rw-r--r--dns/trio/query.pyi33
-rw-r--r--dns/trio/resolver.pyi26
-rw-r--r--pyproject.toml3
-rwxr-xr-xsetup.py2
-rw-r--r--tests/test_async.py219
-rw-r--r--tests/test_trio.py189
16 files changed, 1106 insertions, 666 deletions
diff --git a/dns/__init__.py b/dns/__init__.py
index d5cadb8..6412fb5 100644
--- a/dns/__init__.py
+++ b/dns/__init__.py
@@ -18,6 +18,8 @@
"""dnspython DNS toolkit"""
__all__ = [
+ 'asyncquery.py',
+ 'asyncresolver.py',
'dnssec',
'e164',
'edns',
diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py
new file mode 100644
index 0000000..9bfdaba
--- /dev/null
+++ b/dns/_asyncbackend.py
@@ -0,0 +1,77 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+
+import dns.inet
+
+
+# This is a nullcontext for both sync and async
+
+class NullContext:
+ def __init__(self, enter_result=None):
+ self.enter_result = enter_result
+
+ def __enter__(self):
+ return self.enter_result
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ pass
+
+ async def __aenter__(self):
+ return self.enter_result
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ pass
+
+
+# This is handy, but should probably move somewhere else!
+
+def low_level_address_tuple(af, high_level_address_tuple):
+ address, port = high_level_address_tuple
+ if af == dns.inet.AF_INET:
+ return (address, port)
+ elif af == dns.inet.AF_INET6:
+ ai_flags = socket.AI_NUMERICHOST
+ ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags)
+ return tup
+ else:
+ raise NotImplementedError(f'unknown address family {af}')
+
+
+# These are declared here so backends can import them without creating
+# circular dependencies with dns.asyncbackend.
+
+class Socket:
+ async def close(self):
+ pass
+
+ async def __aenter__(self):
+ pass
+
+ async def __aexit__(self, exc_type, exc_value, traceback):
+ await self.close()
+
+
+class DatagramSocket(Socket):
+ async def sendto(self, what, destination, timeout):
+ pass
+
+ async def recvfrom(self, size, timeout):
+ pass
+
+
+class StreamSocket(Socket):
+ async def sendall(self, what, destination, timeout):
+ pass
+
+ async def recv(self, size, timeout):
+ pass
+
+
+class Backend:
+ def name(self):
+ return 'unknown'
+
+ async def make_socket(self, af, socktype, proto=0,
+ source=None, raw_source=None,
+ ssl_context=None, server_hostname=None):
+ raise NotImplementedError
diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py
new file mode 100644
index 0000000..42c6e66
--- /dev/null
+++ b/dns/_asyncio_backend.py
@@ -0,0 +1,118 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""asyncio library query support"""
+
+import socket
+import asyncio
+
+import dns._asyncbackend
+import dns.exception
+
+class _DatagramProtocol:
+ def __init__(self):
+ self.transport = None
+ self.recvfrom = None
+
+ def connection_made(self, transport):
+ self.transport = transport
+
+ def datagram_received(self, data, addr):
+ if self.recvfrom:
+ self.recvfrom.set_result((data, addr))
+ self.recvfrom = None
+
+ def error_received(self, exc):
+ if self.recvfrom:
+ self.recvfrom.set_exception(exc)
+
+ def connection_lost(self, exc):
+ if self.recvfrom:
+ self.recvfrom.set_exception(exc)
+
+ def close(self):
+ self.transport.close()
+
+
+async def _maybe_wait_for(awaitable, timeout):
+ if timeout:
+ try:
+ return await asyncio.wait_for(awaitable, timeout)
+ except asyncio.TimeoutError:
+ raise dns.exception.Timeout(timeout=timeout)
+ else:
+ return await awaitable
+
+class DatagramSocket(dns._asyncbackend.DatagramSocket):
+ def __init__(self, family, transport, protocol):
+ self.family = family
+ self.transport = transport
+ self.protocol = protocol
+
+ async def sendto(self, what, destination, timeout):
+ # no timeout for asyncio sendto
+ self.transport.sendto(what, destination)
+
+ async def recvfrom(self, timeout):
+ done = asyncio.get_running_loop().create_future()
+ assert self.protocol.recvfrom is None
+ self.protocol.recvfrom = done
+ await _maybe_wait_for(done, timeout)
+ return done.result()
+
+ async def close(self):
+ self.protocol.close()
+
+ async def getpeername(self):
+ return self.transport.get_extra_info('peername')
+
+
+class StreamSocket(dns._asyncbackend.DatagramSocket):
+ def __init__(self, af, reader, writer):
+ self.family = af
+ self.reader = reader
+ self.writer = writer
+
+ async def sendall(self, what, timeout):
+ self.writer.write(what),
+ return await _maybe_wait_for(self.writer.drain(), timeout)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def recv(self, count, timeout):
+ return await _maybe_wait_for(self.reader.read(count),
+ timeout)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def close(self):
+ self.writer.close()
+ await self.writer.wait_closed()
+
+ async def getpeername(self):
+ return self.reader.get_extra_info('peername')
+
+
+class Backend(dns._asyncbackend.Backend):
+ def name(self):
+ return 'asyncio'
+
+ async def make_socket(self, af, socktype, proto=0,
+ source=None, destination=None, timeout=None,
+ ssl_context=None, server_hostname=None):
+ loop = asyncio.get_running_loop()
+ if socktype == socket.SOCK_DGRAM:
+ transport, protocol = await loop.create_datagram_endpoint(
+ _DatagramProtocol, source, family=af,
+ proto=proto)
+ return DatagramSocket(af, transport, protocol)
+ elif socktype == socket.SOCK_STREAM:
+ (r, w) = await _maybe_wait_for(
+ asyncio.open_connection(destination[0],
+ destination[1],
+ family=af,
+ proto=proto,
+ local_addr=source),
+ timeout)
+ return StreamSocket(af, r, w)
+ raise NotImplementedError(f'unsupported socket type {socktype}')
+
+ async def sleep(self, interval):
+ await asyncio.sleep(interval)
diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py
new file mode 100644
index 0000000..e37fea3
--- /dev/null
+++ b/dns/_curio_backend.py
@@ -0,0 +1,92 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""curio async I/O library query support"""
+
+import socket
+import curio
+import curio.socket # type: ignore
+
+import dns._asyncbackend
+import dns.exception
+
+
+def _maybe_timeout(timeout):
+ if timeout:
+ return curio.ignore_after(timeout)
+ else:
+ return dns._asyncbackend.NullContext()
+
+
+# for brevity
+_lltuple = dns._asyncbackend.low_level_address_tuple
+
+
+class DatagramSocket(dns._asyncbackend.DatagramSocket):
+ def __init__(self, socket):
+ self.socket = socket
+ self.family = socket.family
+
+ async def sendto(self, what, destination, timeout):
+ async with _maybe_timeout(timeout):
+ return await self.socket.sendto(what, destination)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def recvfrom(self, timeout):
+ async with _maybe_timeout(timeout):
+ return await self.socket.recvfrom(65535)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def close(self):
+ await self.socket.close()
+
+ async def getpeername(self):
+ return self.socket.getpeername()
+
+
+class StreamSocket(dns._asyncbackend.DatagramSocket):
+ def __init__(self, socket):
+ self.socket = socket
+ self.family = socket.family
+
+ async def sendall(self, what, timeout):
+ async with _maybe_timeout(timeout):
+ return await self.socket.sendall(what)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def recv(self, count, timeout):
+ async with _maybe_timeout(timeout):
+ return await self.socket.recv(count)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def close(self):
+ await self.socket.close()
+
+ async def getpeername(self):
+ return self.socket.getpeername()
+
+
+class Backend(dns._asyncbackend.Backend):
+ def name(self):
+ return 'curio'
+
+ async def make_socket(self, af, socktype, proto=0,
+ source=None, destination=None, timeout=None,
+ ssl_context=None, server_hostname=None):
+ s = curio.socket.socket(af, socktype, proto)
+ try:
+ if source:
+ s.bind(_lltuple(af, source))
+ if socktype == socket.SOCK_STREAM:
+ with _maybe_timeout(timeout):
+ await s.connect(_lltuple(af, destination))
+ except Exception:
+ await s.close()
+ raise
+ if socktype == socket.SOCK_DGRAM:
+ return DatagramSocket(s)
+ elif socktype == socket.SOCK_STREAM:
+ return StreamSocket(s)
+ raise NotImplementedError(f'unsupported socket type {socktype}')
+
+ async def sleep(self, interval):
+ await curio.sleep(interval)
diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py
new file mode 100644
index 0000000..bcaddcc
--- /dev/null
+++ b/dns/_trio_backend.py
@@ -0,0 +1,92 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""trio async I/O library query support"""
+
+import socket
+import trio
+import trio.socket # type: ignore
+
+import dns._asyncbackend
+import dns.exception
+
+
+def _maybe_timeout(timeout):
+ if timeout:
+ return trio.move_on_after(timeout)
+ else:
+ return dns._asyncbackend.NullContext()
+
+
+# for brevity
+_lltuple = dns._asyncbackend.low_level_address_tuple
+
+
+class DatagramSocket(dns._asyncbackend.DatagramSocket):
+ def __init__(self, socket):
+ self.socket = socket
+ self.family = socket.family
+
+ async def sendto(self, what, destination, timeout):
+ with _maybe_timeout(timeout):
+ return await self.socket.sendto(what, destination)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def recvfrom(self, timeout):
+ with _maybe_timeout(timeout):
+ return await self.socket.recvfrom(65535)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def close(self):
+ self.socket.close()
+
+ async def getpeername(self):
+ return self.socket.getpeername()
+
+
+class StreamSocket(dns._asyncbackend.DatagramSocket):
+ def __init__(self, family, stream):
+ self.family = family
+ self.stream = stream
+
+ async def sendall(self, what, timeout):
+ with _maybe_timeout(timeout):
+ return await self.stream.send_all(what)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def recv(self, count, timeout):
+ with _maybe_timeout(timeout):
+ return await self.stream.receive_some(count)
+ raise dns.exception.Timeout(timeout=timeout)
+
+ async def close(self):
+ await self.stream.aclose()
+
+ async def getpeername(self):
+ return self.stream.socket.getpeername()
+
+
+class Backend(dns._asyncbackend.Backend):
+ def name(self):
+ return 'trio'
+
+ async def make_socket(self, af, socktype, proto=0, source=None,
+ destination=None, timeout=None,
+ ssl_context=None, server_hostname=None):
+ s = trio.socket.socket(af, socktype, proto)
+ try:
+ if source:
+ await s.bind(_lltuple(af, source))
+ if socktype == socket.SOCK_STREAM:
+ with _maybe_timeout(timeout):
+ await s.connect(_lltuple(af, destination))
+ except Exception:
+ s.close()
+ raise
+ if socktype == socket.SOCK_DGRAM:
+ return DatagramSocket(s)
+ elif socktype == socket.SOCK_STREAM:
+ return StreamSocket(af, trio.SocketStream(s))
+ raise NotImplementedError(f'unsupported socket type {socktype}')
+
+ async def sleep(self, interval):
+ await trio.sleep(interval)
diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py
new file mode 100644
index 0000000..92a1ae3
--- /dev/null
+++ b/dns/asyncbackend.py
@@ -0,0 +1,43 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+
+from dns._asyncbackend import Socket, DatagramSocket, \
+ StreamSocket, Backend, low_level_address_tuple
+
+
+_default_backend = None
+
+
+def get_default_backend():
+ if _default_backend:
+ return _default_backend
+
+ return set_default_backend(sniff())
+
+
+def sniff():
+ name = 'asyncio'
+ try:
+ import sniffio
+ name = sniffio.current_async_library()
+ except Exception:
+ pass
+ return name
+
+
+def set_default_backend(name):
+ global _default_backend
+
+ if name == 'trio':
+ import dns._trio_backend
+ _default_backend = dns._trio_backend.Backend()
+ elif name == 'curio':
+ import dns._curio_backend
+ _default_backend = dns._curio_backend.Backend()
+ elif name == 'asyncio':
+ import dns._asyncio_backend
+ _default_backend = dns._asyncio_backend.Backend()
+ else:
+ raise NotImplementedException(f'unimplemented async backend {name}')
+
+ return _default_backend
diff --git a/dns/asyncquery.py b/dns/asyncquery.py
new file mode 100644
index 0000000..ed51fdc
--- /dev/null
+++ b/dns/asyncquery.py
@@ -0,0 +1,422 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+"""Talk to a DNS server."""
+
+import os
+import socket
+import struct
+import time
+import base64
+import ipaddress
+
+import dns.asyncbackend
+import dns.exception
+import dns.inet
+import dns.name
+import dns.message
+import dns.rcode
+import dns.rdataclass
+import dns.rdatatype
+
+from dns.query import _addresses_equal, _destination_and_source, \
+ _compute_times, UnexpectedSource
+
+
+# for brevity
+_lltuple = dns.asyncbackend.low_level_address_tuple
+
+
+def _source_tuple(af, address, port):
+ # Make a high level source tuple, or return None if address and port
+ # are both None
+ if address or port:
+ if address is None:
+ if af == socket.AF_INET:
+ address = '0.0.0.0'
+ elif af == socket.AF_INET6:
+ address = '::'
+ else:
+ raise NotImplementedError(f'unknown address family {af}')
+ return (address, port)
+ else:
+ return None
+
+
+def _timeout(expiration, now=None):
+ if expiration:
+ if not now:
+ now = time.time()
+ return max(expiration - now, 0)
+ else:
+ return None
+
+
+async def send_udp(sock, what, destination, expiration=None):
+ """Send a DNS message to the specified UDP socket.
+
+ *sock*, a ``dns.asyncbackend.DatagramSocket``.
+
+ *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
+
+ *destination*, a destination tuple appropriate for the address family
+ of the socket, specifying where to send the query.
+
+ *expiration*, a ``float`` or ``None``, the absolute time at which
+ a timeout exception should be raised. If ``None``, no timeout will
+ occur.
+
+ Returns an ``(int, float)`` tuple of bytes sent and the sent time.
+ """
+
+ if isinstance(what, dns.message.Message):
+ what = what.to_wire()
+ sent_time = time.time()
+ n = await sock.sendto(what, destination, _timeout(expiration, sent_time))
+ return (n, sent_time)
+
+
+async def receive_udp(sock, destination, expiration=None,
+ ignore_unexpected=False, one_rr_per_rrset=False,
+ keyring=None, request_mac=b'', ignore_trailing=False,
+ raise_on_truncation=False):
+ """Read a DNS message from a UDP socket.
+
+ *sock*, a ``dns.asyncbackend.DatagramSocket``.
+
+ *destination*, a destination tuple appropriate for the address family
+ of the socket, specifying where the associated query was sent.
+
+ *expiration*, a ``float`` or ``None``, the absolute time at which
+ a timeout exception should be raised. If ``None``, no timeout will
+ occur.
+
+ *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
+ unexpected sources.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *keyring*, a ``dict``, the keyring to use for TSIG.
+
+ *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
+ the TC bit is set.
+
+ Raises if the message is malformed, if network errors occur, of if
+ there is a timeout.
+
+ Returns a ``dns.message.Message`` object.
+ """
+
+ wire = b''
+ while 1:
+ (wire, from_address) = await sock.recvfrom(65535)
+ if _addresses_equal(sock.family, from_address, destination) or \
+ (dns.inet.is_multicast(destination[0]) and
+ from_address[1:] == destination[1:]):
+ break
+ if not ignore_unexpected:
+ raise UnexpectedSource('got a response from '
+ '%s instead of %s' % (from_address,
+ destination))
+ received_time = time.time()
+ r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing,
+ raise_on_truncation=raise_on_truncation)
+ return (r, received_time)
+
+async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
+ ignore_unexpected=False, one_rr_per_rrset=False,
+ ignore_trailing=False, raise_on_truncation=False, sock=None,
+ backend=None):
+ """Return the response obtained after sending a query via UDP.
+
+ *q*, a ``dns.message.Message``, the query to send
+
+ *where*, a ``str`` containing an IPv4 or IPv6 address, where
+ to send the message.
+
+ *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+ query times out. If ``None``, the default, wait forever.
+
+ *port*, an ``int``, the port send the message to. The default is 53.
+
+ *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+ the source address. The default is the wildcard address.
+
+ *source_port*, an ``int``, the port from which to send the message.
+ The default is 0.
+
+ *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
+ unexpected sources.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
+ the TC bit is set.
+
+ *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
+ the socket to use for the query. If ``None``, the default, a
+ socket is created. Note that if a socket is provided, the
+ *source* and *source_port* are ignored.
+
+ *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
+ the default, then dnspython will use the default backend.
+
+ Returns a ``dns.message.Message``.
+ """
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ wire = q.to_wire()
+ (begin_time, expiration) = _compute_times(timeout)
+ s = None
+ try:
+ if sock:
+ s = sock
+ else:
+ af = dns.inet.af_for_address(where)
+ stuple = _source_tuple(af, source, source_port)
+ s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
+ destination = _lltuple(af, (where, port))
+ await send_udp(s, wire, destination, expiration)
+ (r, received_time) = await receive_udp(s, destination, expiration,
+ ignore_unexpected,
+ one_rr_per_rrset,
+ q.keyring, q.mac,
+ ignore_trailing,
+ raise_on_truncation)
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+ finally:
+ if not sock and s:
+ await s.close()
+
+async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
+ source_port=0, ignore_unexpected=False,
+ one_rr_per_rrset=False, ignore_trailing=False,
+ udp_sock=None, tcp_sock=None):
+ """Return the response to the query, trying UDP first and falling back
+ to TCP if UDP results in a truncated response.
+
+ *q*, a ``dns.message.Message``, the query to send
+
+ *where*, a ``str`` containing an IPv4 or IPv6 address, where
+ to send the message.
+
+ *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+ query times out. If ``None``, the default, wait forever.
+
+ *port*, an ``int``, the port send the message to. The default is 53.
+
+ *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+ the source address. The default is the wildcard address.
+
+ *source_port*, an ``int``, the port from which to send the message.
+ The default is 0.
+
+ *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
+ unexpected sources.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``,
+ the socket to use for the UDP query. If ``None``, the default, a
+ socket is created. Note that if a socket is provided the *source*
+ and *source_port* are ignored for the UDP query.
+
+ *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the
+ socket to use for the TCP query. If ``None``, the default, a
+ socket is created. Note that if a socket is provided *where*,
+ *source* and *source_port* are ignored for the TCP query.
+
+ Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
+ if and only if TCP was used.
+ """
+ try:
+ response = await udp(q, where, timeout, port, source, source_port,
+ ignore_unexpected, one_rr_per_rrset,
+ ignore_trailing, True, udp_sock)
+ return (response, False)
+ except dns.message.Truncated:
+ response = await tcp(q, where, timeout, port, source, source_port,
+ one_rr_per_rrset, ignore_trailing, tcp_sock)
+ return (response, True)
+
+
+
+async def send_tcp(sock, what, expiration=None):
+ """Send a DNS message to the specified TCP socket.
+
+ *sock*, a ``socket``.
+
+ *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
+
+ *expiration*, a ``float`` or ``None``, the absolute time at which
+ a timeout exception should be raised. If ``None``, no timeout will
+ occur.
+
+ Returns an ``(int, float)`` tuple of bytes sent and the sent time.
+ """
+
+ if isinstance(what, dns.message.Message):
+ what = what.to_wire()
+ l = len(what)
+ # copying the wire into tcpmsg is inefficient, but lets us
+ # avoid writev() or doing a short write that would get pushed
+ # onto the net
+ tcpmsg = struct.pack("!H", l) + what
+ sent_time = time.time()
+ await sock.sendall(tcpmsg, expiration)
+ return (len(tcpmsg), sent_time)
+
+
+async def read_exactly(sock, count, expiration):
+ """Read the specified number of bytes from stream. Keep trying until we
+ either get the desired amount, or we hit EOF.
+ """
+ s = b''
+ while count > 0:
+ n = await sock.recv(count, _timeout(expiration))
+ if n == b'':
+ raise EOFError
+ count = count - len(n)
+ s = s + n
+ return s
+
+
+async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False,
+ keyring=None, request_mac=b'', ignore_trailing=False):
+ """Read a DNS message from a TCP socket.
+
+ *sock*, a ``socket``.
+
+ *expiration*, a ``float`` or ``None``, the absolute time at which
+ a timeout exception should be raised. If ``None``, no timeout will
+ occur.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *keyring*, a ``dict``, the keyring to use for TSIG.
+
+ *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ Raises if the message is malformed, if network errors occur, of if
+ there is a timeout.
+
+ Returns a ``dns.message.Message`` object.
+ """
+
+ ldata = await read_exactly(sock, 2, expiration)
+ (l,) = struct.unpack("!H", ldata)
+ wire = await read_exactly(sock, l, expiration)
+ received_time = time.time()
+ r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
+ one_rr_per_rrset=one_rr_per_rrset,
+ ignore_trailing=ignore_trailing)
+ return (r, received_time)
+
+
+async def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False, sock=None,
+ backend=None):
+ """Return the response obtained after sending a query via TCP.
+
+ *q*, a ``dns.message.Message``, the query to send
+
+ *where*, a ``str`` containing an IPv4 or IPv6 address, where
+ to send the message.
+
+ *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
+ query times out. If ``None``, the default, wait forever.
+
+ *port*, an ``int``, the port send the message to. The default is 53.
+
+ *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
+ the source address. The default is the wildcard address.
+
+ *source_port*, an ``int``, the port from which to send the message.
+ The default is 0.
+
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
+
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
+ junk at end of the received message.
+
+ *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the
+ socket to use for the query. If ``None``, the default, a socket
+ is created. Note that if a socket is provided
+ *where*, *port*, *source* and *source_port* are ignored.
+
+ *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
+ the default, then dnspython will use the default backend.
+
+ Returns a ``dns.message.Message``.
+ """
+
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ wire = q.to_wire()
+ (begin_time, expiration) = _compute_times(timeout)
+ s = None
+ try:
+ if sock:
+ # Verify that the socket is connected, as if it's not connected,
+ # it's not writable, and the polling in send_tcp() will time out or
+ # hang forever.
+ await sock.getpeername()
+ s = sock
+ else:
+ # These are simple (address, port) pairs, not
+ # family-dependent tuples you pass to lowlevel socket
+ # code.
+ af = dns.inet.af_for_address(where)
+ stuple = _source_tuple(af, source, source_port)
+ dtuple = (where, port)
+ s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple,
+ dtuple, timeout)
+ await send_tcp(s, wire, expiration)
+ (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset,
+ q.keyring, q.mac,
+ ignore_trailing)
+ r.time = received_time - begin_time
+ if not q.is_response(r):
+ raise BadResponse
+ return r
+ finally:
+ if not sock and s:
+ await s.close()
diff --git a/dns/trio/resolver.py b/dns/asyncresolver.py
index 07e70f9..b45a35b 100644
--- a/dns/trio/resolver.py
+++ b/dns/asyncresolver.py
@@ -15,31 +15,32 @@
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-"""trio async I/O library DNS stub resolver."""
+"""Asynchronous DNS stub resolver."""
-import trio
+import time
+import dns.asyncbackend
+import dns.asyncquery
import dns.exception
import dns.query
import dns.resolver
-import dns.trio.query
# import some resolver symbols for brevity
from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
-# we do this for indentation reasons below
-_udp = dns.trio.query.udp
-_stream = dns.trio.query.stream
-class TooManyAttempts(dns.exception.DNSException):
- """A resolution had too many unsuccessful attempts."""
+# for identation purposes below
+_udp = dns.asyncquery.udp
+_tcp = dns.asyncquery.tcp
+
class Resolver(dns.resolver.Resolver):
async def resolve(self, qname, rdtype=dns.rdatatype.A,
rdclass=dns.rdataclass.IN,
tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, search=None):
+ source_port=0, lifetime=None, search=None,
+ backend=None):
"""Query nameservers asynchronously to find the answer to the question.
The *qname*, *rdtype*, and *rdclass* parameters may be objects
@@ -62,6 +63,9 @@ class Resolver(dns.resolver.Resolver):
*source_port*, an ``int``, the port from which to send the message.
+ *lifetime*, a ``float``, how many seconds a query should run
+ before timing out.
+
*search*, a ``bool`` or ``None``, determines whether the
search list configured in the system's resolver configuration
are used for relative names, and whether the resolver's domain
@@ -69,6 +73,9 @@ class Resolver(dns.resolver.Resolver):
which causes the value of the resolver's
``use_search_by_default`` attribute to be used.
+ *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``,
+ the default, then dnspython will use the default backend.
+
Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist.
Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after
@@ -87,6 +94,9 @@ class Resolver(dns.resolver.Resolver):
resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
raise_on_no_answer, search)
+ if not backend:
+ backend = dns.asyncbackend.get_default_backend()
+ start = time.time()
while True:
(request, answer) = resolution.next_request()
# Note we need to say "if answer is not None" and not just
@@ -101,30 +111,24 @@ class Resolver(dns.resolver.Resolver):
while not done:
(nameserver, port, tcp, backoff) = resolution.next_nameserver()
if backoff:
- loops += 1
- if loops >= 5:
- raise TooManyAttempts
- await trio.sleep(backoff)
+ await backend.sleep(backoff)
+ timeout = self._compute_timeout(start, lifetime)
try:
- with trio.fail_after(self.timeout):
- if dns.inet.is_address(nameserver):
- if tcp:
- response = await \
- _stream(request, nameserver,
- port=port,
- source=source,
- source_port=source_port)
- else:
- response = await \
- _udp(request,
- nameserver,
- port=port,
- source=source,
- source_port=source_port,
- raise_on_truncation=True)
+ if dns.inet.is_address(nameserver):
+ if tcp:
+ response = await _tcp(request, nameserver,
+ timeout, port,
+ source, source_port,
+ backend=backend)
else:
- # We don't do DoH yet.
- raise NotImplementedError
+ response = await _udp(request, nameserver,
+ timeout, port,
+ source, source_port,
+ raise_on_truncation=True,
+ backend=backend)
+ else:
+ # We don't do DoH yet.
+ raise NotImplementedError
except Exception as ex:
(_, done) = resolution.query_result(None, ex)
continue
@@ -191,7 +195,7 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
This is a convenience function that uses the default resolver
object to make the query.
- See ``dns.trio.resolver.Resolver.resolve`` for more information on the
+ See ``dns.asyncresolver.Resolver.resolve`` for more information on the
parameters.
"""
@@ -203,7 +207,7 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
async def resolve_address(ipaddr, *args, **kwargs):
"""Use a resolver to run a reverse query for PTR records.
- See ``dns.trio.resolver.Resolver.resolve_address`` for more
+ See ``dns.asyncresolver.Resolver.resolve_address`` for more
information on the parameters.
"""
@@ -220,7 +224,7 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False,
*tcp*, a ``bool``. If ``True``, use TCP to make the query.
- *resolver*, a ``dns.trio.resolver.Resolver`` or ``None``, the
+ *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the
resolver to use. If ``None``, the default resolver is used.
Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS
diff --git a/dns/trio/__init__.py b/dns/trio/__init__.py
deleted file mode 100644
index 744f880..0000000
--- a/dns/trio/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-"""trio async I/O library helpers"""
-
-__all__ = [
- 'query',
- 'resolver',
-]
diff --git a/dns/trio/query.py b/dns/trio/query.py
deleted file mode 100644
index a3a28fe..0000000
--- a/dns/trio/query.py
+++ /dev/null
@@ -1,374 +0,0 @@
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-"""trio async I/O library query support"""
-
-import contextlib
-import socket
-import struct
-import time
-import trio
-import trio.socket # type: ignore
-
-import dns.exception
-import dns.inet
-import dns.name
-import dns.message
-import dns.query
-import dns.rcode
-import dns.rdataclass
-import dns.rdatatype
-
-# import query symbols for compatibility and brevity
-from dns.query import ssl, UnexpectedSource, BadResponse
-
-# Function used to create a socket. Can be overridden if needed in special
-# situations.
-socket_factory = trio.socket.socket
-
-async def send_udp(sock, what, destination):
- """Asynchronously send a DNS message to the specified UDP socket.
-
- *sock*, a ``trio.socket.socket``.
-
- *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
-
- *destination*, a destination tuple appropriate for the address family
- of the socket, specifying where to send the query.
-
- Returns an ``(int, float)`` tuple of bytes sent and the sent time.
- """
-
- if isinstance(what, dns.message.Message):
- what = what.to_wire()
- sent_time = time.time()
- n = await sock.sendto(what, destination)
- return (n, sent_time)
-
-
-async def receive_udp(sock, destination, ignore_unexpected=False,
- one_rr_per_rrset=False, keyring=None, request_mac=b'',
- ignore_trailing=False, raise_on_truncation=False):
- """Asynchronously read a DNS message from a UDP socket.
-
- *sock*, a ``trio.socket.socket``.
-
- *destination*, a destination tuple appropriate for the address family
- of the socket, specifying where the associated query was sent.
-
- *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
- unexpected sources.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *keyring*, a ``dict``, the keyring to use for TSIG.
-
- *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
- the TC bit is set.
-
- Raises if the message is malformed, if network errors occur, of if
- there is a timeout.
-
- Returns a ``dns.message.Message`` object.
- """
-
- wire = b''
- while True:
- (wire, from_address) = await sock.recvfrom(65535)
- if dns.query._addresses_equal(sock.family, from_address,
- destination) or \
- (dns.inet.is_multicast(destination[0]) and
- from_address[1:] == destination[1:]):
- break
- if not ignore_unexpected:
- raise UnexpectedSource('got a response from '
- '%s instead of %s' % (from_address,
- destination))
- received_time = time.time()
- r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing,
- raise_on_truncation=raise_on_truncation)
- return (r, received_time)
-
-async def udp(q, where, port=53, source=None, source_port=0,
- ignore_unexpected=False, one_rr_per_rrset=False,
- ignore_trailing=False, raise_on_truncation=False,
- sock=None):
- """Asynchronously return the response obtained after sending a query
- via UDP.
-
- *q*, a ``dns.message.Message``, the query to send
-
- *where*, a ``str`` containing an IPv4 or IPv6 address, where
- to send the message.
-
- *port*, an ``int``, the port send the message to. The default is 53.
-
- *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
- the source address. The default is the wildcard address.
-
- *source_port*, an ``int``, the port from which to send the message.
- The default is 0.
-
- *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
- unexpected sources.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
- the TC bit is set.
-
- *sock*, a ``trio.socket.socket``, or ``None``, the socket to use
- for the query. If ``None``, the default, a socket is created. if
- a socket is provided, the *source* and *source_port* are ignored.
-
- Returns a ``dns.message.Message``.
-
- """
-
- wire = q.to_wire()
- (af, destination, source) = \
- dns.query._destination_and_source(None, where, port, source,
- source_port)
- # We can use an ExitStack here as exiting a trio.socket.socket does
- # not await.
- with contextlib.ExitStack() as stack:
- if sock:
- s = sock
- else:
- s = stack.enter_context(socket_factory(af, socket.SOCK_DGRAM, 0))
- if source is not None:
- await s.bind(source)
- (_, sent_time) = await send_udp(s, wire, destination)
- (r, received_time) = await receive_udp(s, destination,
- ignore_unexpected,
- one_rr_per_rrset, q.keyring,
- q.mac, ignore_trailing,
- raise_on_truncation)
- if not q.is_response(r):
- raise BadResponse
- r.time = received_time - sent_time
- return r
-
-async def udp_with_fallback(q, where, timeout=None, port=53, source=None,
- source_port=0, ignore_unexpected=False,
- one_rr_per_rrset=False, ignore_trailing=False):
- """Return the response to the query, trying UDP first and falling back
- to TCP if UDP results in a truncated response.
-
- *q*, a ``dns.message.Message``, the query to send
-
- *where*, a ``str`` containing an IPv4 or IPv6 address, where
- to send the message.
-
- *port*, an ``int``, the port send the message to. The default is 53.
-
- *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
- the source address. The default is the wildcard address.
-
- *source_port*, an ``int``, the port from which to send the message.
- The default is 0.
-
- *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from
- unexpected sources.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True``
- if and only if TCP was used.
- """
- try:
- response = await udp(q, where, port, source, source_port,
- ignore_unexpected, one_rr_per_rrset,
- ignore_trailing, True)
- return (response, False)
- except dns.message.Truncated:
- response = await stream(q, where, False, port, source, source_port,
- one_rr_per_rrset, ignore_trailing)
-
- return (response, True)
-
-# pylint: disable=redefined-outer-name
-
-async def send_stream(stream, what):
- """Asynchronously send a DNS message to the specified stream.
-
- *stream*, a ``trio.abc.Stream``.
-
- *what*, a ``bytes`` or ``dns.message.Message``, the message to send.
-
- Returns an ``(int, float)`` tuple of bytes sent and the sent time.
- """
-
- if isinstance(what, dns.message.Message):
- what = what.to_wire()
- l = len(what)
- # copying the wire into tcpmsg is inefficient, but lets us
- # avoid writev() or doing a short write that would get pushed
- # onto the net
- stream_message = struct.pack("!H", l) + what
- sent_time = time.time()
- await stream.send_all(stream_message)
- return (len(stream_message), sent_time)
-
-async def read_exactly(stream, count):
- """Read the specified number of bytes from stream. Keep trying until we
- either get the desired amount, or we hit EOF.
- """
- s = b''
- while count > 0:
- n = await stream.receive_some(count)
- if n == b'':
- raise EOFError
- count = count - len(n)
- s = s + n
- return s
-
-async def receive_stream(stream, one_rr_per_rrset=False, keyring=None,
- request_mac=b'', ignore_trailing=False):
- """Read a DNS message from a stream.
-
- *stream*, a ``trio.abc.Stream``.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *keyring*, a ``dict``, the keyring to use for TSIG.
-
- *request_mac*, a ``bytes``, the MAC of the request (for TSIG).
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- Raises if the message is malformed, if network errors occur, of if
- there is a timeout.
-
- Returns a ``dns.message.Message`` object.
- """
-
- ldata = await read_exactly(stream, 2)
- (l,) = struct.unpack("!H", ldata)
- wire = await read_exactly(stream, l)
- received_time = time.time()
- r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
- one_rr_per_rrset=one_rr_per_rrset,
- ignore_trailing=ignore_trailing)
- return (r, received_time)
-
-async def stream(q, where, tls=False, port=None, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False,
- stream=None, ssl_context=None, server_hostname=None):
- """Return the response obtained after sending a query using TCP or TLS.
-
- *q*, a ``dns.message.Message``, the query to send.
-
- *where*, a ``str`` containing an IPv4 or IPv6 address, where
- to send the message.
-
- *tls*, a ``bool``. If ``False``, the default, the query will be
- sent using TCP and *port* will default to 53. If ``True``, the
- query is sent using TLS, and *port* will default to 853.
-
- *port*, an ``int``, the port send the message to. The default is as
- specified in the description for *tls*.
-
- *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
- the source address. The default is the wildcard address.
-
- *source_port*, an ``int``, the port from which to send the message.
- The default is 0.
-
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
- RRset.
-
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the received message.
-
- *stream*, a ``trio.abc.Stream``, or ``None``, the stream to use for
- the query. If ``None``, the default, a stream is created. if a
- socket is provided, it must be connected, and the *where*, *port*,
- *tls*, *source*, *source_port*, *ssl_context*, and
- *server_hostname* parameters are ignored.
-
- *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
- a TLS connection. If ``None``, the default, creates one with the default
- configuration. If this value is not ``None``, then the *tls* parameter
- is treated as if it were ``True`` regardless of its value.
-
- *server_hostname*, a ``str`` containing the server's hostname. The
- default is ``None``, which means that no hostname is known, and if an
- SSL context is created, hostname checking will be disabled.
-
- Returns a ``dns.message.Message``.
-
- """
- if ssl_context is not None:
- tls = True
- if port is None:
- if tls:
- port = 853
- else:
- port = 53
- wire = q.to_wire()
- # We'd like to be able to use an AsyncExitStack here, because
- # unlike closing a socket, closing a stream requires an await, but
- # that's a 3.7 feature, so we are forced to try ... finally.
- sock = None
- s = None
- begin_time = time.time()
- try:
- if stream:
- #
- # Verify that the socket is connected, as if it's not connected,
- # it's not writable, and the polling in send_tcp() will time out or
- # hang forever.
- if isinstance(stream, trio.SSLStream):
- tsock = stream.transport_stream.socket
- else:
- tsock = stream.socket
- tsock.getpeername()
- s = stream
- else:
- (af, destination, source) = \
- dns.query._destination_and_source(None, where, port, source,
- source_port)
- sock = socket_factory(af, socket.SOCK_STREAM, 0)
- if source is not None:
- await sock.bind(source)
- await sock.connect(destination)
- s = trio.SocketStream(sock)
- sock = None
- if tls and ssl_context is None:
- ssl_context = ssl.create_default_context()
- if server_hostname is None:
- ssl_context.check_hostname = False
- if ssl_context:
- s = trio.SSLStream(s, ssl_context,
- server_hostname=server_hostname)
- await send_stream(s, wire)
- (r, received_time) = await receive_stream(s, one_rr_per_rrset,
- q.keyring, q.mac,
- ignore_trailing)
- if not q.is_response(r):
- raise BadResponse
- r.time = received_time - begin_time
- return r
- finally:
- if sock:
- sock.close()
- if s and s != stream:
- await s.aclose()
diff --git a/dns/trio/query.pyi b/dns/trio/query.pyi
deleted file mode 100644
index 0a5ab92..0000000
--- a/dns/trio/query.pyi
+++ /dev/null
@@ -1,33 +0,0 @@
-from typing import Optional, Dict, Any
-from . import rdatatype, rdataclass, name, message
-
-# If the ssl import works, then
-#
-# error: Name 'ssl' already defined (by an import)
-#
-# is expected and can be ignored.
-try:
- import ssl
-except ImportError:
- class ssl: # type: ignore
- SSLContext : Dict = {}
-
-import trio
-
-def udp(q : message.Message, where : str, port=53,
- source : Optional[str] = None, source_port : Optional[int] = 0,
- ignore_unexpected : Optional[bool] = False,
- one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False,
- sock : Optional[trio.socket.socket] = None) -> message.Message:
- ...
-
-def stream(q : message.Message, where : str, tls : Optional[bool] = False,
- port=53, source : Optional[str] = None,
- source_port : Optional[int] = 0,
- one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False,
- stream : Optional[trio.abc.Stream] = None,
- ssl_context: Optional[ssl.SSLContext] = None,
- server_hostname: Optional[str] = None) -> message.Message:
- ...
diff --git a/dns/trio/resolver.pyi b/dns/trio/resolver.pyi
deleted file mode 100644
index d84419b..0000000
--- a/dns/trio/resolver.pyi
+++ /dev/null
@@ -1,26 +0,0 @@
-from typing import Union, Optional, List, Any, Dict
-from .. import exception, rdataclass, name, rdatatype
-
-def resolve(qname : str, rdtype : Union[int,str] = 0,
- rdclass : Union[int,str] = 0,
- tcp=False, source=None, raise_on_no_answer=True,
- source_port=0, search : Optional[bool]=None):
- ...
-
-def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Optional[Dict]):
- ...
-
-def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False,
- resolver : Optional[Resolver] = None):
- ...
-
-class Resolver:
- def __init__(self, filename : Optional[str] = '/etc/resolv.conf',
- configure : Optional[bool] = True):
- self.nameservers : List[str]
- def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A,
- rdclass : Union[int,str] = rdataclass.IN,
- tcp : bool = False, source : Optional[str] = None,
- raise_on_no_answer=True, source_port : int = 0,
- search : Optional[bool]=None):
- ...
diff --git a/pyproject.toml b/pyproject.toml
index 44a1dd7..b33fe00 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -14,7 +14,7 @@ requests-toolbelt = {version="^0.9.1", optional=true}
requests = {version="^2.23.0", optional=true}
idna = {version="^2.1", optional=true}
cryptography = {version="^2.6", optional=true}
-trio = {version="^0.14.0", optional=true}
+trio = {version="^0.14", optional=true}
[tool.poetry.dev-dependencies]
mypy = "^0.770"
@@ -28,6 +28,7 @@ doh = ['requests', 'requests-toolbelt']
idna = ['idna']
dnssec = ['cryptography']
trio = ['trio']
+curio = ['curio', 'sniffio']
[build-system]
requires = ["poetry>=0.12"]
diff --git a/setup.py b/setup.py
index 50a3da1..5dfc61e 100755
--- a/setup.py
+++ b/setup.py
@@ -50,7 +50,7 @@ direct manipulation of DNS zones, messages, names, and records.""",
'license' : 'ISC',
'url' : 'http://www.dnspython.org',
'packages' : ['dns', 'dns.rdtypes', 'dns.rdtypes.IN', 'dns.rdtypes.ANY',
- 'dns.rdtypes.CH', 'dns.trio'],
+ 'dns.rdtypes.CH'],
'package_data' : {'dns': ['py.typed']},
'download_url' : \
'http://www.dnspython.org/kits/{}/dnspython-{}.tar.gz'.format(version, version),
diff --git a/tests/test_async.py b/tests/test_async.py
new file mode 100644
index 0000000..c09941b
--- /dev/null
+++ b/tests/test_async.py
@@ -0,0 +1,219 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import asyncio
+import socket
+import unittest
+
+import dns.asyncbackend
+import dns.asyncquery
+import dns.asyncresolver
+import dns.message
+import dns.name
+import dns.rdataclass
+import dns.rdatatype
+import dns.resolver
+
+# Some tests require the internet to be available to run, so let's
+# skip those if it's not there.
+_network_available = True
+try:
+ socket.gethostbyname('dnspython.org')
+except socket.gaierror:
+ _network_available = False
+
+@unittest.skipIf(not _network_available, "Internet not reachable")
+class AsyncTests(unittest.TestCase):
+
+ def setUp(self):
+ self.backend = dns.asyncbackend.set_default_backend('asyncio')
+
+ def async_run(self, afunc):
+ return asyncio.run(afunc())
+
+ def testResolve(self):
+ async def run():
+ answer = await dns.asyncresolver.resolve('dns.google.', 'A')
+ return set([rdata.address for rdata in answer])
+ seen = self.async_run(run)
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+ def testResolveAddress(self):
+ async def run():
+ return await dns.asyncresolver.resolve_address('8.8.8.8')
+ answer = self.async_run(run)
+ dnsgoogle = dns.name.from_text('dns.google.')
+ self.assertEqual(answer[0].target, dnsgoogle)
+
+ def testZoneForName1(self):
+ async def run():
+ name = dns.name.from_text('www.dnspython.org.')
+ return await dns.asyncresolver.zone_for_name(name)
+ ezname = dns.name.from_text('dnspython.org.')
+ zname = self.async_run(run)
+ self.assertEqual(zname, ezname)
+
+ def testZoneForName2(self):
+ async def run():
+ name = dns.name.from_text('a.b.www.dnspython.org.')
+ return await dns.asyncresolver.zone_for_name(name)
+ ezname = dns.name.from_text('dnspython.org.')
+ zname = self.async_run(run)
+ self.assertEqual(zname, ezname)
+
+ def testZoneForName3(self):
+ async def run():
+ name = dns.name.from_text('dnspython.org.')
+ return await dns.asyncresolver.zone_for_name(name)
+ ezname = dns.name.from_text('dnspython.org.')
+ zname = self.async_run(run)
+ self.assertEqual(zname, ezname)
+
+ def testZoneForName4(self):
+ def bad():
+ name = dns.name.from_text('dnspython.org', None)
+ async def run():
+ return await dns.asyncresolver.zone_for_name(name)
+ self.async_run(run)
+ self.assertRaises(dns.resolver.NotAbsolute, bad)
+
+ def testQueryUDP(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.asyncquery.udp(q, '8.8.8.8')
+ response = self.async_run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+ def testQueryUDPWithSocket(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ async with await self.backend.make_socket(socket.AF_INET,
+ socket.SOCK_DGRAM) as s:
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.asyncquery.udp(q, '8.8.8.8', sock=s)
+ response = self.async_run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+ def testQueryTCP(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.asyncquery.tcp(q, '8.8.8.8')
+ response = self.async_run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+ def testQueryTCPWithSocket(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ async with await self.backend.make_socket(socket.AF_INET,
+ socket.SOCK_STREAM, 0,
+ None,
+ ('8.8.8.8', 53)) as s:
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.asyncquery.tcp(q, '8.8.8.8', sock=s)
+ response = self.async_run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
+ # def testQueryTLS(self):
+ # qname = dns.name.from_text('dns.google.')
+ # async def run():
+ # q = dns.message.make_query(qname, dns.rdatatype.A)
+ # return await dns.asyncquery.stream(q, '8.8.8.8', True)
+ # response = self.async_run(run)
+ # rrs = response.get_rrset(response.answer, qname,
+ # dns.rdataclass.IN, dns.rdatatype.A)
+ # self.assertTrue(rrs is not None)
+ # seen = set([rdata.address for rdata in rrs])
+ # self.assertTrue('8.8.8.8' in seen)
+ # self.assertTrue('8.8.4.4' in seen)
+
+ # def testQueryTLSWithSocket(self):
+ # qname = dns.name.from_text('dns.google.')
+ # async def run():
+ # async with await trio.open_ssl_over_tcp_stream('8.8.8.8',
+ # 853) as s:
+ # q = dns.message.make_query(qname, dns.rdatatype.A)
+ # return await dns.asyncquery.stream(q, '8.8.8.8', stream=s)
+ # response = self.async_run(run)
+ # rrs = response.get_rrset(response.answer, qname,
+ # dns.rdataclass.IN, dns.rdatatype.A)
+ # self.assertTrue(rrs is not None)
+ # seen = set([rdata.address for rdata in rrs])
+ # self.assertTrue('8.8.8.8' in seen)
+ # self.assertTrue('8.8.4.4' in seen)
+
+ def testQueryUDPFallback(self):
+ qname = dns.name.from_text('.')
+ async def run():
+ q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+ return await dns.asyncquery.udp_with_fallback(q, '8.8.8.8')
+ (_, tcp) = self.async_run(run)
+ self.assertTrue(tcp)
+
+ def testQueryUDPFallbackNoFallback(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.asyncquery.udp_with_fallback(q, '8.8.8.8')
+ (_, tcp) = self.async_run(run)
+ self.assertFalse(tcp)
+
+try:
+ import trio
+
+ class TrioAsyncTests(AsyncTests):
+ def setUp(self):
+ self.backend = dns.asyncbackend.set_default_backend('trio')
+
+ def async_run(self, afunc):
+ return trio.run(afunc)
+except ImportError:
+ pass
+
+try:
+ import curio
+
+ class CurioAsyncTests(AsyncTests):
+ def setUp(self):
+ self.backend = dns.asyncbackend.set_default_backend('curio')
+
+ def async_run(self, afunc):
+ return curio.run(afunc)
+except ImportError:
+ pass
diff --git a/tests/test_trio.py b/tests/test_trio.py
deleted file mode 100644
index 8304a1f..0000000
--- a/tests/test_trio.py
+++ /dev/null
@@ -1,189 +0,0 @@
-# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
-
-# Copyright (C) 2003-2017 Nominum, Inc.
-#
-# Permission to use, copy, modify, and distribute this software and its
-# documentation for any purpose with or without fee is hereby granted,
-# provided that the above copyright notice and this permission notice
-# appear in all copies.
-#
-# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
-# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
-# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
-# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
-# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
-# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
-# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
-
-import socket
-import unittest
-
-try:
- import trio
- import trio.socket
-
- import dns.message
- import dns.name
- import dns.rdataclass
- import dns.rdatatype
- import dns.trio.query
- import dns.trio.resolver
-
- # Some tests require the internet to be available to run, so let's
- # skip those if it's not there.
- _network_available = True
- try:
- socket.gethostbyname('dnspython.org')
- except socket.gaierror:
- _network_available = False
-
- @unittest.skipIf(not _network_available, "Internet not reachable")
- class TrioTests(unittest.TestCase):
-
- def testResolve(self):
- async def run():
- answer = await dns.trio.resolver.resolve('dns.google.', 'A')
- return set([rdata.address for rdata in answer])
- seen = trio.run(run)
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
-
- def testResolveAddress(self):
- async def run():
- return await dns.trio.resolver.resolve_address('8.8.8.8')
- answer = trio.run(run)
- dnsgoogle = dns.name.from_text('dns.google.')
- self.assertEqual(answer[0].target, dnsgoogle)
-
- def testZoneForName1(self):
- async def run():
- name = dns.name.from_text('www.dnspython.org.')
- return await dns.trio.resolver.zone_for_name(name)
- ezname = dns.name.from_text('dnspython.org.')
- zname = trio.run(run)
- self.assertEqual(zname, ezname)
-
- def testZoneForName2(self):
- async def run():
- name = dns.name.from_text('a.b.www.dnspython.org.')
- return await dns.trio.resolver.zone_for_name(name)
- ezname = dns.name.from_text('dnspython.org.')
- zname = trio.run(run)
- self.assertEqual(zname, ezname)
-
- def testZoneForName3(self):
- async def run():
- name = dns.name.from_text('dnspython.org.')
- return await dns.trio.resolver.zone_for_name(name)
- ezname = dns.name.from_text('dnspython.org.')
- zname = trio.run(run)
- self.assertEqual(zname, ezname)
-
- def testZoneForName4(self):
- def bad():
- name = dns.name.from_text('dnspython.org', None)
- async def run():
- return await dns.trio.resolver.zone_for_name(name)
- trio.run(run)
- self.assertRaises(dns.resolver.NotAbsolute, bad)
-
- def testQueryUDP(self):
- qname = dns.name.from_text('dns.google.')
- async def run():
- q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.trio.query.udp(q, '8.8.8.8')
- response = trio.run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
- self.assertTrue(rrs is not None)
- seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
-
- def testQueryUDPWithSocket(self):
- qname = dns.name.from_text('dns.google.')
- async def run():
- with trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
- q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.trio.query.udp(q, '8.8.8.8', sock=s)
- response = trio.run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
- self.assertTrue(rrs is not None)
- seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
-
- def testQueryTCP(self):
- qname = dns.name.from_text('dns.google.')
- async def run():
- q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.trio.query.stream(q, '8.8.8.8')
- response = trio.run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
- self.assertTrue(rrs is not None)
- seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
-
- def testQueryTCPWithSocket(self):
- qname = dns.name.from_text('dns.google.')
- async def run():
- async with await trio.open_tcp_stream('8.8.8.8', 53) as s:
- q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
- response = trio.run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
- self.assertTrue(rrs is not None)
- seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
-
- def testQueryTLS(self):
- qname = dns.name.from_text('dns.google.')
- async def run():
- q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.trio.query.stream(q, '8.8.8.8', True)
- response = trio.run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
- self.assertTrue(rrs is not None)
- seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
-
- def testQueryTLSWithSocket(self):
- qname = dns.name.from_text('dns.google.')
- async def run():
- async with await trio.open_ssl_over_tcp_stream('8.8.8.8',
- 853) as s:
- q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
- response = trio.run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
- self.assertTrue(rrs is not None)
- seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
-
- def testQueryUDPFallback(self):
- qname = dns.name.from_text('.')
- async def run():
- q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
- return await dns.trio.query.udp_with_fallback(q, '8.8.8.8')
- (_, tcp) = trio.run(run)
- self.assertTrue(tcp)
-
- def testQueryUDPFallbackNoFallback(self):
- qname = dns.name.from_text('dns.google.')
- async def run():
- q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.trio.query.udp_with_fallback(q, '8.8.8.8')
- (_, tcp) = trio.run(run)
- self.assertFalse(tcp)
-
-except ModuleNotFoundError:
- pass