diff options
author | Bob Halley <halley@dnspython.org> | 2020-06-11 18:50:30 -0700 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2020-06-11 18:50:30 -0700 |
commit | 98b344d625190c3b536682775fb0e0340e715063 (patch) | |
tree | 5882b861b4579633467106a1571503a586037be3 | |
parent | fa1d4938a8296b4ae88f7bbc5dfa67377995b9e2 (diff) | |
download | dnspython-98b344d625190c3b536682775fb0e0340e715063.tar.gz |
Support trio, curio, and asyncio with one API!
-rw-r--r-- | dns/__init__.py | 2 | ||||
-rw-r--r-- | dns/_asyncbackend.py | 77 | ||||
-rw-r--r-- | dns/_asyncio_backend.py | 118 | ||||
-rw-r--r-- | dns/_curio_backend.py | 92 | ||||
-rw-r--r-- | dns/_trio_backend.py | 92 | ||||
-rw-r--r-- | dns/asyncbackend.py | 43 | ||||
-rw-r--r-- | dns/asyncquery.py | 422 | ||||
-rw-r--r-- | dns/asyncresolver.py (renamed from dns/trio/resolver.py) | 72 | ||||
-rw-r--r-- | dns/trio/__init__.py | 8 | ||||
-rw-r--r-- | dns/trio/query.py | 374 | ||||
-rw-r--r-- | dns/trio/query.pyi | 33 | ||||
-rw-r--r-- | dns/trio/resolver.pyi | 26 | ||||
-rw-r--r-- | pyproject.toml | 3 | ||||
-rwxr-xr-x | setup.py | 2 | ||||
-rw-r--r-- | tests/test_async.py | 219 | ||||
-rw-r--r-- | tests/test_trio.py | 189 |
16 files changed, 1106 insertions, 666 deletions
diff --git a/dns/__init__.py b/dns/__init__.py index d5cadb8..6412fb5 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -18,6 +18,8 @@ """dnspython DNS toolkit""" __all__ = [ + 'asyncquery.py', + 'asyncresolver.py', 'dnssec', 'e164', 'edns', diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py new file mode 100644 index 0000000..9bfdaba --- /dev/null +++ b/dns/_asyncbackend.py @@ -0,0 +1,77 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + + +import dns.inet + + +# This is a nullcontext for both sync and async + +class NullContext: + def __init__(self, enter_result=None): + self.enter_result = enter_result + + def __enter__(self): + return self.enter_result + + def __exit__(self, exc_type, exc_value, traceback): + pass + + async def __aenter__(self): + return self.enter_result + + async def __aexit__(self, exc_type, exc_value, traceback): + pass + + +# This is handy, but should probably move somewhere else! + +def low_level_address_tuple(af, high_level_address_tuple): + address, port = high_level_address_tuple + if af == dns.inet.AF_INET: + return (address, port) + elif af == dns.inet.AF_INET6: + ai_flags = socket.AI_NUMERICHOST + ((*_, tup), *_) = socket.getaddrinfo(address, port, flags=ai_flags) + return tup + else: + raise NotImplementedError(f'unknown address family {af}') + + +# These are declared here so backends can import them without creating +# circular dependencies with dns.asyncbackend. + +class Socket: + async def close(self): + pass + + async def __aenter__(self): + pass + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + +class DatagramSocket(Socket): + async def sendto(self, what, destination, timeout): + pass + + async def recvfrom(self, size, timeout): + pass + + +class StreamSocket(Socket): + async def sendall(self, what, destination, timeout): + pass + + async def recv(self, size, timeout): + pass + + +class Backend: + def name(self): + return 'unknown' + + async def make_socket(self, af, socktype, proto=0, + source=None, raw_source=None, + ssl_context=None, server_hostname=None): + raise NotImplementedError diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py new file mode 100644 index 0000000..42c6e66 --- /dev/null +++ b/dns/_asyncio_backend.py @@ -0,0 +1,118 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""asyncio library query support""" + +import socket +import asyncio + +import dns._asyncbackend +import dns.exception + +class _DatagramProtocol: + def __init__(self): + self.transport = None + self.recvfrom = None + + def connection_made(self, transport): + self.transport = transport + + def datagram_received(self, data, addr): + if self.recvfrom: + self.recvfrom.set_result((data, addr)) + self.recvfrom = None + + def error_received(self, exc): + if self.recvfrom: + self.recvfrom.set_exception(exc) + + def connection_lost(self, exc): + if self.recvfrom: + self.recvfrom.set_exception(exc) + + def close(self): + self.transport.close() + + +async def _maybe_wait_for(awaitable, timeout): + if timeout: + try: + return await asyncio.wait_for(awaitable, timeout) + except asyncio.TimeoutError: + raise dns.exception.Timeout(timeout=timeout) + else: + return await awaitable + +class DatagramSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, family, transport, protocol): + self.family = family + self.transport = transport + self.protocol = protocol + + async def sendto(self, what, destination, timeout): + # no timeout for asyncio sendto + self.transport.sendto(what, destination) + + async def recvfrom(self, timeout): + done = asyncio.get_running_loop().create_future() + assert self.protocol.recvfrom is None + self.protocol.recvfrom = done + await _maybe_wait_for(done, timeout) + return done.result() + + async def close(self): + self.protocol.close() + + async def getpeername(self): + return self.transport.get_extra_info('peername') + + +class StreamSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, af, reader, writer): + self.family = af + self.reader = reader + self.writer = writer + + async def sendall(self, what, timeout): + self.writer.write(what), + return await _maybe_wait_for(self.writer.drain(), timeout) + raise dns.exception.Timeout(timeout=timeout) + + async def recv(self, count, timeout): + return await _maybe_wait_for(self.reader.read(count), + timeout) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + self.writer.close() + await self.writer.wait_closed() + + async def getpeername(self): + return self.reader.get_extra_info('peername') + + +class Backend(dns._asyncbackend.Backend): + def name(self): + return 'asyncio' + + async def make_socket(self, af, socktype, proto=0, + source=None, destination=None, timeout=None, + ssl_context=None, server_hostname=None): + loop = asyncio.get_running_loop() + if socktype == socket.SOCK_DGRAM: + transport, protocol = await loop.create_datagram_endpoint( + _DatagramProtocol, source, family=af, + proto=proto) + return DatagramSocket(af, transport, protocol) + elif socktype == socket.SOCK_STREAM: + (r, w) = await _maybe_wait_for( + asyncio.open_connection(destination[0], + destination[1], + family=af, + proto=proto, + local_addr=source), + timeout) + return StreamSocket(af, r, w) + raise NotImplementedError(f'unsupported socket type {socktype}') + + async def sleep(self, interval): + await asyncio.sleep(interval) diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py new file mode 100644 index 0000000..e37fea3 --- /dev/null +++ b/dns/_curio_backend.py @@ -0,0 +1,92 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""curio async I/O library query support""" + +import socket +import curio +import curio.socket # type: ignore + +import dns._asyncbackend +import dns.exception + + +def _maybe_timeout(timeout): + if timeout: + return curio.ignore_after(timeout) + else: + return dns._asyncbackend.NullContext() + + +# for brevity +_lltuple = dns._asyncbackend.low_level_address_tuple + + +class DatagramSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, socket): + self.socket = socket + self.family = socket.family + + async def sendto(self, what, destination, timeout): + async with _maybe_timeout(timeout): + return await self.socket.sendto(what, destination) + raise dns.exception.Timeout(timeout=timeout) + + async def recvfrom(self, timeout): + async with _maybe_timeout(timeout): + return await self.socket.recvfrom(65535) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + await self.socket.close() + + async def getpeername(self): + return self.socket.getpeername() + + +class StreamSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, socket): + self.socket = socket + self.family = socket.family + + async def sendall(self, what, timeout): + async with _maybe_timeout(timeout): + return await self.socket.sendall(what) + raise dns.exception.Timeout(timeout=timeout) + + async def recv(self, count, timeout): + async with _maybe_timeout(timeout): + return await self.socket.recv(count) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + await self.socket.close() + + async def getpeername(self): + return self.socket.getpeername() + + +class Backend(dns._asyncbackend.Backend): + def name(self): + return 'curio' + + async def make_socket(self, af, socktype, proto=0, + source=None, destination=None, timeout=None, + ssl_context=None, server_hostname=None): + s = curio.socket.socket(af, socktype, proto) + try: + if source: + s.bind(_lltuple(af, source)) + if socktype == socket.SOCK_STREAM: + with _maybe_timeout(timeout): + await s.connect(_lltuple(af, destination)) + except Exception: + await s.close() + raise + if socktype == socket.SOCK_DGRAM: + return DatagramSocket(s) + elif socktype == socket.SOCK_STREAM: + return StreamSocket(s) + raise NotImplementedError(f'unsupported socket type {socktype}') + + async def sleep(self, interval): + await curio.sleep(interval) diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py new file mode 100644 index 0000000..bcaddcc --- /dev/null +++ b/dns/_trio_backend.py @@ -0,0 +1,92 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""trio async I/O library query support""" + +import socket +import trio +import trio.socket # type: ignore + +import dns._asyncbackend +import dns.exception + + +def _maybe_timeout(timeout): + if timeout: + return trio.move_on_after(timeout) + else: + return dns._asyncbackend.NullContext() + + +# for brevity +_lltuple = dns._asyncbackend.low_level_address_tuple + + +class DatagramSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, socket): + self.socket = socket + self.family = socket.family + + async def sendto(self, what, destination, timeout): + with _maybe_timeout(timeout): + return await self.socket.sendto(what, destination) + raise dns.exception.Timeout(timeout=timeout) + + async def recvfrom(self, timeout): + with _maybe_timeout(timeout): + return await self.socket.recvfrom(65535) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + self.socket.close() + + async def getpeername(self): + return self.socket.getpeername() + + +class StreamSocket(dns._asyncbackend.DatagramSocket): + def __init__(self, family, stream): + self.family = family + self.stream = stream + + async def sendall(self, what, timeout): + with _maybe_timeout(timeout): + return await self.stream.send_all(what) + raise dns.exception.Timeout(timeout=timeout) + + async def recv(self, count, timeout): + with _maybe_timeout(timeout): + return await self.stream.receive_some(count) + raise dns.exception.Timeout(timeout=timeout) + + async def close(self): + await self.stream.aclose() + + async def getpeername(self): + return self.stream.socket.getpeername() + + +class Backend(dns._asyncbackend.Backend): + def name(self): + return 'trio' + + async def make_socket(self, af, socktype, proto=0, source=None, + destination=None, timeout=None, + ssl_context=None, server_hostname=None): + s = trio.socket.socket(af, socktype, proto) + try: + if source: + await s.bind(_lltuple(af, source)) + if socktype == socket.SOCK_STREAM: + with _maybe_timeout(timeout): + await s.connect(_lltuple(af, destination)) + except Exception: + s.close() + raise + if socktype == socket.SOCK_DGRAM: + return DatagramSocket(s) + elif socktype == socket.SOCK_STREAM: + return StreamSocket(af, trio.SocketStream(s)) + raise NotImplementedError(f'unsupported socket type {socktype}') + + async def sleep(self, interval): + await trio.sleep(interval) diff --git a/dns/asyncbackend.py b/dns/asyncbackend.py new file mode 100644 index 0000000..92a1ae3 --- /dev/null +++ b/dns/asyncbackend.py @@ -0,0 +1,43 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + + +from dns._asyncbackend import Socket, DatagramSocket, \ + StreamSocket, Backend, low_level_address_tuple + + +_default_backend = None + + +def get_default_backend(): + if _default_backend: + return _default_backend + + return set_default_backend(sniff()) + + +def sniff(): + name = 'asyncio' + try: + import sniffio + name = sniffio.current_async_library() + except Exception: + pass + return name + + +def set_default_backend(name): + global _default_backend + + if name == 'trio': + import dns._trio_backend + _default_backend = dns._trio_backend.Backend() + elif name == 'curio': + import dns._curio_backend + _default_backend = dns._curio_backend.Backend() + elif name == 'asyncio': + import dns._asyncio_backend + _default_backend = dns._asyncio_backend.Backend() + else: + raise NotImplementedException(f'unimplemented async backend {name}') + + return _default_backend diff --git a/dns/asyncquery.py b/dns/asyncquery.py new file mode 100644 index 0000000..ed51fdc --- /dev/null +++ b/dns/asyncquery.py @@ -0,0 +1,422 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""Talk to a DNS server.""" + +import os +import socket +import struct +import time +import base64 +import ipaddress + +import dns.asyncbackend +import dns.exception +import dns.inet +import dns.name +import dns.message +import dns.rcode +import dns.rdataclass +import dns.rdatatype + +from dns.query import _addresses_equal, _destination_and_source, \ + _compute_times, UnexpectedSource + + +# for brevity +_lltuple = dns.asyncbackend.low_level_address_tuple + + +def _source_tuple(af, address, port): + # Make a high level source tuple, or return None if address and port + # are both None + if address or port: + if address is None: + if af == socket.AF_INET: + address = '0.0.0.0' + elif af == socket.AF_INET6: + address = '::' + else: + raise NotImplementedError(f'unknown address family {af}') + return (address, port) + else: + return None + + +def _timeout(expiration, now=None): + if expiration: + if not now: + now = time.time() + return max(expiration - now, 0) + else: + return None + + +async def send_udp(sock, what, destination, expiration=None): + """Send a DNS message to the specified UDP socket. + + *sock*, a ``dns.asyncbackend.DatagramSocket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where to send the query. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + sent_time = time.time() + n = await sock.sendto(what, destination, _timeout(expiration, sent_time)) + return (n, sent_time) + + +async def receive_udp(sock, destination, expiration=None, + ignore_unexpected=False, one_rr_per_rrset=False, + keyring=None, request_mac=b'', ignore_trailing=False, + raise_on_truncation=False): + """Read a DNS message from a UDP socket. + + *sock*, a ``dns.asyncbackend.DatagramSocket``. + + *destination*, a destination tuple appropriate for the address family + of the socket, specifying where the associated query was sent. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + Returns a ``dns.message.Message`` object. + """ + + wire = b'' + while 1: + (wire, from_address) = await sock.recvfrom(65535) + if _addresses_equal(sock.family, from_address, destination) or \ + (dns.inet.is_multicast(destination[0]) and + from_address[1:] == destination[1:]): + break + if not ignore_unexpected: + raise UnexpectedSource('got a response from ' + '%s instead of %s' % (from_address, + destination)) + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing, + raise_on_truncation=raise_on_truncation) + return (r, received_time) + +async def udp(q, where, timeout=None, port=53, source=None, source_port=0, + ignore_unexpected=False, one_rr_per_rrset=False, + ignore_trailing=False, raise_on_truncation=False, sock=None, + backend=None): + """Return the response obtained after sending a query via UDP. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if + the TC bit is set. + + *sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, + the socket to use for the query. If ``None``, the default, a + socket is created. Note that if a socket is provided, the + *source* and *source_port* are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + Returns a ``dns.message.Message``. + """ + if not backend: + backend = dns.asyncbackend.get_default_backend() + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + s = None + try: + if sock: + s = sock + else: + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) + destination = _lltuple(af, (where, port)) + await send_udp(s, wire, destination, expiration) + (r, received_time) = await receive_udp(s, destination, expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing, + raise_on_truncation) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + finally: + if not sock and s: + await s.close() + +async def udp_with_fallback(q, where, timeout=None, port=53, source=None, + source_port=0, ignore_unexpected=False, + one_rr_per_rrset=False, ignore_trailing=False, + udp_sock=None, tcp_sock=None): + """Return the response to the query, trying UDP first and falling back + to TCP if UDP results in a truncated response. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from + unexpected sources. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *udp_sock*, a ``dns.asyncbackend.DatagramSocket``, or ``None``, + the socket to use for the UDP query. If ``None``, the default, a + socket is created. Note that if a socket is provided the *source* + and *source_port* are ignored for the UDP query. + + *tcp_sock*, a ``dns.asyncbackend.StreamSocket``, or ``None``, the + socket to use for the TCP query. If ``None``, the default, a + socket is created. Note that if a socket is provided *where*, + *source* and *source_port* are ignored for the TCP query. + + Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` + if and only if TCP was used. + """ + try: + response = await udp(q, where, timeout, port, source, source_port, + ignore_unexpected, one_rr_per_rrset, + ignore_trailing, True, udp_sock) + return (response, False) + except dns.message.Truncated: + response = await tcp(q, where, timeout, port, source, source_port, + one_rr_per_rrset, ignore_trailing, tcp_sock) + return (response, True) + + + +async def send_tcp(sock, what, expiration=None): + """Send a DNS message to the specified TCP socket. + + *sock*, a ``socket``. + + *what*, a ``bytes`` or ``dns.message.Message``, the message to send. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + Returns an ``(int, float)`` tuple of bytes sent and the sent time. + """ + + if isinstance(what, dns.message.Message): + what = what.to_wire() + l = len(what) + # copying the wire into tcpmsg is inefficient, but lets us + # avoid writev() or doing a short write that would get pushed + # onto the net + tcpmsg = struct.pack("!H", l) + what + sent_time = time.time() + await sock.sendall(tcpmsg, expiration) + return (len(tcpmsg), sent_time) + + +async def read_exactly(sock, count, expiration): + """Read the specified number of bytes from stream. Keep trying until we + either get the desired amount, or we hit EOF. + """ + s = b'' + while count > 0: + n = await sock.recv(count, _timeout(expiration)) + if n == b'': + raise EOFError + count = count - len(n) + s = s + n + return s + + +async def receive_tcp(sock, expiration=None, one_rr_per_rrset=False, + keyring=None, request_mac=b'', ignore_trailing=False): + """Read a DNS message from a TCP socket. + + *sock*, a ``socket``. + + *expiration*, a ``float`` or ``None``, the absolute time at which + a timeout exception should be raised. If ``None``, no timeout will + occur. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *request_mac*, a ``bytes``, the MAC of the request (for TSIG). + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + Raises if the message is malformed, if network errors occur, of if + there is a timeout. + + Returns a ``dns.message.Message`` object. + """ + + ldata = await read_exactly(sock, 2, expiration) + (l,) = struct.unpack("!H", ldata) + wire = await read_exactly(sock, l, expiration) + received_time = time.time() + r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, + one_rr_per_rrset=one_rr_per_rrset, + ignore_trailing=ignore_trailing) + return (r, received_time) + + +async def tcp(q, where, timeout=None, port=53, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, sock=None, + backend=None): + """Return the response obtained after sending a query via TCP. + + *q*, a ``dns.message.Message``, the query to send + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *timeout*, a ``float`` or ``None``, the number of seconds to wait before the + query times out. If ``None``, the default, wait forever. + + *port*, an ``int``, the port send the message to. The default is 53. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. + + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing + junk at end of the received message. + + *sock*, a ``dns.asyncbacket.StreamSocket``, or ``None``, the + socket to use for the query. If ``None``, the default, a socket + is created. Note that if a socket is provided + *where*, *port*, *source* and *source_port* are ignored. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + Returns a ``dns.message.Message``. + """ + + if not backend: + backend = dns.asyncbackend.get_default_backend() + wire = q.to_wire() + (begin_time, expiration) = _compute_times(timeout) + s = None + try: + if sock: + # Verify that the socket is connected, as if it's not connected, + # it's not writable, and the polling in send_tcp() will time out or + # hang forever. + await sock.getpeername() + s = sock + else: + # These are simple (address, port) pairs, not + # family-dependent tuples you pass to lowlevel socket + # code. + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + s = await backend.make_socket(af, socket.SOCK_STREAM, 0, stuple, + dtuple, timeout) + await send_tcp(s, wire, expiration) + (r, received_time) = await receive_tcp(s, expiration, one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing) + r.time = received_time - begin_time + if not q.is_response(r): + raise BadResponse + return r + finally: + if not sock and s: + await s.close() diff --git a/dns/trio/resolver.py b/dns/asyncresolver.py index 07e70f9..b45a35b 100644 --- a/dns/trio/resolver.py +++ b/dns/asyncresolver.py @@ -15,31 +15,32 @@ # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. -"""trio async I/O library DNS stub resolver.""" +"""Asynchronous DNS stub resolver.""" -import trio +import time +import dns.asyncbackend +import dns.asyncquery import dns.exception import dns.query import dns.resolver -import dns.trio.query # import some resolver symbols for brevity from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA -# we do this for indentation reasons below -_udp = dns.trio.query.udp -_stream = dns.trio.query.stream -class TooManyAttempts(dns.exception.DNSException): - """A resolution had too many unsuccessful attempts.""" +# for identation purposes below +_udp = dns.asyncquery.udp +_tcp = dns.asyncquery.tcp + class Resolver(dns.resolver.Resolver): async def resolve(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, tcp=False, source=None, raise_on_no_answer=True, - source_port=0, search=None): + source_port=0, lifetime=None, search=None, + backend=None): """Query nameservers asynchronously to find the answer to the question. The *qname*, *rdtype*, and *rdclass* parameters may be objects @@ -62,6 +63,9 @@ class Resolver(dns.resolver.Resolver): *source_port*, an ``int``, the port from which to send the message. + *lifetime*, a ``float``, how many seconds a query should run + before timing out. + *search*, a ``bool`` or ``None``, determines whether the search list configured in the system's resolver configuration are used for relative names, and whether the resolver's domain @@ -69,6 +73,9 @@ class Resolver(dns.resolver.Resolver): which causes the value of the resolver's ``use_search_by_default`` attribute to be used. + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + Raises ``dns.resolver.NXDOMAIN`` if the query name does not exist. Raises ``dns.resolver.YXDOMAIN`` if the query name is too long after @@ -87,6 +94,9 @@ class Resolver(dns.resolver.Resolver): resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, raise_on_no_answer, search) + if not backend: + backend = dns.asyncbackend.get_default_backend() + start = time.time() while True: (request, answer) = resolution.next_request() # Note we need to say "if answer is not None" and not just @@ -101,30 +111,24 @@ class Resolver(dns.resolver.Resolver): while not done: (nameserver, port, tcp, backoff) = resolution.next_nameserver() if backoff: - loops += 1 - if loops >= 5: - raise TooManyAttempts - await trio.sleep(backoff) + await backend.sleep(backoff) + timeout = self._compute_timeout(start, lifetime) try: - with trio.fail_after(self.timeout): - if dns.inet.is_address(nameserver): - if tcp: - response = await \ - _stream(request, nameserver, - port=port, - source=source, - source_port=source_port) - else: - response = await \ - _udp(request, - nameserver, - port=port, - source=source, - source_port=source_port, - raise_on_truncation=True) + if dns.inet.is_address(nameserver): + if tcp: + response = await _tcp(request, nameserver, + timeout, port, + source, source_port, + backend=backend) else: - # We don't do DoH yet. - raise NotImplementedError + response = await _udp(request, nameserver, + timeout, port, + source, source_port, + raise_on_truncation=True, + backend=backend) + else: + # We don't do DoH yet. + raise NotImplementedError except Exception as ex: (_, done) = resolution.query_result(None, ex) continue @@ -191,7 +195,7 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, This is a convenience function that uses the default resolver object to make the query. - See ``dns.trio.resolver.Resolver.resolve`` for more information on the + See ``dns.asyncresolver.Resolver.resolve`` for more information on the parameters. """ @@ -203,7 +207,7 @@ async def resolve(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, async def resolve_address(ipaddr, *args, **kwargs): """Use a resolver to run a reverse query for PTR records. - See ``dns.trio.resolver.Resolver.resolve_address`` for more + See ``dns.asyncresolver.Resolver.resolve_address`` for more information on the parameters. """ @@ -220,7 +224,7 @@ async def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, *tcp*, a ``bool``. If ``True``, use TCP to make the query. - *resolver*, a ``dns.trio.resolver.Resolver`` or ``None``, the + *resolver*, a ``dns.asyncresolver.Resolver`` or ``None``, the resolver to use. If ``None``, the default resolver is used. Raises ``dns.resolver.NoRootSOA`` if there is no SOA RR at the DNS diff --git a/dns/trio/__init__.py b/dns/trio/__init__.py deleted file mode 100644 index 744f880..0000000 --- a/dns/trio/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -"""trio async I/O library helpers""" - -__all__ = [ - 'query', - 'resolver', -] diff --git a/dns/trio/query.py b/dns/trio/query.py deleted file mode 100644 index a3a28fe..0000000 --- a/dns/trio/query.py +++ /dev/null @@ -1,374 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -"""trio async I/O library query support""" - -import contextlib -import socket -import struct -import time -import trio -import trio.socket # type: ignore - -import dns.exception -import dns.inet -import dns.name -import dns.message -import dns.query -import dns.rcode -import dns.rdataclass -import dns.rdatatype - -# import query symbols for compatibility and brevity -from dns.query import ssl, UnexpectedSource, BadResponse - -# Function used to create a socket. Can be overridden if needed in special -# situations. -socket_factory = trio.socket.socket - -async def send_udp(sock, what, destination): - """Asynchronously send a DNS message to the specified UDP socket. - - *sock*, a ``trio.socket.socket``. - - *what*, a ``bytes`` or ``dns.message.Message``, the message to send. - - *destination*, a destination tuple appropriate for the address family - of the socket, specifying where to send the query. - - Returns an ``(int, float)`` tuple of bytes sent and the sent time. - """ - - if isinstance(what, dns.message.Message): - what = what.to_wire() - sent_time = time.time() - n = await sock.sendto(what, destination) - return (n, sent_time) - - -async def receive_udp(sock, destination, ignore_unexpected=False, - one_rr_per_rrset=False, keyring=None, request_mac=b'', - ignore_trailing=False, raise_on_truncation=False): - """Asynchronously read a DNS message from a UDP socket. - - *sock*, a ``trio.socket.socket``. - - *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. - - *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from - unexpected sources. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *keyring*, a ``dict``, the keyring to use for TSIG. - - *request_mac*, a ``bytes``, the MAC of the request (for TSIG). - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if - the TC bit is set. - - Raises if the message is malformed, if network errors occur, of if - there is a timeout. - - Returns a ``dns.message.Message`` object. - """ - - wire = b'' - while True: - (wire, from_address) = await sock.recvfrom(65535) - if dns.query._addresses_equal(sock.family, from_address, - destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): - break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) - received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing, - raise_on_truncation=raise_on_truncation) - return (r, received_time) - -async def udp(q, where, port=53, source=None, source_port=0, - ignore_unexpected=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False, - sock=None): - """Asynchronously return the response obtained after sending a query - via UDP. - - *q*, a ``dns.message.Message``, the query to send - - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. - - *port*, an ``int``, the port send the message to. The default is 53. - - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. - - *source_port*, an ``int``, the port from which to send the message. - The default is 0. - - *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from - unexpected sources. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if - the TC bit is set. - - *sock*, a ``trio.socket.socket``, or ``None``, the socket to use - for the query. If ``None``, the default, a socket is created. if - a socket is provided, the *source* and *source_port* are ignored. - - Returns a ``dns.message.Message``. - - """ - - wire = q.to_wire() - (af, destination, source) = \ - dns.query._destination_and_source(None, where, port, source, - source_port) - # We can use an ExitStack here as exiting a trio.socket.socket does - # not await. - with contextlib.ExitStack() as stack: - if sock: - s = sock - else: - s = stack.enter_context(socket_factory(af, socket.SOCK_DGRAM, 0)) - if source is not None: - await s.bind(source) - (_, sent_time) = await send_udp(s, wire, destination) - (r, received_time) = await receive_udp(s, destination, - ignore_unexpected, - one_rr_per_rrset, q.keyring, - q.mac, ignore_trailing, - raise_on_truncation) - if not q.is_response(r): - raise BadResponse - r.time = received_time - sent_time - return r - -async def udp_with_fallback(q, where, timeout=None, port=53, source=None, - source_port=0, ignore_unexpected=False, - one_rr_per_rrset=False, ignore_trailing=False): - """Return the response to the query, trying UDP first and falling back - to TCP if UDP results in a truncated response. - - *q*, a ``dns.message.Message``, the query to send - - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. - - *port*, an ``int``, the port send the message to. The default is 53. - - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. - - *source_port*, an ``int``, the port from which to send the message. - The default is 0. - - *ignore_unexpected*, a ``bool``. If ``True``, ignore responses from - unexpected sources. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - - Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` - if and only if TCP was used. - """ - try: - response = await udp(q, where, port, source, source_port, - ignore_unexpected, one_rr_per_rrset, - ignore_trailing, True) - return (response, False) - except dns.message.Truncated: - response = await stream(q, where, False, port, source, source_port, - one_rr_per_rrset, ignore_trailing) - - return (response, True) - -# pylint: disable=redefined-outer-name - -async def send_stream(stream, what): - """Asynchronously send a DNS message to the specified stream. - - *stream*, a ``trio.abc.Stream``. - - *what*, a ``bytes`` or ``dns.message.Message``, the message to send. - - Returns an ``(int, float)`` tuple of bytes sent and the sent time. - """ - - if isinstance(what, dns.message.Message): - what = what.to_wire() - l = len(what) - # copying the wire into tcpmsg is inefficient, but lets us - # avoid writev() or doing a short write that would get pushed - # onto the net - stream_message = struct.pack("!H", l) + what - sent_time = time.time() - await stream.send_all(stream_message) - return (len(stream_message), sent_time) - -async def read_exactly(stream, count): - """Read the specified number of bytes from stream. Keep trying until we - either get the desired amount, or we hit EOF. - """ - s = b'' - while count > 0: - n = await stream.receive_some(count) - if n == b'': - raise EOFError - count = count - len(n) - s = s + n - return s - -async def receive_stream(stream, one_rr_per_rrset=False, keyring=None, - request_mac=b'', ignore_trailing=False): - """Read a DNS message from a stream. - - *stream*, a ``trio.abc.Stream``. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *keyring*, a ``dict``, the keyring to use for TSIG. - - *request_mac*, a ``bytes``, the MAC of the request (for TSIG). - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - - Raises if the message is malformed, if network errors occur, of if - there is a timeout. - - Returns a ``dns.message.Message`` object. - """ - - ldata = await read_exactly(stream, 2) - (l,) = struct.unpack("!H", ldata) - wire = await read_exactly(stream, l) - received_time = time.time() - r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, - one_rr_per_rrset=one_rr_per_rrset, - ignore_trailing=ignore_trailing) - return (r, received_time) - -async def stream(q, where, tls=False, port=None, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False, - stream=None, ssl_context=None, server_hostname=None): - """Return the response obtained after sending a query using TCP or TLS. - - *q*, a ``dns.message.Message``, the query to send. - - *where*, a ``str`` containing an IPv4 or IPv6 address, where - to send the message. - - *tls*, a ``bool``. If ``False``, the default, the query will be - sent using TCP and *port* will default to 53. If ``True``, the - query is sent using TLS, and *port* will default to 853. - - *port*, an ``int``, the port send the message to. The default is as - specified in the description for *tls*. - - *source*, a ``str`` containing an IPv4 or IPv6 address, specifying - the source address. The default is the wildcard address. - - *source_port*, an ``int``, the port from which to send the message. - The default is 0. - - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own - RRset. - - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the received message. - - *stream*, a ``trio.abc.Stream``, or ``None``, the stream to use for - the query. If ``None``, the default, a stream is created. if a - socket is provided, it must be connected, and the *where*, *port*, - *tls*, *source*, *source_port*, *ssl_context*, and - *server_hostname* parameters are ignored. - - *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing - a TLS connection. If ``None``, the default, creates one with the default - configuration. If this value is not ``None``, then the *tls* parameter - is treated as if it were ``True`` regardless of its value. - - *server_hostname*, a ``str`` containing the server's hostname. The - default is ``None``, which means that no hostname is known, and if an - SSL context is created, hostname checking will be disabled. - - Returns a ``dns.message.Message``. - - """ - if ssl_context is not None: - tls = True - if port is None: - if tls: - port = 853 - else: - port = 53 - wire = q.to_wire() - # We'd like to be able to use an AsyncExitStack here, because - # unlike closing a socket, closing a stream requires an await, but - # that's a 3.7 feature, so we are forced to try ... finally. - sock = None - s = None - begin_time = time.time() - try: - if stream: - # - # Verify that the socket is connected, as if it's not connected, - # it's not writable, and the polling in send_tcp() will time out or - # hang forever. - if isinstance(stream, trio.SSLStream): - tsock = stream.transport_stream.socket - else: - tsock = stream.socket - tsock.getpeername() - s = stream - else: - (af, destination, source) = \ - dns.query._destination_and_source(None, where, port, source, - source_port) - sock = socket_factory(af, socket.SOCK_STREAM, 0) - if source is not None: - await sock.bind(source) - await sock.connect(destination) - s = trio.SocketStream(sock) - sock = None - if tls and ssl_context is None: - ssl_context = ssl.create_default_context() - if server_hostname is None: - ssl_context.check_hostname = False - if ssl_context: - s = trio.SSLStream(s, ssl_context, - server_hostname=server_hostname) - await send_stream(s, wire) - (r, received_time) = await receive_stream(s, one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing) - if not q.is_response(r): - raise BadResponse - r.time = received_time - begin_time - return r - finally: - if sock: - sock.close() - if s and s != stream: - await s.aclose() diff --git a/dns/trio/query.pyi b/dns/trio/query.pyi deleted file mode 100644 index 0a5ab92..0000000 --- a/dns/trio/query.pyi +++ /dev/null @@ -1,33 +0,0 @@ -from typing import Optional, Dict, Any -from . import rdatatype, rdataclass, name, message - -# If the ssl import works, then -# -# error: Name 'ssl' already defined (by an import) -# -# is expected and can be ignored. -try: - import ssl -except ImportError: - class ssl: # type: ignore - SSLContext : Dict = {} - -import trio - -def udp(q : message.Message, where : str, port=53, - source : Optional[str] = None, source_port : Optional[int] = 0, - ignore_unexpected : Optional[bool] = False, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - sock : Optional[trio.socket.socket] = None) -> message.Message: - ... - -def stream(q : message.Message, where : str, tls : Optional[bool] = False, - port=53, source : Optional[str] = None, - source_port : Optional[int] = 0, - one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False, - stream : Optional[trio.abc.Stream] = None, - ssl_context: Optional[ssl.SSLContext] = None, - server_hostname: Optional[str] = None) -> message.Message: - ... diff --git a/dns/trio/resolver.pyi b/dns/trio/resolver.pyi deleted file mode 100644 index d84419b..0000000 --- a/dns/trio/resolver.pyi +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Union, Optional, List, Any, Dict -from .. import exception, rdataclass, name, rdatatype - -def resolve(qname : str, rdtype : Union[int,str] = 0, - rdclass : Union[int,str] = 0, - tcp=False, source=None, raise_on_no_answer=True, - source_port=0, search : Optional[bool]=None): - ... - -def resolve_address(self, ipaddr: str, *args: Any, **kwargs: Optional[Dict]): - ... - -def zone_for_name(name, rdclass : int = rdataclass.IN, tcp=False, - resolver : Optional[Resolver] = None): - ... - -class Resolver: - def __init__(self, filename : Optional[str] = '/etc/resolv.conf', - configure : Optional[bool] = True): - self.nameservers : List[str] - def resolve(self, qname : str, rdtype : Union[int,str] = rdatatype.A, - rdclass : Union[int,str] = rdataclass.IN, - tcp : bool = False, source : Optional[str] = None, - raise_on_no_answer=True, source_port : int = 0, - search : Optional[bool]=None): - ... diff --git a/pyproject.toml b/pyproject.toml index 44a1dd7..b33fe00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ requests-toolbelt = {version="^0.9.1", optional=true} requests = {version="^2.23.0", optional=true} idna = {version="^2.1", optional=true} cryptography = {version="^2.6", optional=true} -trio = {version="^0.14.0", optional=true} +trio = {version="^0.14", optional=true} [tool.poetry.dev-dependencies] mypy = "^0.770" @@ -28,6 +28,7 @@ doh = ['requests', 'requests-toolbelt'] idna = ['idna'] dnssec = ['cryptography'] trio = ['trio'] +curio = ['curio', 'sniffio'] [build-system] requires = ["poetry>=0.12"] @@ -50,7 +50,7 @@ direct manipulation of DNS zones, messages, names, and records.""", 'license' : 'ISC', 'url' : 'http://www.dnspython.org', 'packages' : ['dns', 'dns.rdtypes', 'dns.rdtypes.IN', 'dns.rdtypes.ANY', - 'dns.rdtypes.CH', 'dns.trio'], + 'dns.rdtypes.CH'], 'package_data' : {'dns': ['py.typed']}, 'download_url' : \ 'http://www.dnspython.org/kits/{}/dnspython-{}.tar.gz'.format(version, version), diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000..c09941b --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,219 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import asyncio +import socket +import unittest + +import dns.asyncbackend +import dns.asyncquery +import dns.asyncresolver +import dns.message +import dns.name +import dns.rdataclass +import dns.rdatatype +import dns.resolver + +# Some tests require the internet to be available to run, so let's +# skip those if it's not there. +_network_available = True +try: + socket.gethostbyname('dnspython.org') +except socket.gaierror: + _network_available = False + +@unittest.skipIf(not _network_available, "Internet not reachable") +class AsyncTests(unittest.TestCase): + + def setUp(self): + self.backend = dns.asyncbackend.set_default_backend('asyncio') + + def async_run(self, afunc): + return asyncio.run(afunc()) + + def testResolve(self): + async def run(): + answer = await dns.asyncresolver.resolve('dns.google.', 'A') + return set([rdata.address for rdata in answer]) + seen = self.async_run(run) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + + def testResolveAddress(self): + async def run(): + return await dns.asyncresolver.resolve_address('8.8.8.8') + answer = self.async_run(run) + dnsgoogle = dns.name.from_text('dns.google.') + self.assertEqual(answer[0].target, dnsgoogle) + + def testZoneForName1(self): + async def run(): + name = dns.name.from_text('www.dnspython.org.') + return await dns.asyncresolver.zone_for_name(name) + ezname = dns.name.from_text('dnspython.org.') + zname = self.async_run(run) + self.assertEqual(zname, ezname) + + def testZoneForName2(self): + async def run(): + name = dns.name.from_text('a.b.www.dnspython.org.') + return await dns.asyncresolver.zone_for_name(name) + ezname = dns.name.from_text('dnspython.org.') + zname = self.async_run(run) + self.assertEqual(zname, ezname) + + def testZoneForName3(self): + async def run(): + name = dns.name.from_text('dnspython.org.') + return await dns.asyncresolver.zone_for_name(name) + ezname = dns.name.from_text('dnspython.org.') + zname = self.async_run(run) + self.assertEqual(zname, ezname) + + def testZoneForName4(self): + def bad(): + name = dns.name.from_text('dnspython.org', None) + async def run(): + return await dns.asyncresolver.zone_for_name(name) + self.async_run(run) + self.assertRaises(dns.resolver.NotAbsolute, bad) + + def testQueryUDP(self): + qname = dns.name.from_text('dns.google.') + async def run(): + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.asyncquery.udp(q, '8.8.8.8') + response = self.async_run(run) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + + def testQueryUDPWithSocket(self): + qname = dns.name.from_text('dns.google.') + async def run(): + async with await self.backend.make_socket(socket.AF_INET, + socket.SOCK_DGRAM) as s: + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.asyncquery.udp(q, '8.8.8.8', sock=s) + response = self.async_run(run) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + + def testQueryTCP(self): + qname = dns.name.from_text('dns.google.') + async def run(): + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.asyncquery.tcp(q, '8.8.8.8') + response = self.async_run(run) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + + def testQueryTCPWithSocket(self): + qname = dns.name.from_text('dns.google.') + async def run(): + async with await self.backend.make_socket(socket.AF_INET, + socket.SOCK_STREAM, 0, + None, + ('8.8.8.8', 53)) as s: + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.asyncquery.tcp(q, '8.8.8.8', sock=s) + response = self.async_run(run) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + + # def testQueryTLS(self): + # qname = dns.name.from_text('dns.google.') + # async def run(): + # q = dns.message.make_query(qname, dns.rdatatype.A) + # return await dns.asyncquery.stream(q, '8.8.8.8', True) + # response = self.async_run(run) + # rrs = response.get_rrset(response.answer, qname, + # dns.rdataclass.IN, dns.rdatatype.A) + # self.assertTrue(rrs is not None) + # seen = set([rdata.address for rdata in rrs]) + # self.assertTrue('8.8.8.8' in seen) + # self.assertTrue('8.8.4.4' in seen) + + # def testQueryTLSWithSocket(self): + # qname = dns.name.from_text('dns.google.') + # async def run(): + # async with await trio.open_ssl_over_tcp_stream('8.8.8.8', + # 853) as s: + # q = dns.message.make_query(qname, dns.rdatatype.A) + # return await dns.asyncquery.stream(q, '8.8.8.8', stream=s) + # response = self.async_run(run) + # rrs = response.get_rrset(response.answer, qname, + # dns.rdataclass.IN, dns.rdatatype.A) + # self.assertTrue(rrs is not None) + # seen = set([rdata.address for rdata in rrs]) + # self.assertTrue('8.8.8.8' in seen) + # self.assertTrue('8.8.4.4' in seen) + + def testQueryUDPFallback(self): + qname = dns.name.from_text('.') + async def run(): + q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) + return await dns.asyncquery.udp_with_fallback(q, '8.8.8.8') + (_, tcp) = self.async_run(run) + self.assertTrue(tcp) + + def testQueryUDPFallbackNoFallback(self): + qname = dns.name.from_text('dns.google.') + async def run(): + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.asyncquery.udp_with_fallback(q, '8.8.8.8') + (_, tcp) = self.async_run(run) + self.assertFalse(tcp) + +try: + import trio + + class TrioAsyncTests(AsyncTests): + def setUp(self): + self.backend = dns.asyncbackend.set_default_backend('trio') + + def async_run(self, afunc): + return trio.run(afunc) +except ImportError: + pass + +try: + import curio + + class CurioAsyncTests(AsyncTests): + def setUp(self): + self.backend = dns.asyncbackend.set_default_backend('curio') + + def async_run(self, afunc): + return curio.run(afunc) +except ImportError: + pass diff --git a/tests/test_trio.py b/tests/test_trio.py deleted file mode 100644 index 8304a1f..0000000 --- a/tests/test_trio.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license - -# Copyright (C) 2003-2017 Nominum, Inc. -# -# Permission to use, copy, modify, and distribute this software and its -# documentation for any purpose with or without fee is hereby granted, -# provided that the above copyright notice and this permission notice -# appear in all copies. -# -# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES -# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF -# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR -# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES -# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN -# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT -# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - -import socket -import unittest - -try: - import trio - import trio.socket - - import dns.message - import dns.name - import dns.rdataclass - import dns.rdatatype - import dns.trio.query - import dns.trio.resolver - - # Some tests require the internet to be available to run, so let's - # skip those if it's not there. - _network_available = True - try: - socket.gethostbyname('dnspython.org') - except socket.gaierror: - _network_available = False - - @unittest.skipIf(not _network_available, "Internet not reachable") - class TrioTests(unittest.TestCase): - - def testResolve(self): - async def run(): - answer = await dns.trio.resolver.resolve('dns.google.', 'A') - return set([rdata.address for rdata in answer]) - seen = trio.run(run) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) - - def testResolveAddress(self): - async def run(): - return await dns.trio.resolver.resolve_address('8.8.8.8') - answer = trio.run(run) - dnsgoogle = dns.name.from_text('dns.google.') - self.assertEqual(answer[0].target, dnsgoogle) - - def testZoneForName1(self): - async def run(): - name = dns.name.from_text('www.dnspython.org.') - return await dns.trio.resolver.zone_for_name(name) - ezname = dns.name.from_text('dnspython.org.') - zname = trio.run(run) - self.assertEqual(zname, ezname) - - def testZoneForName2(self): - async def run(): - name = dns.name.from_text('a.b.www.dnspython.org.') - return await dns.trio.resolver.zone_for_name(name) - ezname = dns.name.from_text('dnspython.org.') - zname = trio.run(run) - self.assertEqual(zname, ezname) - - def testZoneForName3(self): - async def run(): - name = dns.name.from_text('dnspython.org.') - return await dns.trio.resolver.zone_for_name(name) - ezname = dns.name.from_text('dnspython.org.') - zname = trio.run(run) - self.assertEqual(zname, ezname) - - def testZoneForName4(self): - def bad(): - name = dns.name.from_text('dnspython.org', None) - async def run(): - return await dns.trio.resolver.zone_for_name(name) - trio.run(run) - self.assertRaises(dns.resolver.NotAbsolute, bad) - - def testQueryUDP(self): - qname = dns.name.from_text('dns.google.') - async def run(): - q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.trio.query.udp(q, '8.8.8.8') - response = trio.run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) - self.assertTrue(rrs is not None) - seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) - - def testQueryUDPWithSocket(self): - qname = dns.name.from_text('dns.google.') - async def run(): - with trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: - q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.trio.query.udp(q, '8.8.8.8', sock=s) - response = trio.run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) - self.assertTrue(rrs is not None) - seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) - - def testQueryTCP(self): - qname = dns.name.from_text('dns.google.') - async def run(): - q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.trio.query.stream(q, '8.8.8.8') - response = trio.run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) - self.assertTrue(rrs is not None) - seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) - - def testQueryTCPWithSocket(self): - qname = dns.name.from_text('dns.google.') - async def run(): - async with await trio.open_tcp_stream('8.8.8.8', 53) as s: - q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.trio.query.stream(q, '8.8.8.8', stream=s) - response = trio.run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) - self.assertTrue(rrs is not None) - seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) - - def testQueryTLS(self): - qname = dns.name.from_text('dns.google.') - async def run(): - q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.trio.query.stream(q, '8.8.8.8', True) - response = trio.run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) - self.assertTrue(rrs is not None) - seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) - - def testQueryTLSWithSocket(self): - qname = dns.name.from_text('dns.google.') - async def run(): - async with await trio.open_ssl_over_tcp_stream('8.8.8.8', - 853) as s: - q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.trio.query.stream(q, '8.8.8.8', stream=s) - response = trio.run(run) - rrs = response.get_rrset(response.answer, qname, - dns.rdataclass.IN, dns.rdatatype.A) - self.assertTrue(rrs is not None) - seen = set([rdata.address for rdata in rrs]) - self.assertTrue('8.8.8.8' in seen) - self.assertTrue('8.8.4.4' in seen) - - def testQueryUDPFallback(self): - qname = dns.name.from_text('.') - async def run(): - q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) - return await dns.trio.query.udp_with_fallback(q, '8.8.8.8') - (_, tcp) = trio.run(run) - self.assertTrue(tcp) - - def testQueryUDPFallbackNoFallback(self): - qname = dns.name.from_text('dns.google.') - async def run(): - q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.trio.query.udp_with_fallback(q, '8.8.8.8') - (_, tcp) = trio.run(run) - self.assertFalse(tcp) - -except ModuleNotFoundError: - pass |