summaryrefslogtreecommitdiff
path: root/dns
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2020-07-17 16:37:53 -0700
committerBrian Wellington <bwelling@xbill.org>2020-07-17 16:39:40 -0700
commit4c0fe5541e36e06fccf1a85028bc289d3070374e (patch)
treed101ba84db385c86e16d2e9c26d0454e8bb69911 /dns
parentac3f05d0d5b61ba46444cb9e66deb7357ff8522e (diff)
downloaddnspython-4c0fe5541e36e06fccf1a85028bc289d3070374e.tar.gz
Changes to blocking model.
Before this change, the synchronous code would check sockets for readability or writability before doing nonblocking read or write. This changes them to attempt the read or write first, and then block if the operation could not complete. This also removes the no-longer-needed getpeername() call in tcp(), which was needed to deal with the case where an unconnected socket was passed in; waiting for writability would block rather than immediately return an error. By attempting the write first, we get the error immediately.
Diffstat (limited to 'dns')
-rw-r--r--dns/query.py67
1 files changed, 40 insertions, 27 deletions
diff --git a/dns/query.py b/dns/query.py
index eb82771..dbf9f77 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -342,6 +342,33 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
raise BadResponse
return r
+def _udp_recv(sock, max_size, expiration):
+ """Reads a datagram from the socket.
+ A Timeout exception will be raised if the operation is not completed
+ by the expiration time.
+ """
+ while True:
+ try:
+ return sock.recvfrom(max_size)
+ except BlockingIOError:
+ _wait_for_readable(sock, expiration)
+
+
+def _udp_send(sock, data, destination, expiration):
+ """Sends the specified datagram to destination over the socket.
+ A Timeout exception will be raised if the operation is not completed
+ by the expiration time.
+ """
+ while True:
+ try:
+ if destination:
+ return sock.sendto(data, destination)
+ else:
+ return sock.send(data)
+ except BlockingIOError:
+ _wait_for_writable(sock, expiration)
+
+
def send_udp(sock, what, destination, expiration=None):
"""Send a DNS message to the specified UDP socket.
@@ -361,9 +388,8 @@ def send_udp(sock, what, destination, expiration=None):
if isinstance(what, dns.message.Message):
what = what.to_wire()
- _wait_for_writable(sock, expiration)
sent_time = time.time()
- n = sock.sendto(what, destination)
+ n = _udp_send(sock, what, destination, expiration)
return (n, sent_time)
@@ -413,9 +439,8 @@ def receive_udp(sock, destination=None, expiration=None,
"""
wire = b''
- while 1:
- _wait_for_readable(sock, expiration)
- (wire, from_address) = sock.recvfrom(65535)
+ while True:
+ (wire, from_address) = _udp_recv(sock, 65535, expiration)
if _matches_destination(sock.family, from_address, destination,
ignore_unexpected):
break
@@ -553,18 +578,16 @@ def _net_read(sock, count, expiration):
"""
s = b''
while count > 0:
- _wait_for_readable(sock, expiration)
try:
n = sock.recv(count)
- except ssl.SSLWantReadError: # pragma: no cover
- continue
+ if n == b'':
+ raise EOFError
+ count -= len(n)
+ s += n
+ except (BlockingIOError, ssl.SSLWantReadError):
+ _wait_for_readable(sock, expiration)
except ssl.SSLWantWriteError: # pragma: no cover
_wait_for_writable(sock, expiration)
- continue
- if n == b'':
- raise EOFError
- count = count - len(n)
- s = s + n
return s
@@ -576,14 +599,12 @@ def _net_write(sock, data, expiration):
current = 0
l = len(data)
while current < l:
- _wait_for_writable(sock, expiration)
try:
current += sock.send(data[current:])
+ except (BlockingIOError, ssl.SSLWantWriteError):
+ _wait_for_writable(sock, expiration)
except ssl.SSLWantReadError: # pragma: no cover
_wait_for_readable(sock, expiration)
- continue
- except ssl.SSLWantWriteError: # pragma: no cover
- continue
def send_tcp(sock, what, expiration=None):
@@ -607,7 +628,6 @@ def send_tcp(sock, what, expiration=None):
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = struct.pack("!H", l) + what
- _wait_for_writable(sock, expiration)
sent_time = time.time()
_net_write(sock, tcpmsg, expiration)
return (len(tcpmsg), sent_time)
@@ -697,11 +717,6 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
(begin_time, expiration) = _compute_times(timeout)
with contextlib.ExitStack() as stack:
if sock:
- #
- # Verify that the socket is connected, as if it's not connected,
- # it's not writable, and the polling in send_tcp() will time out or
- # hang forever.
- sock.getpeername()
s = sock
else:
(af, destination, source) = _destination_and_source(where, port,
@@ -881,8 +896,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
_connect(s, destination, expiration)
l = len(wire)
if use_udp:
- _wait_for_writable(s, expiration)
- s.send(wire)
+ _udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", l) + wire
_net_write(s, tcpmsg, expiration)
@@ -903,8 +917,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
(expiration is not None and mexpiration > expiration):
mexpiration = expiration
if use_udp:
- _wait_for_readable(s, expiration)
- (wire, from_address) = s.recvfrom(65535)
+ (wire, from_address) = _udp_recv(s, 65535, expiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)