summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-05-18 08:08:05 -0700
committerBob Halley <halley@dnspython.org>2020-05-18 08:08:05 -0700
commit10f0f35c892f9c636eab8e97064d4fa941d7252d (patch)
tree875d3497d945b79fdd1b014ab41681c88dfe5b8d
parent037dfe7f71f21e37fade0e3a203e38b675a1aa43 (diff)
downloaddnspython-10f0f35c892f9c636eab8e97064d4fa941d7252d.tar.gz
refactor resolver, extracting all business logic
-rw-r--r--dns/resolver.py363
-rw-r--r--dns/trio/resolver.py204
2 files changed, 232 insertions, 335 deletions
diff --git a/dns/resolver.py b/dns/resolver.py
index 3af35f4..474219a 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -490,6 +490,169 @@ class LRUCache(object):
node = next
self.data = {}
+class _Resolution(object):
+ """Helper class for dns.resolver.Resolver.resolve().
+
+ All of the "business logic" of resolution is encapsulated in this
+ class, allowing us to have multiple resolve() implementations
+ using different I/O schemes without copying all of the
+ complicated logic.
+
+ This class is a "friend" to dns.resolver.Resolver and manipulates
+ resolver data structures directly.
+ """
+
+ def __init__(self, resolver, qname, rdtype, rdclass, tcp,
+ raise_on_no_answer, search):
+ if isinstance(qname, str):
+ qname = dns.name.from_text(qname, None)
+ if isinstance(rdtype, str):
+ rdtype = dns.rdatatype.from_text(rdtype)
+ if dns.rdatatype.is_metatype(rdtype):
+ raise NoMetaqueries
+ if isinstance(rdclass, str):
+ rdclass = dns.rdataclass.from_text(rdclass)
+ if dns.rdataclass.is_metaclass(rdclass):
+ raise NoMetaqueries
+ self.resolver = resolver
+ self.qnames_to_try = resolver._get_qnames_to_try(qname, search)
+ self.qnames = self.qnames_to_try[:]
+ self.rdtype = rdtype
+ self.rdclass = rdclass
+ self.tcp = tcp
+ self.raise_on_no_answer = raise_on_no_answer
+ self.nxdomain_responses = {}
+
+ def next_request(self):
+ """Get the next request to send, and check the cache.
+
+ Returns a (request, answer) tuple. At most one of request or
+ answer will not be None.
+ """
+
+ # We return a tuple instead of Union[Message,Answer] as it lets
+ # the caller avoid isinstance.
+
+ if len(self.qnames) == 0:
+ #
+ # We've tried everything and only gotten NXDOMAINs. (We know
+ # it's only NXDOMAINs as anything else would have returned
+ # before now.)
+ #
+ raise NXDOMAIN(qnames=self.qnames_to_try,
+ responses=self.nxdomain_responses)
+
+ self.qname = self.qnames.pop()
+
+ # Do we know the answer?
+ if self.resolver.cache:
+ answer = self.resolver.cache.get((self.qname, self.rdtype,
+ self.rdclass))
+ if answer is not None:
+ if answer.rrset is None and self.raise_on_no_answer:
+ raise NoAnswer(response=answer.response)
+ else:
+ return (None, answer)
+
+ # Build the request
+ request = dns.message.make_query(self.qname, self.rdtype, self.rdclass)
+ if self.resolver.keyname is not None:
+ request.use_tsig(self.resolver.keyring, self.resolver.keyname,
+ algorithm=self.resolver.keyalgorithm)
+ request.use_edns(self.resolver.edns, self.resolver.ednsflags,
+ self.resolver.payload)
+ if self.resolver.flags is not None:
+ request.flags = self.resolver.flags
+
+ self.nameservers = self.resolver.nameservers[:]
+ if self.resolver.rotate:
+ random.shuffle(self.nameservers)
+ self.current_nameservers = self.nameservers[:]
+ self.errors = []
+ self.nameserver = None
+ self.tcp_attempt = False
+ self.retry_with_tcp = False
+ self.request = request
+ self.backoff = 0.10
+
+ return (request, None)
+
+ def next_nameserver(self):
+ if self.retry_with_tcp:
+ assert self.nameserver is not None
+ self.tcp_attempt = True
+ self.retry_with_tcp = False
+ return (self.nameserver, self.port, True)
+
+ backoff = 0
+ if not self.current_nameservers:
+ if len(self.nameservers) == 0:
+ # Out of things to try!
+ raise NoNameservers(request=self.request, errors=self.errors)
+ self.current_nameservers = self.nameservers[:]
+ backoff = self.backoff
+ self.backoff = min(self.backoff * 2, 2)
+
+ self.nameserver = self.current_nameservers.pop()
+ self.port = self.resolver.nameserver_ports.get(self.nameserver,
+ self.resolver.port)
+ self.tcp_attempt = self.tcp
+ return (self.nameserver, self.port, self.tcp_attempt, backoff)
+
+ def query_result(self, response, ex):
+ #
+ # returns an (answer: Answer, end_loop: bool) tuple.
+ #
+ if ex:
+ # Exception during I/O or from_wire()
+ assert response is None
+ self.errors.append((self.nameserver, self.tcp_attempt, self.port,
+ ex, response))
+ if isinstance(ex, dns.exception.FormError) or \
+ isinstance(ex, EOFError) or \
+ isinstance(ex, NotImplementedError):
+ # This nameserver is no good, take it out of the mix.
+ self.nameservers.remove(self.nameserver)
+ elif isinstance(ex, dns.message.Truncated):
+ if self.tcp_attempt:
+ # Truncation with TCP is no good!
+ self.nameservers.remove(self.nameserver)
+ else:
+ self.retry_with_tcp = True
+ return (None, False)
+ # We got an answer!
+ assert response is not None
+ rcode = response.rcode()
+ if rcode == dns.rcode.NOERROR:
+ answer = Answer(self.qname, self.rdtype, self.rdclass, response,
+ self.raise_on_no_answer, self.nameserver,
+ self.port)
+ if self.resolver.cache:
+ self.resolver.cache.put((self.qname, self.rdtype,
+ self.rdclass), answer)
+ return (answer, True)
+ elif rcode == dns.rcode.NXDOMAIN:
+ self.nxdomain_responses[self.qname] = response
+ # Make next_nameserver() return None, so caller breaks its
+ # inner loop and calls next_request().
+ return (None, True)
+ elif rcode == dns.rcode.YXDOMAIN:
+ yex = YXDOMAIN()
+ self.errors.append((self.nameserver, self.tcp_attempt,
+ self.port, yex, response))
+ raise yex
+ else:
+ #
+ # We got a response, but we're not happy with the
+ # rcode in it. Remove the server from the mix if
+ # the rcode isn't SERVFAIL.
+ #
+ if rcode != dns.rcode.SERVFAIL or not self.retry_servfail:
+ self.nameservers.remove(self.nameserver)
+ self.errors.append((self.nameserver, self.tcp_attempt, self.port,
+ dns.rcode.to_text(rcode), response))
+ return (None, False)
+
class Resolver(object):
"""DNS stub resolver."""
@@ -862,179 +1025,47 @@ class Resolver(object):
"""
- if isinstance(qname, str):
- qname = dns.name.from_text(qname, None)
- if isinstance(rdtype, str):
- rdtype = dns.rdatatype.from_text(rdtype)
- if dns.rdatatype.is_metatype(rdtype):
- raise NoMetaqueries
- if isinstance(rdclass, str):
- rdclass = dns.rdataclass.from_text(rdclass)
- if dns.rdataclass.is_metaclass(rdclass):
- raise NoMetaqueries
- qnames_to_try = self._get_qnames_to_try(qname, search)
- all_nxdomain = True
- nxdomain_responses = {}
+ resolution = _Resolution(self, qname, rdtype, rdclass, tcp,
+ raise_on_no_answer, search)
start = time.time()
- _qname = None # make pylint happy
- for _qname in qnames_to_try:
- if self.cache:
- answer = self.cache.get((_qname, rdtype, rdclass))
- if answer is not None:
- if answer.rrset is None and raise_on_no_answer:
- raise NoAnswer(response=answer.response)
+ while True:
+ (request, answer) = resolution.next_request()
+ if answer:
+ # cache hit!
+ return answer
+ done = False
+ while not done:
+ (nameserver, port, tcp, backoff) = resolution.next_nameserver()
+ if backoff:
+ time.sleep(backoff)
+ timeout = self._compute_timeout(start, lifetime)
+ try:
+ if dns.inet.is_address(nameserver):
+ if tcp:
+ response = dns.query.tcp(request, nameserver,
+ timeout=timeout,
+ port=port,
+ source=source,
+ source_port=source_port)
+ else:
+ response = dns.query.udp(request,
+ nameserver,
+ timeout=timeout,
+ port=port,
+ source=source,
+ source_port=source_port)
else:
- return answer
- request = dns.message.make_query(_qname, rdtype, rdclass)
- if self.keyname is not None:
- request.use_tsig(self.keyring, self.keyname,
- algorithm=self.keyalgorithm)
- request.use_edns(self.edns, self.ednsflags, self.payload)
- if self.flags is not None:
- request.flags = self.flags
- response = None
- #
- # make a copy of the servers list so we can alter it later.
- #
- nameservers = self.nameservers[:]
- errors = []
- if self.rotate:
- random.shuffle(nameservers)
- backoff = 0.10
- # keep track of nameserver and port
- # to include them in Answer
- nameserver_answered = None
- port_answered = None
- while response is None:
- if len(nameservers) == 0:
- raise NoNameservers(request=request, errors=errors)
- for nameserver in nameservers[:]:
- timeout = self._compute_timeout(start, lifetime)
- port = self.nameserver_ports.get(nameserver, self.port)
- protocol = urlparse(nameserver).scheme
- try:
+ protocol = urlparse(nameserver).scheme
if protocol == 'https':
- tcp_attempt = True
response = dns.query.https(request, nameserver,
timeout=timeout)
elif protocol:
continue
- else:
- tcp_attempt = tcp
- if tcp:
- response = \
- dns.query.tcp(request, nameserver,
- timeout=timeout,
- port=port,
- source=source,
- source_port=source_port)
- else:
- try:
- response = \
- dns.query.udp(request,
- nameserver,
- timeout=timeout,
- port=port,
- source=source,
- source_port=source_port)
- except dns.message.Truncated:
- # Response truncated; retry with TCP.
- tcp_attempt = True
- timeout = self._compute_timeout(start,
- lifetime)
- response = \
- dns.query.tcp(request, nameserver,
- timeout=timeout,
- port=port,
- source=source,
- source_port=source_port)
- except (socket.error, dns.exception.Timeout) as ex:
- #
- # Communication failure or timeout. Go to the
- # next server
- #
- errors.append((nameserver, tcp_attempt, port, ex,
- response))
- response = None
- continue
- except dns.query.UnexpectedSource as ex:
- #
- # Who knows? Keep going.
- #
- errors.append((nameserver, tcp_attempt, port, ex,
- response))
- response = None
- continue
- except dns.exception.FormError as ex:
- #
- # We don't understand what this server is
- # saying. Take it out of the mix and
- # continue.
- #
- nameservers.remove(nameserver)
- errors.append((nameserver, tcp_attempt, port, ex,
- response))
- response = None
- continue
- except EOFError as ex:
- #
- # We're using TCP and they hung up on us.
- # Probably they don't support TCP (though
- # they're supposed to!). Take it out of the
- # mix and continue.
- #
- nameservers.remove(nameserver)
- errors.append((nameserver, tcp_attempt, port, ex,
- response))
- response = None
- continue
- nameserver_answered = nameserver
- port_answered = port
- rcode = response.rcode()
- if rcode == dns.rcode.YXDOMAIN:
- yex = YXDOMAIN()
- errors.append((nameserver, tcp_attempt, port, yex,
- response))
- raise yex
- if rcode == dns.rcode.NOERROR or \
- rcode == dns.rcode.NXDOMAIN:
- break
- #
- # We got a response, but we're not happy with the
- # rcode in it. Remove the server from the mix if
- # the rcode isn't SERVFAIL.
- #
- if rcode != dns.rcode.SERVFAIL or not self.retry_servfail:
- nameservers.remove(nameserver)
- errors.append((nameserver, tcp_attempt, port,
- dns.rcode.to_text(rcode), response))
- response = None
- if response is not None:
- break
- #
- # All nameservers failed!
- #
- if len(nameservers) > 0:
- #
- # But we still have servers to try. Sleep a bit
- # so we don't pound them!
- #
- timeout = self._compute_timeout(start, lifetime)
- sleep_time = min(timeout, backoff)
- backoff *= 2
- time.sleep(sleep_time)
- if response.rcode() == dns.rcode.NXDOMAIN:
- nxdomain_responses[_qname] = response
- continue
- all_nxdomain = False
- break
- if all_nxdomain:
- raise NXDOMAIN(qnames=qnames_to_try, responses=nxdomain_responses)
- answer = Answer(_qname, rdtype, rdclass, response,
- raise_on_no_answer, nameserver_answered, port_answered)
- if self.cache:
- self.cache.put((_qname, rdtype, rdclass), answer)
- return answer
+ (answer, done) = resolution.query_result(response, None)
+ if answer:
+ return answer
+ except Exception as ex:
+ (_, done) = resolution.query_result(None, ex)
def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
tcp=False, source=None, raise_on_no_answer=True, source_port=0,
diff --git a/dns/trio/resolver.py b/dns/trio/resolver.py
index 2f96a52..785fde6 100644
--- a/dns/trio/resolver.py
+++ b/dns/trio/resolver.py
@@ -17,19 +17,15 @@
"""trio async I/O library DNS stub resolver."""
-import random
-import socket
import trio
-from urllib.parse import urlparse
import dns.exception
import dns.query
import dns.resolver
import dns.trio.query
-# import resolver symbols for compatibility and brevity
-from dns.resolver import NXDOMAIN, YXDOMAIN, NoAnswer, NoNameservers, \
- NotAbsolute, NoRootSOA, NoMetaqueries, Answer
+# import some resolver symbols for brevity
+from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA
# we do this for indentation reasons below
_udp = dns.trio.query.udp
@@ -87,62 +83,25 @@ class Resolver(dns.resolver.Resolver):
"""
- if isinstance(qname, str):
- qname = dns.name.from_text(qname, None)
- if isinstance(rdtype, str):
- rdtype = dns.rdatatype.from_text(rdtype)
- if dns.rdatatype.is_metatype(rdtype):
- raise NoMetaqueries
- if isinstance(rdclass, str):
- rdclass = dns.rdataclass.from_text(rdclass)
- if dns.rdataclass.is_metaclass(rdclass):
- raise NoMetaqueries
- qnames_to_try = self._get_qnames_to_try(qname, search)
- all_nxdomain = True
- nxdomain_responses = {}
- _qname = None # make pylint happy
- for _qname in qnames_to_try:
- if self.cache:
- answer = self.cache.get((_qname, rdtype, rdclass))
- if answer is not None:
- if answer.rrset is None and raise_on_no_answer:
- raise NoAnswer(response=answer.response)
- else:
- return answer
- request = dns.message.make_query(_qname, rdtype, rdclass)
- if self.keyname is not None:
- request.use_tsig(self.keyring, self.keyname,
- algorithm=self.keyalgorithm)
- request.use_edns(self.edns, self.ednsflags, self.payload)
- if self.flags is not None:
- request.flags = self.flags
- response = None
- #
- # make a copy of the servers list so we can alter it later.
- #
- nameservers = self.nameservers[:]
- errors = []
- if self.rotate:
- random.shuffle(nameservers)
- backoff = 0.10
- # keep track of nameserver and port
- # to include them in Answer
- nameserver_answered = None
- port_answered = None
- loops = 0
- while response is None:
- if len(nameservers) == 0:
- raise NoNameservers(request=request, errors=errors)
- for nameserver in nameservers[:]:
- port = self.nameserver_ports.get(nameserver, self.port)
- protocol = urlparse(nameserver).scheme
- try:
- with trio.fail_after(self.timeout):
- if protocol == 'https':
- raise NotImplementedError
- elif protocol:
- continue
- tcp_attempt = tcp
+ resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp,
+ raise_on_no_answer, search)
+ while True:
+ (request, answer) = resolution.next_request()
+ if answer:
+ # cache hit!
+ return answer
+ loops = 1
+ done = False
+ while not done:
+ (nameserver, port, tcp, backoff) = resolution.next_nameserver()
+ if backoff:
+ loops += 1
+ if loops >= 5:
+ raise TooManyAttempts
+ await trio.sleep(backoff)
+ try:
+ with trio.fail_after(self.timeout):
+ if dns.inet.is_address(nameserver):
if tcp:
response = await \
_stream(request, nameserver,
@@ -150,113 +109,20 @@ class Resolver(dns.resolver.Resolver):
source=source,
source_port=source_port)
else:
- try:
- response = await \
- _udp(request,
- nameserver,
- port=port,
- source=source,
- source_port=source_port)
- except dns.message.Truncated:
- # Response truncated; retry with TCP.
- tcp_attempt = True
- response = await \
- _stream(request, nameserver,
- port=port,
- source=source,
- source_port=source_port)
- except (socket.error, trio.TooSlowError) as ex:
- #
- # Communication failure or timeout. Go to the
- # next server
- #
- errors.append((nameserver, tcp_attempt, port, ex,
- response))
- response = None
- continue
- except dns.query.UnexpectedSource as ex:
- #
- # Who knows? Keep going.
- #
- errors.append((nameserver, tcp_attempt, port, ex,
- response))
- response = None
- continue
- except dns.exception.FormError as ex:
- #
- # We don't understand what this server is
- # saying. Take it out of the mix and
- # continue.
- #
- nameservers.remove(nameserver)
- errors.append((nameserver, tcp_attempt, port, ex,
- response))
- response = None
- continue
- except EOFError as ex:
- #
- # We're using TCP and they hung up on us.
- # Probably they don't support TCP (though
- # they're supposed to!). Take it out of the
- # mix and continue.
- #
- nameservers.remove(nameserver)
- errors.append((nameserver, tcp_attempt, port, ex,
- response))
- response = None
- continue
- nameserver_answered = nameserver
- port_answered = port
- rcode = response.rcode()
- if rcode == dns.rcode.YXDOMAIN:
- yex = YXDOMAIN()
- errors.append((nameserver, tcp_attempt, port, yex,
- response))
- raise yex
- if rcode == dns.rcode.NOERROR or \
- rcode == dns.rcode.NXDOMAIN:
- break
- #
- # We got a response, but we're not happy with the
- # rcode in it. Remove the server from the mix if
- # the rcode isn't SERVFAIL.
- #
- if rcode != dns.rcode.SERVFAIL or not self.retry_servfail:
- nameservers.remove(nameserver)
- errors.append((nameserver, tcp_attempt, port,
- dns.rcode.to_text(rcode), response))
- response = None
- if response is not None:
- break
- #
- # All nameservers failed!
- #
- # Do not loop forever if caller hasn't used a timeout
- # scope.
- loops += 1
- if loops >= 5:
- raise TooManyAttempts
- if len(nameservers) > 0:
- #
- # But we still have servers to try. Sleep a bit
- # so we don't pound them!
- #
- await trio.sleep(backoff)
- backoff *= 2
- if backoff > 2:
- backoff = 2
- if response.rcode() == dns.rcode.NXDOMAIN:
- nxdomain_responses[_qname] = response
- continue
- all_nxdomain = False
- break
- if all_nxdomain:
- raise NXDOMAIN(qnames=qnames_to_try, responses=nxdomain_responses)
- answer = Answer(_qname, rdtype, rdclass, response, raise_on_no_answer,
- nameserver_answered, port_answered)
- if self.cache:
- self.cache.put((_qname, rdtype, rdclass), answer)
- return answer
+ response = await \
+ _udp(request,
+ nameserver,
+ port=port,
+ source=source,
+ source_port=source_port)
+ else:
+ # We don't do DoH yet.
+ raise NotImplementedError
+ (answer, done) = resolution.query_result(response, None)
+ if answer:
+ return answer
+ except Exception as ex:
+ (_, done) = resolution.query_result(None, ex)
async def query(self, *args, **kwargs):
# We have to define something here as we don't want to inherit the