diff options
Diffstat (limited to 'dns/query.py')
-rw-r--r-- | dns/query.py | 67 |
1 files changed, 40 insertions, 27 deletions
diff --git a/dns/query.py b/dns/query.py index eb82771..dbf9f77 100644 --- a/dns/query.py +++ b/dns/query.py @@ -342,6 +342,33 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, raise BadResponse return r +def _udp_recv(sock, max_size, expiration): + """Reads a datagram from the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + while True: + try: + return sock.recvfrom(max_size) + except BlockingIOError: + _wait_for_readable(sock, expiration) + + +def _udp_send(sock, data, destination, expiration): + """Sends the specified datagram to destination over the socket. + A Timeout exception will be raised if the operation is not completed + by the expiration time. + """ + while True: + try: + if destination: + return sock.sendto(data, destination) + else: + return sock.send(data) + except BlockingIOError: + _wait_for_writable(sock, expiration) + + def send_udp(sock, what, destination, expiration=None): """Send a DNS message to the specified UDP socket. @@ -361,9 +388,8 @@ def send_udp(sock, what, destination, expiration=None): if isinstance(what, dns.message.Message): what = what.to_wire() - _wait_for_writable(sock, expiration) sent_time = time.time() - n = sock.sendto(what, destination) + n = _udp_send(sock, what, destination, expiration) return (n, sent_time) @@ -413,9 +439,8 @@ def receive_udp(sock, destination=None, expiration=None, """ wire = b'' - while 1: - _wait_for_readable(sock, expiration) - (wire, from_address) = sock.recvfrom(65535) + while True: + (wire, from_address) = _udp_recv(sock, 65535, expiration) if _matches_destination(sock.family, from_address, destination, ignore_unexpected): break @@ -553,18 +578,16 @@ def _net_read(sock, count, expiration): """ s = b'' while count > 0: - _wait_for_readable(sock, expiration) try: n = sock.recv(count) - except ssl.SSLWantReadError: # pragma: no cover - continue + if n == b'': + raise EOFError + count -= len(n) + s += n + except (BlockingIOError, ssl.SSLWantReadError): + _wait_for_readable(sock, expiration) except ssl.SSLWantWriteError: # pragma: no cover _wait_for_writable(sock, expiration) - continue - if n == b'': - raise EOFError - count = count - len(n) - s = s + n return s @@ -576,14 +599,12 @@ def _net_write(sock, data, expiration): current = 0 l = len(data) while current < l: - _wait_for_writable(sock, expiration) try: current += sock.send(data[current:]) + except (BlockingIOError, ssl.SSLWantWriteError): + _wait_for_writable(sock, expiration) except ssl.SSLWantReadError: # pragma: no cover _wait_for_readable(sock, expiration) - continue - except ssl.SSLWantWriteError: # pragma: no cover - continue def send_tcp(sock, what, expiration=None): @@ -607,7 +628,6 @@ def send_tcp(sock, what, expiration=None): # avoid writev() or doing a short write that would get pushed # onto the net tcpmsg = struct.pack("!H", l) + what - _wait_for_writable(sock, expiration) sent_time = time.time() _net_write(sock, tcpmsg, expiration) return (len(tcpmsg), sent_time) @@ -697,11 +717,6 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, (begin_time, expiration) = _compute_times(timeout) with contextlib.ExitStack() as stack: if sock: - # - # Verify that the socket is connected, as if it's not connected, - # it's not writable, and the polling in send_tcp() will time out or - # hang forever. - sock.getpeername() s = sock else: (af, destination, source) = _destination_and_source(where, port, @@ -881,8 +896,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, _connect(s, destination, expiration) l = len(wire) if use_udp: - _wait_for_writable(s, expiration) - s.send(wire) + _udp_send(s, wire, None, expiration) else: tcpmsg = struct.pack("!H", l) + wire _net_write(s, tcpmsg, expiration) @@ -903,8 +917,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, (expiration is not None and mexpiration > expiration): mexpiration = expiration if use_udp: - _wait_for_readable(s, expiration) - (wire, from_address) = s.recvfrom(65535) + (wire, from_address) = _udp_recv(s, 65535, expiration) else: ldata = _net_read(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) |