summaryrefslogtreecommitdiff
path: root/dns/query.py
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2020-07-08 15:11:19 -0700
committerBrian Wellington <bwelling@xbill.org>2020-07-08 15:15:31 -0700
commit7a5e59707b395454db2cb650371bbc2e800e7be4 (patch)
treebc49dbae99179d3afee50e6e25f9ecb80800adc7 /dns/query.py
parentcce293110167a2e8e20fdf6cdf2d15b0b9ca6679 (diff)
downloaddnspython-7a5e59707b395454db2cb650371bbc2e800e7be4.tar.gz
Add support for receiving UDP queries.
The existing receive_udp() methods are only usable for receiving responses, as they require an expected destination and check that the message is from that destination. This change makes the expected destination (and hence the check) optional, and returns the address that the message was received from (in the sync case, this is only done if no destination is provided, for backwards compatibility). New tests are added, which required adding generic getsockname() support to the async backends.
Diffstat (limited to 'dns/query.py')
-rw-r--r--dns/query.py44
1 files changed, 32 insertions, 12 deletions
diff --git a/dns/query.py b/dns/query.py
index 13c8246..7df565d 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -201,6 +201,21 @@ def _addresses_equal(af, a1, a2):
return n1 == n2 and a1[1:] == a2[1:]
+def _matches_destination(af, from_address, destination, ignore_unexpected):
+ # Check that from_address is appropriate for a response to a query
+ # sent to destination.
+ if not destination:
+ return True
+ if _addresses_equal(af, from_address, destination) or \
+ (dns.inet.is_multicast(destination[0]) and
+ from_address[1:] == destination[1:]):
+ return True
+ elif ignore_unexpected:
+ return False
+ raise UnexpectedSource(f'got a response from {from_address} instead of '
+ f'{destination}')
+
+
def _destination_and_source(where, port, source, source_port,
where_must_be_address=True):
# Apply defaults and compute destination and source tuples
@@ -397,7 +412,7 @@ def send_udp(sock, what, destination, expiration=None):
return (n, sent_time)
-def receive_udp(sock, destination, expiration=None,
+def receive_udp(sock, destination=None, expiration=None,
ignore_unexpected=False, one_rr_per_rrset=False,
keyring=None, request_mac=b'', ignore_trailing=False,
raise_on_truncation=False):
@@ -406,7 +421,9 @@ def receive_udp(sock, destination, expiration=None,
*sock*, a ``socket``.
*destination*, a destination tuple appropriate for the address family
- of the socket, specifying where the associated query was sent.
+ of the socket, specifying where the message is expected to arrive from.
+ When receiving a response, this would be where the associated query was
+ sent.
*expiration*, a ``float`` or ``None``, the absolute time at which
a timeout exception should be raised. If ``None``, no timeout will
@@ -431,28 +448,31 @@ def receive_udp(sock, destination, expiration=None,
Raises if the message is malformed, if network errors occur, of if
there is a timeout.
- Returns a ``(dns.message.Message, float)`` tuple of the received message
- and the received time.
+ If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
+ tuple of the received message and the received time.
+
+ If *destination* is ``None``, returns a
+ ``(dns.message.Message, float, tuple)``
+ tuple of the received message, the received time, and the address where
+ the message arrived from.
"""
wire = b''
while 1:
_wait_for_readable(sock, expiration)
(wire, from_address) = sock.recvfrom(65535)
- if _addresses_equal(sock.family, from_address, destination) or \
- (dns.inet.is_multicast(destination[0]) and
- from_address[1:] == destination[1:]):
+ if _matches_destination(sock.family, from_address, destination,
+ ignore_unexpected):
break
- if not ignore_unexpected:
- raise UnexpectedSource('got a response from '
- '%s instead of %s' % (from_address,
- destination))
received_time = time.time()
r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac,
one_rr_per_rrset=one_rr_per_rrset,
ignore_trailing=ignore_trailing,
raise_on_truncation=raise_on_truncation)
- return (r, received_time)
+ if destination:
+ return (r, received_time)
+ else:
+ return (r, received_time, from_address)
def udp(q, where, timeout=None, port=53, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False,