From 10f0f35c892f9c636eab8e97064d4fa941d7252d Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Mon, 18 May 2020 08:08:05 -0700 Subject: refactor resolver, extracting all business logic --- dns/resolver.py | 363 ++++++++++++++++++++++++++++----------------------- dns/trio/resolver.py | 204 +++++------------------------ 2 files changed, 232 insertions(+), 335 deletions(-) diff --git a/dns/resolver.py b/dns/resolver.py index 3af35f4..474219a 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -490,6 +490,169 @@ class LRUCache(object): node = next self.data = {} +class _Resolution(object): + """Helper class for dns.resolver.Resolver.resolve(). + + All of the "business logic" of resolution is encapsulated in this + class, allowing us to have multiple resolve() implementations + using different I/O schemes without copying all of the + complicated logic. + + This class is a "friend" to dns.resolver.Resolver and manipulates + resolver data structures directly. + """ + + def __init__(self, resolver, qname, rdtype, rdclass, tcp, + raise_on_no_answer, search): + if isinstance(qname, str): + qname = dns.name.from_text(qname, None) + if isinstance(rdtype, str): + rdtype = dns.rdatatype.from_text(rdtype) + if dns.rdatatype.is_metatype(rdtype): + raise NoMetaqueries + if isinstance(rdclass, str): + rdclass = dns.rdataclass.from_text(rdclass) + if dns.rdataclass.is_metaclass(rdclass): + raise NoMetaqueries + self.resolver = resolver + self.qnames_to_try = resolver._get_qnames_to_try(qname, search) + self.qnames = self.qnames_to_try[:] + self.rdtype = rdtype + self.rdclass = rdclass + self.tcp = tcp + self.raise_on_no_answer = raise_on_no_answer + self.nxdomain_responses = {} + + def next_request(self): + """Get the next request to send, and check the cache. + + Returns a (request, answer) tuple. At most one of request or + answer will not be None. + """ + + # We return a tuple instead of Union[Message,Answer] as it lets + # the caller avoid isinstance. + + if len(self.qnames) == 0: + # + # We've tried everything and only gotten NXDOMAINs. (We know + # it's only NXDOMAINs as anything else would have returned + # before now.) + # + raise NXDOMAIN(qnames=self.qnames_to_try, + responses=self.nxdomain_responses) + + self.qname = self.qnames.pop() + + # Do we know the answer? + if self.resolver.cache: + answer = self.resolver.cache.get((self.qname, self.rdtype, + self.rdclass)) + if answer is not None: + if answer.rrset is None and self.raise_on_no_answer: + raise NoAnswer(response=answer.response) + else: + return (None, answer) + + # Build the request + request = dns.message.make_query(self.qname, self.rdtype, self.rdclass) + if self.resolver.keyname is not None: + request.use_tsig(self.resolver.keyring, self.resolver.keyname, + algorithm=self.resolver.keyalgorithm) + request.use_edns(self.resolver.edns, self.resolver.ednsflags, + self.resolver.payload) + if self.resolver.flags is not None: + request.flags = self.resolver.flags + + self.nameservers = self.resolver.nameservers[:] + if self.resolver.rotate: + random.shuffle(self.nameservers) + self.current_nameservers = self.nameservers[:] + self.errors = [] + self.nameserver = None + self.tcp_attempt = False + self.retry_with_tcp = False + self.request = request + self.backoff = 0.10 + + return (request, None) + + def next_nameserver(self): + if self.retry_with_tcp: + assert self.nameserver is not None + self.tcp_attempt = True + self.retry_with_tcp = False + return (self.nameserver, self.port, True) + + backoff = 0 + if not self.current_nameservers: + if len(self.nameservers) == 0: + # Out of things to try! + raise NoNameservers(request=self.request, errors=self.errors) + self.current_nameservers = self.nameservers[:] + backoff = self.backoff + self.backoff = min(self.backoff * 2, 2) + + self.nameserver = self.current_nameservers.pop() + self.port = self.resolver.nameserver_ports.get(self.nameserver, + self.resolver.port) + self.tcp_attempt = self.tcp + return (self.nameserver, self.port, self.tcp_attempt, backoff) + + def query_result(self, response, ex): + # + # returns an (answer: Answer, end_loop: bool) tuple. + # + if ex: + # Exception during I/O or from_wire() + assert response is None + self.errors.append((self.nameserver, self.tcp_attempt, self.port, + ex, response)) + if isinstance(ex, dns.exception.FormError) or \ + isinstance(ex, EOFError) or \ + isinstance(ex, NotImplementedError): + # This nameserver is no good, take it out of the mix. + self.nameservers.remove(self.nameserver) + elif isinstance(ex, dns.message.Truncated): + if self.tcp_attempt: + # Truncation with TCP is no good! + self.nameservers.remove(self.nameserver) + else: + self.retry_with_tcp = True + return (None, False) + # We got an answer! + assert response is not None + rcode = response.rcode() + if rcode == dns.rcode.NOERROR: + answer = Answer(self.qname, self.rdtype, self.rdclass, response, + self.raise_on_no_answer, self.nameserver, + self.port) + if self.resolver.cache: + self.resolver.cache.put((self.qname, self.rdtype, + self.rdclass), answer) + return (answer, True) + elif rcode == dns.rcode.NXDOMAIN: + self.nxdomain_responses[self.qname] = response + # Make next_nameserver() return None, so caller breaks its + # inner loop and calls next_request(). + return (None, True) + elif rcode == dns.rcode.YXDOMAIN: + yex = YXDOMAIN() + self.errors.append((self.nameserver, self.tcp_attempt, + self.port, yex, response)) + raise yex + else: + # + # We got a response, but we're not happy with the + # rcode in it. Remove the server from the mix if + # the rcode isn't SERVFAIL. + # + if rcode != dns.rcode.SERVFAIL or not self.retry_servfail: + self.nameservers.remove(self.nameserver) + self.errors.append((self.nameserver, self.tcp_attempt, self.port, + dns.rcode.to_text(rcode), response)) + return (None, False) + class Resolver(object): """DNS stub resolver.""" @@ -862,179 +1025,47 @@ class Resolver(object): """ - if isinstance(qname, str): - qname = dns.name.from_text(qname, None) - if isinstance(rdtype, str): - rdtype = dns.rdatatype.from_text(rdtype) - if dns.rdatatype.is_metatype(rdtype): - raise NoMetaqueries - if isinstance(rdclass, str): - rdclass = dns.rdataclass.from_text(rdclass) - if dns.rdataclass.is_metaclass(rdclass): - raise NoMetaqueries - qnames_to_try = self._get_qnames_to_try(qname, search) - all_nxdomain = True - nxdomain_responses = {} + resolution = _Resolution(self, qname, rdtype, rdclass, tcp, + raise_on_no_answer, search) start = time.time() - _qname = None # make pylint happy - for _qname in qnames_to_try: - if self.cache: - answer = self.cache.get((_qname, rdtype, rdclass)) - if answer is not None: - if answer.rrset is None and raise_on_no_answer: - raise NoAnswer(response=answer.response) + while True: + (request, answer) = resolution.next_request() + if answer: + # cache hit! + return answer + done = False + while not done: + (nameserver, port, tcp, backoff) = resolution.next_nameserver() + if backoff: + time.sleep(backoff) + timeout = self._compute_timeout(start, lifetime) + try: + if dns.inet.is_address(nameserver): + if tcp: + response = dns.query.tcp(request, nameserver, + timeout=timeout, + port=port, + source=source, + source_port=source_port) + else: + response = dns.query.udp(request, + nameserver, + timeout=timeout, + port=port, + source=source, + source_port=source_port) else: - return answer - request = dns.message.make_query(_qname, rdtype, rdclass) - if self.keyname is not None: - request.use_tsig(self.keyring, self.keyname, - algorithm=self.keyalgorithm) - request.use_edns(self.edns, self.ednsflags, self.payload) - if self.flags is not None: - request.flags = self.flags - response = None - # - # make a copy of the servers list so we can alter it later. - # - nameservers = self.nameservers[:] - errors = [] - if self.rotate: - random.shuffle(nameservers) - backoff = 0.10 - # keep track of nameserver and port - # to include them in Answer - nameserver_answered = None - port_answered = None - while response is None: - if len(nameservers) == 0: - raise NoNameservers(request=request, errors=errors) - for nameserver in nameservers[:]: - timeout = self._compute_timeout(start, lifetime) - port = self.nameserver_ports.get(nameserver, self.port) - protocol = urlparse(nameserver).scheme - try: + protocol = urlparse(nameserver).scheme if protocol == 'https': - tcp_attempt = True response = dns.query.https(request, nameserver, timeout=timeout) elif protocol: continue - else: - tcp_attempt = tcp - if tcp: - response = \ - dns.query.tcp(request, nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port) - else: - try: - response = \ - dns.query.udp(request, - nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port) - except dns.message.Truncated: - # Response truncated; retry with TCP. - tcp_attempt = True - timeout = self._compute_timeout(start, - lifetime) - response = \ - dns.query.tcp(request, nameserver, - timeout=timeout, - port=port, - source=source, - source_port=source_port) - except (socket.error, dns.exception.Timeout) as ex: - # - # Communication failure or timeout. Go to the - # next server - # - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except dns.query.UnexpectedSource as ex: - # - # Who knows? Keep going. - # - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except dns.exception.FormError as ex: - # - # We don't understand what this server is - # saying. Take it out of the mix and - # continue. - # - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except EOFError as ex: - # - # We're using TCP and they hung up on us. - # Probably they don't support TCP (though - # they're supposed to!). Take it out of the - # mix and continue. - # - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - nameserver_answered = nameserver - port_answered = port - rcode = response.rcode() - if rcode == dns.rcode.YXDOMAIN: - yex = YXDOMAIN() - errors.append((nameserver, tcp_attempt, port, yex, - response)) - raise yex - if rcode == dns.rcode.NOERROR or \ - rcode == dns.rcode.NXDOMAIN: - break - # - # We got a response, but we're not happy with the - # rcode in it. Remove the server from the mix if - # the rcode isn't SERVFAIL. - # - if rcode != dns.rcode.SERVFAIL or not self.retry_servfail: - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, - dns.rcode.to_text(rcode), response)) - response = None - if response is not None: - break - # - # All nameservers failed! - # - if len(nameservers) > 0: - # - # But we still have servers to try. Sleep a bit - # so we don't pound them! - # - timeout = self._compute_timeout(start, lifetime) - sleep_time = min(timeout, backoff) - backoff *= 2 - time.sleep(sleep_time) - if response.rcode() == dns.rcode.NXDOMAIN: - nxdomain_responses[_qname] = response - continue - all_nxdomain = False - break - if all_nxdomain: - raise NXDOMAIN(qnames=qnames_to_try, responses=nxdomain_responses) - answer = Answer(_qname, rdtype, rdclass, response, - raise_on_no_answer, nameserver_answered, port_answered) - if self.cache: - self.cache.put((_qname, rdtype, rdclass), answer) - return answer + (answer, done) = resolution.query_result(response, None) + if answer: + return answer + except Exception as ex: + (_, done) = resolution.query_result(None, ex) def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, tcp=False, source=None, raise_on_no_answer=True, source_port=0, diff --git a/dns/trio/resolver.py b/dns/trio/resolver.py index 2f96a52..785fde6 100644 --- a/dns/trio/resolver.py +++ b/dns/trio/resolver.py @@ -17,19 +17,15 @@ """trio async I/O library DNS stub resolver.""" -import random -import socket import trio -from urllib.parse import urlparse import dns.exception import dns.query import dns.resolver import dns.trio.query -# import resolver symbols for compatibility and brevity -from dns.resolver import NXDOMAIN, YXDOMAIN, NoAnswer, NoNameservers, \ - NotAbsolute, NoRootSOA, NoMetaqueries, Answer +# import some resolver symbols for brevity +from dns.resolver import NXDOMAIN, NoAnswer, NotAbsolute, NoRootSOA # we do this for indentation reasons below _udp = dns.trio.query.udp @@ -87,62 +83,25 @@ class Resolver(dns.resolver.Resolver): """ - if isinstance(qname, str): - qname = dns.name.from_text(qname, None) - if isinstance(rdtype, str): - rdtype = dns.rdatatype.from_text(rdtype) - if dns.rdatatype.is_metatype(rdtype): - raise NoMetaqueries - if isinstance(rdclass, str): - rdclass = dns.rdataclass.from_text(rdclass) - if dns.rdataclass.is_metaclass(rdclass): - raise NoMetaqueries - qnames_to_try = self._get_qnames_to_try(qname, search) - all_nxdomain = True - nxdomain_responses = {} - _qname = None # make pylint happy - for _qname in qnames_to_try: - if self.cache: - answer = self.cache.get((_qname, rdtype, rdclass)) - if answer is not None: - if answer.rrset is None and raise_on_no_answer: - raise NoAnswer(response=answer.response) - else: - return answer - request = dns.message.make_query(_qname, rdtype, rdclass) - if self.keyname is not None: - request.use_tsig(self.keyring, self.keyname, - algorithm=self.keyalgorithm) - request.use_edns(self.edns, self.ednsflags, self.payload) - if self.flags is not None: - request.flags = self.flags - response = None - # - # make a copy of the servers list so we can alter it later. - # - nameservers = self.nameservers[:] - errors = [] - if self.rotate: - random.shuffle(nameservers) - backoff = 0.10 - # keep track of nameserver and port - # to include them in Answer - nameserver_answered = None - port_answered = None - loops = 0 - while response is None: - if len(nameservers) == 0: - raise NoNameservers(request=request, errors=errors) - for nameserver in nameservers[:]: - port = self.nameserver_ports.get(nameserver, self.port) - protocol = urlparse(nameserver).scheme - try: - with trio.fail_after(self.timeout): - if protocol == 'https': - raise NotImplementedError - elif protocol: - continue - tcp_attempt = tcp + resolution = dns.resolver._Resolution(self, qname, rdtype, rdclass, tcp, + raise_on_no_answer, search) + while True: + (request, answer) = resolution.next_request() + if answer: + # cache hit! + return answer + loops = 1 + done = False + while not done: + (nameserver, port, tcp, backoff) = resolution.next_nameserver() + if backoff: + loops += 1 + if loops >= 5: + raise TooManyAttempts + await trio.sleep(backoff) + try: + with trio.fail_after(self.timeout): + if dns.inet.is_address(nameserver): if tcp: response = await \ _stream(request, nameserver, @@ -150,113 +109,20 @@ class Resolver(dns.resolver.Resolver): source=source, source_port=source_port) else: - try: - response = await \ - _udp(request, - nameserver, - port=port, - source=source, - source_port=source_port) - except dns.message.Truncated: - # Response truncated; retry with TCP. - tcp_attempt = True - response = await \ - _stream(request, nameserver, - port=port, - source=source, - source_port=source_port) - except (socket.error, trio.TooSlowError) as ex: - # - # Communication failure or timeout. Go to the - # next server - # - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except dns.query.UnexpectedSource as ex: - # - # Who knows? Keep going. - # - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except dns.exception.FormError as ex: - # - # We don't understand what this server is - # saying. Take it out of the mix and - # continue. - # - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - except EOFError as ex: - # - # We're using TCP and they hung up on us. - # Probably they don't support TCP (though - # they're supposed to!). Take it out of the - # mix and continue. - # - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, ex, - response)) - response = None - continue - nameserver_answered = nameserver - port_answered = port - rcode = response.rcode() - if rcode == dns.rcode.YXDOMAIN: - yex = YXDOMAIN() - errors.append((nameserver, tcp_attempt, port, yex, - response)) - raise yex - if rcode == dns.rcode.NOERROR or \ - rcode == dns.rcode.NXDOMAIN: - break - # - # We got a response, but we're not happy with the - # rcode in it. Remove the server from the mix if - # the rcode isn't SERVFAIL. - # - if rcode != dns.rcode.SERVFAIL or not self.retry_servfail: - nameservers.remove(nameserver) - errors.append((nameserver, tcp_attempt, port, - dns.rcode.to_text(rcode), response)) - response = None - if response is not None: - break - # - # All nameservers failed! - # - # Do not loop forever if caller hasn't used a timeout - # scope. - loops += 1 - if loops >= 5: - raise TooManyAttempts - if len(nameservers) > 0: - # - # But we still have servers to try. Sleep a bit - # so we don't pound them! - # - await trio.sleep(backoff) - backoff *= 2 - if backoff > 2: - backoff = 2 - if response.rcode() == dns.rcode.NXDOMAIN: - nxdomain_responses[_qname] = response - continue - all_nxdomain = False - break - if all_nxdomain: - raise NXDOMAIN(qnames=qnames_to_try, responses=nxdomain_responses) - answer = Answer(_qname, rdtype, rdclass, response, raise_on_no_answer, - nameserver_answered, port_answered) - if self.cache: - self.cache.put((_qname, rdtype, rdclass), answer) - return answer + response = await \ + _udp(request, + nameserver, + port=port, + source=source, + source_port=source_port) + else: + # We don't do DoH yet. + raise NotImplementedError + (answer, done) = resolution.query_result(response, None) + if answer: + return answer + except Exception as ex: + (_, done) = resolution.query_result(None, ex) async def query(self, *args, **kwargs): # We have to define something here as we don't want to inherit the -- cgit v1.2.1