summaryrefslogtreecommitdiff
path: root/dns/query.py
diff options
context:
space:
mode:
Diffstat (limited to 'dns/query.py')
-rw-r--r--dns/query.py67
1 files changed, 40 insertions, 27 deletions
diff --git a/dns/query.py b/dns/query.py
index eb82771..dbf9f77 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -342,6 +342,33 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0,
raise BadResponse
return r
+def _udp_recv(sock, max_size, expiration):
+ """Reads a datagram from the socket.
+ A Timeout exception will be raised if the operation is not completed
+ by the expiration time.
+ """
+ while True:
+ try:
+ return sock.recvfrom(max_size)
+ except BlockingIOError:
+ _wait_for_readable(sock, expiration)
+
+
+def _udp_send(sock, data, destination, expiration):
+ """Sends the specified datagram to destination over the socket.
+ A Timeout exception will be raised if the operation is not completed
+ by the expiration time.
+ """
+ while True:
+ try:
+ if destination:
+ return sock.sendto(data, destination)
+ else:
+ return sock.send(data)
+ except BlockingIOError:
+ _wait_for_writable(sock, expiration)
+
+
def send_udp(sock, what, destination, expiration=None):
"""Send a DNS message to the specified UDP socket.
@@ -361,9 +388,8 @@ def send_udp(sock, what, destination, expiration=None):
if isinstance(what, dns.message.Message):
what = what.to_wire()
- _wait_for_writable(sock, expiration)
sent_time = time.time()
- n = sock.sendto(what, destination)
+ n = _udp_send(sock, what, destination, expiration)
return (n, sent_time)
@@ -413,9 +439,8 @@ def receive_udp(sock, destination=None, expiration=None,
"""
wire = b''
- while 1:
- _wait_for_readable(sock, expiration)
- (wire, from_address) = sock.recvfrom(65535)
+ while True:
+ (wire, from_address) = _udp_recv(sock, 65535, expiration)
if _matches_destination(sock.family, from_address, destination,
ignore_unexpected):
break
@@ -553,18 +578,16 @@ def _net_read(sock, count, expiration):
"""
s = b''
while count > 0:
- _wait_for_readable(sock, expiration)
try:
n = sock.recv(count)
- except ssl.SSLWantReadError: # pragma: no cover
- continue
+ if n == b'':
+ raise EOFError
+ count -= len(n)
+ s += n
+ except (BlockingIOError, ssl.SSLWantReadError):
+ _wait_for_readable(sock, expiration)
except ssl.SSLWantWriteError: # pragma: no cover
_wait_for_writable(sock, expiration)
- continue
- if n == b'':
- raise EOFError
- count = count - len(n)
- s = s + n
return s
@@ -576,14 +599,12 @@ def _net_write(sock, data, expiration):
current = 0
l = len(data)
while current < l:
- _wait_for_writable(sock, expiration)
try:
current += sock.send(data[current:])
+ except (BlockingIOError, ssl.SSLWantWriteError):
+ _wait_for_writable(sock, expiration)
except ssl.SSLWantReadError: # pragma: no cover
_wait_for_readable(sock, expiration)
- continue
- except ssl.SSLWantWriteError: # pragma: no cover
- continue
def send_tcp(sock, what, expiration=None):
@@ -607,7 +628,6 @@ def send_tcp(sock, what, expiration=None):
# avoid writev() or doing a short write that would get pushed
# onto the net
tcpmsg = struct.pack("!H", l) + what
- _wait_for_writable(sock, expiration)
sent_time = time.time()
_net_write(sock, tcpmsg, expiration)
return (len(tcpmsg), sent_time)
@@ -697,11 +717,6 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
(begin_time, expiration) = _compute_times(timeout)
with contextlib.ExitStack() as stack:
if sock:
- #
- # Verify that the socket is connected, as if it's not connected,
- # it's not writable, and the polling in send_tcp() will time out or
- # hang forever.
- sock.getpeername()
s = sock
else:
(af, destination, source) = _destination_and_source(where, port,
@@ -881,8 +896,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
_connect(s, destination, expiration)
l = len(wire)
if use_udp:
- _wait_for_writable(s, expiration)
- s.send(wire)
+ _udp_send(s, wire, None, expiration)
else:
tcpmsg = struct.pack("!H", l) + wire
_net_write(s, tcpmsg, expiration)
@@ -903,8 +917,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
(expiration is not None and mexpiration > expiration):
mexpiration = expiration
if use_udp:
- _wait_for_readable(s, expiration)
- (wire, from_address) = s.recvfrom(65535)
+ (wire, from_address) = _udp_recv(s, 65535, expiration)
else:
ldata = _net_read(s, 2, mexpiration)
(l,) = struct.unpack("!H", ldata)