summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-06-06 15:43:06 -0700
committerBob Halley <halley@dnspython.org>2020-06-06 15:43:06 -0700
commit30c2562cb781a4977bb9f3ed0c2a3e573e74c1db (patch)
treed4af4ded9bcf2ea326400abca55159d1aada6fed
parent8dbbd2e3df3b7dc846bf19ce1c2aa8872b83b51e (diff)
downloaddnspython-30c2562cb781a4977bb9f3ed0c2a3e573e74c1db.tar.gz
Allow a socket to be passed to udp(), and a stream to stream().
-rw-r--r--dns/trio/query.py109
-rw-r--r--dns/trio/query.pyi6
-rw-r--r--tests/test_trio.py44
3 files changed, 123 insertions, 36 deletions
diff --git a/dns/trio/query.py b/dns/trio/query.py
index 11af174..53b8fe5 100644
--- a/dns/trio/query.py
+++ b/dns/trio/query.py
@@ -2,6 +2,7 @@
"""trio async I/O library query support"""
+import contextlib
import socket
import struct
import time
@@ -27,7 +28,7 @@ socket_factory = trio.socket.socket
async def send_udp(sock, what, destination):
"""Asynchronously send a DNS message to the specified UDP socket.
- *sock*, a ``trio.socket``.
+ *sock*, a ``trio.socket.socket``.
*what*, a ``bytes`` or ``dns.message.Message``, the message to send.
@@ -49,7 +50,7 @@ async def receive_udp(sock, destination, ignore_unexpected=False,
ignore_trailing=False, raise_on_truncation=False):
"""Asynchronously read a DNS message from a UDP socket.
- *sock*, a ``trio.socket``.
+ *sock*, a ``trio.socket.socket``.
*destination*, a destination tuple appropriate for the address family
of the socket, specifying where the associated query was sent.
@@ -97,7 +98,8 @@ async def receive_udp(sock, destination, ignore_unexpected=False,
async def udp(q, where, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False,
- ignore_trailing=False, raise_on_truncation=False):
+ ignore_trailing=False, raise_on_truncation=False,
+ sock=None):
"""Asynchronously return the response obtained after sending a query
via UDP.
@@ -126,18 +128,27 @@ async def udp(q, where, port=53, source=None, source_port=0,
*raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
the TC bit is set.
+ *sock*, a ``trio.socket.socket``, or ``None``, the socket to use
+ for the query. If ``None``, the default, a socket is created. if
+ a socket is provided, the *source* and *source_port* are ignored.
+
Returns a ``dns.message.Message``.
+
"""
wire = q.to_wire()
(af, destination, source) = \
dns.query._destination_and_source(None, where, port, source,
source_port)
- with socket_factory(af, socket.SOCK_DGRAM, 0) as s:
- received_time = None
- sent_time = None
- if source is not None:
- await s.bind(source)
+ # We can use an ExitStack here as exiting a trio.socket.socket does
+ # not await.
+ with contextlib.ExitStack() as stack:
+ if sock:
+ s = sock
+ else:
+ s = stack.enter_context(socket_factory(af, socket.SOCK_DGRAM, 0))
+ if source is not None:
+ await s.bind(source)
(_, sent_time) = await send_udp(s, wire, destination)
(r, received_time) = await receive_udp(s, destination,
ignore_unexpected,
@@ -260,7 +271,7 @@ async def receive_stream(stream, one_rr_per_rrset=False, keyring=None,
async def stream(q, where, tls=False, port=None, source=None, source_port=0,
one_rr_per_rrset=False, ignore_trailing=False,
- ssl_context=None, server_hostname=None):
+ stream=None, ssl_context=None, server_hostname=None):
"""Return the response obtained after sending a query using TCP or TLS.
*q*, a ``dns.message.Message``, the query to send.
@@ -287,6 +298,12 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0,
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *stream*, a ``trio.abc.Stream``, or ``None``, the stream to use for
+ the query. If ``None``, the default, a stream is created. if a
+ socket is provided, it must be connected, and the *where*, *port*,
+ *tls*, *source*, *source_port*, *ssl_context*, and
+ *server_hostname* parameters are ignored.
+
*ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
a TLS connection. If ``None``, the default, creates one with the default
configuration. If this value is not ``None``, then the *tls* parameter
@@ -297,8 +314,8 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0,
SSL context is created, hostname checking will be disabled.
Returns a ``dns.message.Message``.
- """
+ """
if ssl_context is not None:
tls = True
if port is None:
@@ -307,28 +324,50 @@ async def stream(q, where, tls=False, port=None, source=None, source_port=0,
else:
port = 53
wire = q.to_wire()
- (af, destination, source) = \
- dns.query._destination_and_source(None, where, port, source,
- source_port)
- with socket_factory(af, socket.SOCK_STREAM, 0) as s:
- begin_time = time.time()
- if source is not None:
- await s.bind(source)
- await s.connect(destination)
- stream = trio.SocketStream(s)
- if tls and ssl_context is None:
- ssl_context = ssl.create_default_context()
- if server_hostname is None:
- ssl_context.check_hostname = False
- if ssl_context:
- stream = trio.SSLStream(stream, ssl_context,
- server_hostname=server_hostname)
- async with stream:
- await send_stream(stream, wire)
- (r, received_time) = await receive_stream(stream, one_rr_per_rrset,
- q.keyring, q.mac,
- ignore_trailing)
- if not q.is_response(r):
- raise BadResponse
- r.time = received_time - begin_time
- return r
+ # We'd like to be able to use an AsyncExitStack here, but that's a 3.7
+ # feature, so we are forced to try ... finally.
+ sock = None
+ s = None
+ begin_time = time.time()
+ try:
+ if stream:
+ #
+ # Verify that the socket is connected, as if it's not connected,
+ # it's not writable, and the polling in send_tcp() will time out or
+ # hang forever.
+ if isinstance(stream, trio.SSLStream):
+ tsock = stream.transport_stream.socket
+ else:
+ tsock = stream.socket
+ tsock.getpeername()
+ s = stream
+ else:
+ (af, destination, source) = \
+ dns.query._destination_and_source(None, where, port, source,
+ source_port)
+ sock = socket_factory(af, socket.SOCK_STREAM, 0)
+ if source is not None:
+ await sock.bind(source)
+ await sock.connect(destination)
+ s = trio.SocketStream(sock)
+ sock = None
+ if tls and ssl_context is None:
+ ssl_context = ssl.create_default_context()
+ if server_hostname is None:
+ ssl_context.check_hostname = False
+ if ssl_context:
+ s = trio.SSLStream(s, ssl_context,
+ server_hostname=server_hostname)
+ await send_stream(s, wire)
+ (r, received_time) = await receive_stream(s, one_rr_per_rrset,
+ q.keyring, q.mac,
+ ignore_trailing)
+ if not q.is_response(r):
+ raise BadResponse
+ r.time = received_time - begin_time
+ return r
+ finally:
+ if sock:
+ sock.close()
+ if s and s != stream:
+ await s.aclose()
diff --git a/dns/trio/query.pyi b/dns/trio/query.pyi
index c51f000..0a5ab92 100644
--- a/dns/trio/query.pyi
+++ b/dns/trio/query.pyi
@@ -12,11 +12,14 @@ except ImportError:
class ssl: # type: ignore
SSLContext : Dict = {}
+import trio
+
def udp(q : message.Message, where : str, port=53,
source : Optional[str] = None, source_port : Optional[int] = 0,
ignore_unexpected : Optional[bool] = False,
one_rr_per_rrset : Optional[bool] = False,
- ignore_trailing : Optional[bool] = False) -> message.Message:
+ ignore_trailing : Optional[bool] = False,
+ sock : Optional[trio.socket.socket] = None) -> message.Message:
...
def stream(q : message.Message, where : str, tls : Optional[bool] = False,
@@ -24,6 +27,7 @@ def stream(q : message.Message, where : str, tls : Optional[bool] = False,
source_port : Optional[int] = 0,
one_rr_per_rrset : Optional[bool] = False,
ignore_trailing : Optional[bool] = False,
+ stream : Optional[trio.abc.Stream] = None,
ssl_context: Optional[ssl.SSLContext] = None,
server_hostname: Optional[str] = None) -> message.Message:
...
diff --git a/tests/test_trio.py b/tests/test_trio.py
index d519844..8304a1f 100644
--- a/tests/test_trio.py
+++ b/tests/test_trio.py
@@ -20,6 +20,7 @@ import unittest
try:
import trio
+ import trio.socket
import dns.message
import dns.name
@@ -99,6 +100,20 @@ try:
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryUDPWithSocket(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ with trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.trio.query.udp(q, '8.8.8.8', sock=s)
+ response = trio.run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryTCP(self):
qname = dns.name.from_text('dns.google.')
async def run():
@@ -112,6 +127,20 @@ try:
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryTCPWithSocket(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ async with await trio.open_tcp_stream('8.8.8.8', 53) as s:
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
+ response = trio.run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryTLS(self):
qname = dns.name.from_text('dns.google.')
async def run():
@@ -125,6 +154,21 @@ try:
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryTLSWithSocket(self):
+ qname = dns.name.from_text('dns.google.')
+ async def run():
+ async with await trio.open_ssl_over_tcp_stream('8.8.8.8',
+ 853) as s:
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ return await dns.trio.query.stream(q, '8.8.8.8', stream=s)
+ response = trio.run(run)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryUDPFallback(self):
qname = dns.name.from_text('.')
async def run():