diff options
-rw-r--r-- | dns/resolver.py | 103 | ||||
-rw-r--r-- | tests/test_resolution.py | 46 |
2 files changed, 104 insertions, 45 deletions
diff --git a/dns/resolver.py b/dns/resolver.py index e50eab8..cc1f78b 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -536,51 +536,61 @@ class _Resolution(object): """ # We return a tuple instead of Union[Message,Answer] as it lets - # the caller avoid isinstance. + # the caller avoid isinstance(). - if len(self.qnames) == 0: - # - # We've tried everything and only gotten NXDOMAINs. (We know - # it's only NXDOMAINs as anything else would have returned - # before now.) - # - raise NXDOMAIN(qnames=self.qnames_to_try, - responses=self.nxdomain_responses) - - self.qname = self.qnames.pop(0) - - # Do we know the answer? - if self.resolver.cache: - answer = self.resolver.cache.get((self.qname, self.rdtype, - self.rdclass)) - if answer is not None: - if answer.rrset is None and self.raise_on_no_answer: - raise NoAnswer(response=answer.response) - else: - return (None, answer) - - # Build the request - request = dns.message.make_query(self.qname, self.rdtype, self.rdclass) - if self.resolver.keyname is not None: - request.use_tsig(self.resolver.keyring, self.resolver.keyname, - algorithm=self.resolver.keyalgorithm) - request.use_edns(self.resolver.edns, self.resolver.ednsflags, - self.resolver.payload) - if self.resolver.flags is not None: - request.flags = self.resolver.flags - - self.nameservers = self.resolver.nameservers[:] - if self.resolver.rotate: - random.shuffle(self.nameservers) - self.current_nameservers = self.nameservers[:] - self.errors = [] - self.nameserver = None - self.tcp_attempt = False - self.retry_with_tcp = False - self.request = request - self.backoff = 0.10 + while len(self.qnames) > 0: + self.qname = self.qnames.pop(0) + + # Do we know the answer? + if self.resolver.cache: + answer = self.resolver.cache.get((self.qname, self.rdtype, + self.rdclass)) + if answer is not None: + if answer.rrset is None and self.raise_on_no_answer: + raise NoAnswer(response=answer.response) + else: + return (None, answer) + answer = self.resolver.cache.get((self.qname, + dns.rdatatype.ANY, + self.rdclass)) + if answer is not None and \ + answer.response.rcode() == dns.rcode.NXDOMAIN: + # cached NXDOMAIN; record it and continue to next + # name. + self.nxdomain_responses[self.qname] = answer.response + continue + + # Build the request + request = dns.message.make_query(self.qname, self.rdtype, + self.rdclass) + if self.resolver.keyname is not None: + request.use_tsig(self.resolver.keyring, self.resolver.keyname, + algorithm=self.resolver.keyalgorithm) + request.use_edns(self.resolver.edns, self.resolver.ednsflags, + self.resolver.payload) + if self.resolver.flags is not None: + request.flags = self.resolver.flags + + self.nameservers = self.resolver.nameservers[:] + if self.resolver.rotate: + random.shuffle(self.nameservers) + self.current_nameservers = self.nameservers[:] + self.errors = [] + self.nameserver = None + self.tcp_attempt = False + self.retry_with_tcp = False + self.request = request + self.backoff = 0.10 + + return (request, None) - return (request, None) + # + # We've tried everything and only gotten NXDOMAINs. (We know + # it's only NXDOMAINs as anything else would have returned + # before now.) + # + raise NXDOMAIN(qnames=self.qnames_to_try, + responses=self.nxdomain_responses) def next_nameserver(self): if self.retry_with_tcp: @@ -641,6 +651,13 @@ class _Resolution(object): self.nxdomain_responses[self.qname] = response # Make next_nameserver() return None, so caller breaks its # inner loop and calls next_request(). + if self.resolver.cache: + answer = Answer(self.qname, dns.rdatatype.ANY, + dns.rdataclass.IN, response) + self.resolver.cache.put((self.qname, + dns.rdatatype.ANY, + self.rdclass), answer) + return (None, True) elif rcode == dns.rcode.YXDOMAIN: yex = YXDOMAIN() diff --git a/tests/test_resolution.py b/tests/test_resolution.py index 95dd9ae..bb1c4b1 100644 --- a/tests/test_resolution.py +++ b/tests/test_resolution.py @@ -56,7 +56,7 @@ class ResolutionTestCase(unittest.TestCase): def make_negative_response(self, q, nxdomain=False): r = dns.message.make_response(q) - rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN, + rrs = r.get_rrset(r.authority, q.question[0].name, dns.rdataclass.IN, dns.rdatatype.SOA, create=True) rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA, '. . 1 2 3 4 300'), 300) @@ -76,7 +76,7 @@ class ResolutionTestCase(unittest.TestCase): self.assertTrue(request is None) self.assertTrue(answer is cache_answer) - def test_next_request_no_answer(self): + def test_next_request_cached_no_answer(self): # In default mode, we should raise on a no-answer hit self.resolver.cache = dns.resolver.Cache() q = dns.message.make_query(self.qname, dns.rdatatype.A) @@ -98,6 +98,35 @@ class ResolutionTestCase(unittest.TestCase): self.assertTrue(request is None) self.assertTrue(answer is cache_answer) + def test_next_request_cached_nxdomain(self): + # use a relative qname so we have two qnames to try + qname = dns.name.from_text('www.dnspython.org', None) + self.resn = dns.resolver._Resolution(self.resolver, qname, + 'A', 'IN', + False, True, False) + qname1 = dns.name.from_text('www.dnspython.org.example.') + qname2 = dns.name.from_text('www.dnspython.org.') + # Arrange to get NXDOMAIN hits on both of those qnames. + self.resolver.cache = dns.resolver.Cache() + q1 = dns.message.make_query(qname1, dns.rdatatype.A) + r1 = self.make_negative_response(q1, True) + cache_answer = dns.resolver.Answer(qname1, dns.rdatatype.ANY, + dns.rdataclass.IN, r1) + self.resolver.cache.put((qname1, dns.rdatatype.ANY, + dns.rdataclass.IN), cache_answer) + q2 = dns.message.make_query(qname2, dns.rdatatype.A) + r2 = self.make_negative_response(q2, True) + cache_answer = dns.resolver.Answer(qname2, dns.rdatatype.ANY, + dns.rdataclass.IN, r2) + self.resolver.cache.put((qname2, dns.rdatatype.ANY, + dns.rdataclass.IN), cache_answer) + try: + (request, answer) = self.resn.next_request() + self.assertTrue(False) # should not happen! + except dns.resolver.NXDOMAIN as nx: + self.assertTrue(nx.response(qname1) is r1) + self.assertTrue(nx.response(qname2) is r2) + def test_next_nameserver_udp(self): (request, answer) = self.resn.next_request() (nameserver1, port, tcp, backoff) = self.resn.next_nameserver() @@ -241,6 +270,19 @@ class ResolutionTestCase(unittest.TestCase): self.assertTrue(answer is None) self.assertTrue(done) + def test_query_result_nxdomain_cached(self): + self.resolver.cache = dns.resolver.Cache() + q = dns.message.make_query(self.qname, dns.rdatatype.A) + r = self.make_negative_response(q, True) + (_, _) = self.resn.next_request() + (_, _, _, _) = self.resn.next_nameserver() + (answer, done) = self.resn.query_result(r, None) + self.assertTrue(answer is None) + self.assertTrue(done) + cache_answer = self.resolver.cache.get((self.qname, dns.rdatatype.ANY, + dns.rdataclass.IN)) + self.assertTrue(cache_answer.response is r) + def test_query_result_yxdomain(self): q = dns.message.make_query(self.qname, dns.rdatatype.A) r = self.make_address_response(q) |