diff options
| author | Brian Wellington <bwelling@xbill.org> | 2020-07-08 15:11:19 -0700 |
|---|---|---|
| committer | Brian Wellington <bwelling@xbill.org> | 2020-07-08 15:15:31 -0700 |
| commit | 7a5e59707b395454db2cb650371bbc2e800e7be4 (patch) | |
| tree | bc49dbae99179d3afee50e6e25f9ecb80800adc7 /dns | |
| parent | cce293110167a2e8e20fdf6cdf2d15b0b9ca6679 (diff) | |
| download | dnspython-7a5e59707b395454db2cb650371bbc2e800e7be4.tar.gz | |
Add support for receiving UDP queries.
The existing receive_udp() methods are only usable for receiving
responses, as they require an expected destination and check that the
message is from that destination.
This change makes the expected destination (and hence the check)
optional, and returns the address that the message was received from (in
the sync case, this is only done if no destination is provided, for
backwards compatibility).
New tests are added, which required adding generic getsockname() support
to the async backends.
Diffstat (limited to 'dns')
| -rw-r--r-- | dns/_asyncio_backend.py | 6 | ||||
| -rw-r--r-- | dns/_curio_backend.py | 6 | ||||
| -rw-r--r-- | dns/_trio_backend.py | 9 | ||||
| -rw-r--r-- | dns/asyncquery.py | 36 | ||||
| -rw-r--r-- | dns/query.py | 44 |
5 files changed, 69 insertions, 32 deletions
diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index ba7c2e7..3af34ff 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -75,6 +75,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.transport.get_extra_info('peername') + async def getsockname(self): + return self.transport.get_extra_info('sockname') + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, af, reader, writer): @@ -102,6 +105,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.writer.get_extra_info('peername') + async def getsockname(self): + return self.writer.get_extra_info('sockname') + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/_curio_backend.py b/dns/_curio_backend.py index dca966d..300e1b8 100644 --- a/dns/_curio_backend.py +++ b/dns/_curio_backend.py @@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, socket): @@ -65,6 +68,9 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/_trio_backend.py b/dns/_trio_backend.py index 0f1378f..92ea879 100644 --- a/dns/_trio_backend.py +++ b/dns/_trio_backend.py @@ -43,6 +43,9 @@ class DatagramSocket(dns._asyncbackend.DatagramSocket): async def getpeername(self): return self.socket.getpeername() + async def getsockname(self): + return self.socket.getsockname() + class StreamSocket(dns._asyncbackend.DatagramSocket): def __init__(self, family, stream, tls=False): @@ -69,6 +72,12 @@ class StreamSocket(dns._asyncbackend.DatagramSocket): else: return self.stream.socket.getpeername() + async def getsockname(self): + if self.tls: + return self.stream.transport_stream.socket.getsockname() + else: + return self.stream.socket.getsockname() + class Backend(dns._asyncbackend.Backend): def name(self): diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 4afe7bc..b792648 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -30,8 +30,7 @@ import dns.rcode import dns.rdataclass import dns.rdatatype -from dns.query import _addresses_equal, _compute_times, UnexpectedSource, \ - BadResponse, ssl +from dns.query import _compute_times, _matches_destination, BadResponse, ssl # for brevity @@ -87,7 +86,7 @@ async def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -async def receive_udp(sock, destination, expiration=None, +async def receive_udp(sock, destination=None, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False, raise_on_truncation=False): @@ -96,7 +95,9 @@ async def receive_udp(sock, destination, expiration=None, *sock*, a ``dns.asyncbackend.DatagramSocket``. *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will @@ -121,27 +122,22 @@ async def receive_udp(sock, destination, expiration=None, Raises if the message is malformed, if network errors occur, of if there is a timeout. - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + Returns a ``(dns.message.Message, float, tuple)`` tuple of the received + message, the received time, and the address where the message arrived from. """ wire = b'' while 1: (wire, from_address) = await sock.recvfrom(65535, _timeout(expiration)) - if _addresses_equal(sock.family, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation) - return (r, received_time) + return (r, received_time, from_address) async def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, @@ -202,12 +198,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, stuple = _source_tuple(af, source, source_port) s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) await send_udp(s, wire, destination, expiration) - (r, received_time) = await receive_udp(s, destination, expiration, - ignore_unexpected, - one_rr_per_rrset, - q.keyring, q.mac, - ignore_trailing, - raise_on_truncation) + (r, received_time, _) = await receive_udp(s, destination, expiration, + ignore_unexpected, + one_rr_per_rrset, + q.keyring, q.mac, + ignore_trailing, + raise_on_truncation) r.time = received_time - begin_time if not q.is_response(r): raise BadResponse diff --git a/dns/query.py b/dns/query.py index 13c8246..7df565d 100644 --- a/dns/query.py +++ b/dns/query.py @@ -201,6 +201,21 @@ def _addresses_equal(af, a1, a2): return n1 == n2 and a1[1:] == a2[1:] +def _matches_destination(af, from_address, destination, ignore_unexpected): + # Check that from_address is appropriate for a response to a query + # sent to destination. + if not destination: + return True + if _addresses_equal(af, from_address, destination) or \ + (dns.inet.is_multicast(destination[0]) and + from_address[1:] == destination[1:]): + return True + elif ignore_unexpected: + return False + raise UnexpectedSource(f'got a response from {from_address} instead of ' + f'{destination}') + + def _destination_and_source(where, port, source, source_port, where_must_be_address=True): # Apply defaults and compute destination and source tuples @@ -397,7 +412,7 @@ def send_udp(sock, what, destination, expiration=None): return (n, sent_time) -def receive_udp(sock, destination, expiration=None, +def receive_udp(sock, destination=None, expiration=None, ignore_unexpected=False, one_rr_per_rrset=False, keyring=None, request_mac=b'', ignore_trailing=False, raise_on_truncation=False): @@ -406,7 +421,9 @@ def receive_udp(sock, destination, expiration=None, *sock*, a ``socket``. *destination*, a destination tuple appropriate for the address family - of the socket, specifying where the associated query was sent. + of the socket, specifying where the message is expected to arrive from. + When receiving a response, this would be where the associated query was + sent. *expiration*, a ``float`` or ``None``, the absolute time at which a timeout exception should be raised. If ``None``, no timeout will @@ -431,28 +448,31 @@ def receive_udp(sock, destination, expiration=None, Raises if the message is malformed, if network errors occur, of if there is a timeout. - Returns a ``(dns.message.Message, float)`` tuple of the received message - and the received time. + If *destination* is not ``None``, returns a ``(dns.message.Message, float)`` + tuple of the received message and the received time. + + If *destination* is ``None``, returns a + ``(dns.message.Message, float, tuple)`` + tuple of the received message, the received time, and the address where + the message arrived from. """ wire = b'' while 1: _wait_for_readable(sock, expiration) (wire, from_address) = sock.recvfrom(65535) - if _addresses_equal(sock.family, from_address, destination) or \ - (dns.inet.is_multicast(destination[0]) and - from_address[1:] == destination[1:]): + if _matches_destination(sock.family, from_address, destination, + ignore_unexpected): break - if not ignore_unexpected: - raise UnexpectedSource('got a response from ' - '%s instead of %s' % (from_address, - destination)) received_time = time.time() r = dns.message.from_wire(wire, keyring=keyring, request_mac=request_mac, one_rr_per_rrset=one_rr_per_rrset, ignore_trailing=ignore_trailing, raise_on_truncation=raise_on_truncation) - return (r, received_time) + if destination: + return (r, received_time) + else: + return (r, received_time, from_address) def udp(q, where, timeout=None, port=53, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False, ignore_trailing=False, |
