diff options
author | Bob Halley <halley@dnspython.org> | 2020-06-06 15:43:06 -0700 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2020-06-06 15:43:06 -0700 |
commit | 30c2562cb781a4977bb9f3ed0c2a3e573e74c1db (patch) | |
tree | d4af4ded9bcf2ea326400abca55159d1aada6fed | |
parent | 8dbbd2e3df3b7dc846bf19ce1c2aa8872b83b51e (diff) | |
download | dnspython-30c2562cb781a4977bb9f3ed0c2a3e573e74c1db.tar.gz |
Allow a socket to be passed to udp(), and a stream to stream().
-rw-r--r-- | dns/trio/query.py | 109 | ||||
-rw-r--r-- | dns/trio/query.pyi | 6 | ||||
-rw-r--r-- | tests/test_trio.py | 44 |
3 files changed, 123 insertions, 36 deletions
diff --git a/dns/trio/query.py b/dns/trio/query.py index 11af174..53b8fe5 100644 --- a/dns/trio/query.py +++ b/dns/trio/query.py @@ -2,6 +2,7 @@ """trio async I/O library query support""" +import contextlib import socket import struct import time @@ -27,7 +28,7 @@ socket_factory = trio.socket.socket async def send_udp(sock, what, destination): """Asynchronously send a DNS message to the specified UDP socket. - *sock*, a ``trio.socket``. + *sock*, a ``trio.socket.socket``. *what*, a ``bytes`` or ``dns.message.Message``, the message to send. @@ -49,7 +50,7 @@ async def receive_udp(sock, destination, ignore_unexpected=False, ignore_trailing=False, raise_on_truncation=False): """Asynchronously read a DNS message from a UDP socket. - *sock*, a ``trio.socket``. + *sock*, a ``trio.socket.socket``. *destination*, a destination tuple appropriate for the address family of the socket, specifying where the associated query was sent. @@ -97,7 +98,8 @@ async def receive_udp(sock, destination, ignore_unexpected=False, async def udp(q, where, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False): + ignore_trailing=False, raise_on_truncation=False, + sock=None): """Asynchronously return the response obtained after sending a query via UDP. @@ -126,18 +128,27 @@ async def udp(q, where, port=53, source=None, source_port=0, *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the TC bit is set. + *sock*, a ``trio.socket.socket``, or ``None``, the socket to use + for the query. If ``None``, the default, a socket is created. if + a socket is provided, the *source* and *source_port* are ignored. + Returns a ``dns.message.Message``. + """ wire = q.to_wire() (af, destination, source) = \ dns.query._destination_and_source(None, where, port, source, source_port) - with socket_factory(af, socket.SOCK_DGRAM, 0) as s: - received_time = None - sent_time = None - if source is not None: - await s.bind(source) + # We can use an ExitStack here as exiting a trio.socket.socket does + # not await. + with contextlib.ExitStack() as stack: + if sock: + s = sock + else: + s = stack.enter_context(socket_factory(af, socket.SOCK_DGRAM, 0)) + if source is not None: + await s.bind(source) (_, sent_time) = await send_udp(s, wire, destination) (r, received_time) = await receive_udp(s, destination, ignore_unexpected, @@ -260,7 +271,7 @@ async def receive_stream(stream, one_rr_per_rrset=False, keyring=None, async def stream(q, where, tls=False, port=None, source=None, source_port=0, one_rr_per_rrset=False, ignore_trailing=False, - ssl_context=None, server_hostname=None): + stream=None, ssl_context=None, server_hostname=None): """Return the response obtained after sending a query using TCP or TLS. *q*, a ``dns.message.Message``, the query to send. @@ -287,6 +298,12 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *stream*, a ``trio.abc.Stream``, or ``None``, the stream to use for + the query. If ``None``, the default, a stream is created. if a + socket is provided, it must be connected, and the *where*, *port*, + *tls*, *source*, *source_port*, *ssl_context*, and + *server_hostname* parameters are ignored. + *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing a TLS connection. If ``None``, the default, creates one with the default configuration. If this value is not ``None``, then the *tls* parameter @@ -297,8 +314,8 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0, SSL context is created, hostname checking will be disabled. Returns a ``dns.message.Message``. - """ + """ if ssl_context is not None: tls = True if port is None: @@ -307,28 +324,50 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0, else: port = 53 wire = q.to_wire() - (af, destination, source) = \ - dns.query._destination_and_source(None, where, port, source, - source_port) - with socket_factory(af, socket.SOCK_STREAM, 0) as s: - begin_time = time.time() - if source is not None: - await s.bind(source) - await s.connect(destination) - stream = trio.SocketStream(s) - if tls and ssl_context is None: - ssl_context = ssl.create_default_context() - if server_hostname is None: - ssl_context.check_hostname = False - if ssl_context: - stream = trio.SSLStream(stream, ssl_context, - server_hostname=server_hostname) - async with stream: - await send_stream(stream, wire) - (r, received_time) = await receive_stream(stream, one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing) - if not q.is_response(r): - raise BadResponse - r.time = received_time - begin_time - return r + # We'd like to be able to use an AsyncExitStack here, but that's a 3.7 + # feature, so we are forced to try ... finally. + sock = None + s = None + begin_time = time.time() + try: + if stream: + # + # Verify that the socket is connected, as if it's not connected, + # it's not writable, and the polling in send_tcp() will time out or + # hang forever. + if isinstance(stream, trio.SSLStream): + tsock = stream.transport_stream.socket + else: + tsock = stream.socket + tsock.getpeername() + s = stream + else: + (af, destination, source) = \ + dns.query._destination_and_source(None, where, port, source, + source_port) + sock = socket_factory(af, socket.SOCK_STREAM, 0) + if source is not None: + await sock.bind(source) + await sock.connect(destination) + s = trio.SocketStream(sock) + sock = None + if tls and ssl_context is None: + ssl_context = ssl.create_default_context() + if server_hostname is None: + ssl_context.check_hostname = False + if ssl_context: + s = trio.SSLStream(s, ssl_context, + server_hostname=server_hostname) + await send_stream(s, wire) + (r, received_time) = await receive_stream(s, one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing) + if not q.is_response(r): + raise BadResponse + r.time = received_time - begin_time + return r + finally: + if sock: + sock.close() + if s and s != stream: + await s.aclose() diff --git a/dns/trio/query.pyi b/dns/trio/query.pyi index c51f000..0a5ab92 100644 --- a/dns/trio/query.pyi +++ b/dns/trio/query.pyi @@ -12,11 +12,14 @@ except ImportError: class ssl: # type: ignore SSLContext : Dict = {} +import trio + def udp(q : message.Message, where : str, port=53, source : Optional[str] = None, source_port : Optional[int] = 0, ignore_unexpected : Optional[bool] = False, one_rr_per_rrset : Optional[bool] = False, - ignore_trailing : Optional[bool] = False) -> message.Message: + ignore_trailing : Optional[bool] = False, + sock : Optional[trio.socket.socket] = None) -> message.Message: ... def stream(q : message.Message, where : str, tls : Optional[bool] = False, @@ -24,6 +27,7 @@ def stream(q : message.Message, where : str, tls : Optional[bool] = False, source_port : Optional[int] = 0, one_rr_per_rrset : Optional[bool] = False, ignore_trailing : Optional[bool] = False, + stream : Optional[trio.abc.Stream] = None, ssl_context: Optional[ssl.SSLContext] = None, server_hostname: Optional[str] = None) -> message.Message: ... diff --git a/tests/test_trio.py b/tests/test_trio.py index d519844..8304a1f 100644 --- a/tests/test_trio.py +++ b/tests/test_trio.py @@ -20,6 +20,7 @@ import unittest try: import trio + import trio.socket import dns.message import dns.name @@ -99,6 +100,20 @@ try: self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) + def testQueryUDPWithSocket(self): + qname = dns.name.from_text('dns.google.') + async def run(): + with trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.trio.query.udp(q, '8.8.8.8', sock=s) + response = trio.run(run) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + def testQueryTCP(self): qname = dns.name.from_text('dns.google.') async def run(): @@ -112,6 +127,20 @@ try: self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) + def testQueryTCPWithSocket(self): + qname = dns.name.from_text('dns.google.') + async def run(): + async with await trio.open_tcp_stream('8.8.8.8', 53) as s: + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.trio.query.stream(q, '8.8.8.8', stream=s) + response = trio.run(run) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + def testQueryTLS(self): qname = dns.name.from_text('dns.google.') async def run(): @@ -125,6 +154,21 @@ try: self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) + def testQueryTLSWithSocket(self): + qname = dns.name.from_text('dns.google.') + async def run(): + async with await trio.open_ssl_over_tcp_stream('8.8.8.8', + 853) as s: + q = dns.message.make_query(qname, dns.rdatatype.A) + return await dns.trio.query.stream(q, '8.8.8.8', stream=s) + response = trio.run(run) + rrs = response.get_rrset(response.answer, qname, + dns.rdataclass.IN, dns.rdatatype.A) + self.assertTrue(rrs is not None) + seen = set([rdata.address for rdata in rrs]) + self.assertTrue('8.8.8.8' in seen) + self.assertTrue('8.8.4.4' in seen) + def testQueryUDPFallback(self): qname = dns.name.from_text('.') async def run(): |