summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--dns/resolver.py103
-rw-r--r--tests/test_resolution.py46
2 files changed, 104 insertions, 45 deletions
diff --git a/dns/resolver.py b/dns/resolver.py
index e50eab8..cc1f78b 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -536,51 +536,61 @@ class _Resolution(object):
"""
# We return a tuple instead of Union[Message,Answer] as it lets
- # the caller avoid isinstance.
+ # the caller avoid isinstance().
- if len(self.qnames) == 0:
- #
- # We've tried everything and only gotten NXDOMAINs. (We know
- # it's only NXDOMAINs as anything else would have returned
- # before now.)
- #
- raise NXDOMAIN(qnames=self.qnames_to_try,
- responses=self.nxdomain_responses)
-
- self.qname = self.qnames.pop(0)
-
- # Do we know the answer?
- if self.resolver.cache:
- answer = self.resolver.cache.get((self.qname, self.rdtype,
- self.rdclass))
- if answer is not None:
- if answer.rrset is None and self.raise_on_no_answer:
- raise NoAnswer(response=answer.response)
- else:
- return (None, answer)
-
- # Build the request
- request = dns.message.make_query(self.qname, self.rdtype, self.rdclass)
- if self.resolver.keyname is not None:
- request.use_tsig(self.resolver.keyring, self.resolver.keyname,
- algorithm=self.resolver.keyalgorithm)
- request.use_edns(self.resolver.edns, self.resolver.ednsflags,
- self.resolver.payload)
- if self.resolver.flags is not None:
- request.flags = self.resolver.flags
-
- self.nameservers = self.resolver.nameservers[:]
- if self.resolver.rotate:
- random.shuffle(self.nameservers)
- self.current_nameservers = self.nameservers[:]
- self.errors = []
- self.nameserver = None
- self.tcp_attempt = False
- self.retry_with_tcp = False
- self.request = request
- self.backoff = 0.10
+ while len(self.qnames) > 0:
+ self.qname = self.qnames.pop(0)
+
+ # Do we know the answer?
+ if self.resolver.cache:
+ answer = self.resolver.cache.get((self.qname, self.rdtype,
+ self.rdclass))
+ if answer is not None:
+ if answer.rrset is None and self.raise_on_no_answer:
+ raise NoAnswer(response=answer.response)
+ else:
+ return (None, answer)
+ answer = self.resolver.cache.get((self.qname,
+ dns.rdatatype.ANY,
+ self.rdclass))
+ if answer is not None and \
+ answer.response.rcode() == dns.rcode.NXDOMAIN:
+ # cached NXDOMAIN; record it and continue to next
+ # name.
+ self.nxdomain_responses[self.qname] = answer.response
+ continue
+
+ # Build the request
+ request = dns.message.make_query(self.qname, self.rdtype,
+ self.rdclass)
+ if self.resolver.keyname is not None:
+ request.use_tsig(self.resolver.keyring, self.resolver.keyname,
+ algorithm=self.resolver.keyalgorithm)
+ request.use_edns(self.resolver.edns, self.resolver.ednsflags,
+ self.resolver.payload)
+ if self.resolver.flags is not None:
+ request.flags = self.resolver.flags
+
+ self.nameservers = self.resolver.nameservers[:]
+ if self.resolver.rotate:
+ random.shuffle(self.nameservers)
+ self.current_nameservers = self.nameservers[:]
+ self.errors = []
+ self.nameserver = None
+ self.tcp_attempt = False
+ self.retry_with_tcp = False
+ self.request = request
+ self.backoff = 0.10
+
+ return (request, None)
- return (request, None)
+ #
+ # We've tried everything and only gotten NXDOMAINs. (We know
+ # it's only NXDOMAINs as anything else would have returned
+ # before now.)
+ #
+ raise NXDOMAIN(qnames=self.qnames_to_try,
+ responses=self.nxdomain_responses)
def next_nameserver(self):
if self.retry_with_tcp:
@@ -641,6 +651,13 @@ class _Resolution(object):
self.nxdomain_responses[self.qname] = response
# Make next_nameserver() return None, so caller breaks its
# inner loop and calls next_request().
+ if self.resolver.cache:
+ answer = Answer(self.qname, dns.rdatatype.ANY,
+ dns.rdataclass.IN, response)
+ self.resolver.cache.put((self.qname,
+ dns.rdatatype.ANY,
+ self.rdclass), answer)
+
return (None, True)
elif rcode == dns.rcode.YXDOMAIN:
yex = YXDOMAIN()
diff --git a/tests/test_resolution.py b/tests/test_resolution.py
index 95dd9ae..bb1c4b1 100644
--- a/tests/test_resolution.py
+++ b/tests/test_resolution.py
@@ -56,7 +56,7 @@ class ResolutionTestCase(unittest.TestCase):
def make_negative_response(self, q, nxdomain=False):
r = dns.message.make_response(q)
- rrs = r.get_rrset(r.authority, self.qname, dns.rdataclass.IN,
+ rrs = r.get_rrset(r.authority, q.question[0].name, dns.rdataclass.IN,
dns.rdatatype.SOA, create=True)
rrs.add(dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.SOA,
'. . 1 2 3 4 300'), 300)
@@ -76,7 +76,7 @@ class ResolutionTestCase(unittest.TestCase):
self.assertTrue(request is None)
self.assertTrue(answer is cache_answer)
- def test_next_request_no_answer(self):
+ def test_next_request_cached_no_answer(self):
# In default mode, we should raise on a no-answer hit
self.resolver.cache = dns.resolver.Cache()
q = dns.message.make_query(self.qname, dns.rdatatype.A)
@@ -98,6 +98,35 @@ class ResolutionTestCase(unittest.TestCase):
self.assertTrue(request is None)
self.assertTrue(answer is cache_answer)
+ def test_next_request_cached_nxdomain(self):
+ # use a relative qname so we have two qnames to try
+ qname = dns.name.from_text('www.dnspython.org', None)
+ self.resn = dns.resolver._Resolution(self.resolver, qname,
+ 'A', 'IN',
+ False, True, False)
+ qname1 = dns.name.from_text('www.dnspython.org.example.')
+ qname2 = dns.name.from_text('www.dnspython.org.')
+ # Arrange to get NXDOMAIN hits on both of those qnames.
+ self.resolver.cache = dns.resolver.Cache()
+ q1 = dns.message.make_query(qname1, dns.rdatatype.A)
+ r1 = self.make_negative_response(q1, True)
+ cache_answer = dns.resolver.Answer(qname1, dns.rdatatype.ANY,
+ dns.rdataclass.IN, r1)
+ self.resolver.cache.put((qname1, dns.rdatatype.ANY,
+ dns.rdataclass.IN), cache_answer)
+ q2 = dns.message.make_query(qname2, dns.rdatatype.A)
+ r2 = self.make_negative_response(q2, True)
+ cache_answer = dns.resolver.Answer(qname2, dns.rdatatype.ANY,
+ dns.rdataclass.IN, r2)
+ self.resolver.cache.put((qname2, dns.rdatatype.ANY,
+ dns.rdataclass.IN), cache_answer)
+ try:
+ (request, answer) = self.resn.next_request()
+ self.assertTrue(False) # should not happen!
+ except dns.resolver.NXDOMAIN as nx:
+ self.assertTrue(nx.response(qname1) is r1)
+ self.assertTrue(nx.response(qname2) is r2)
+
def test_next_nameserver_udp(self):
(request, answer) = self.resn.next_request()
(nameserver1, port, tcp, backoff) = self.resn.next_nameserver()
@@ -241,6 +270,19 @@ class ResolutionTestCase(unittest.TestCase):
self.assertTrue(answer is None)
self.assertTrue(done)
+ def test_query_result_nxdomain_cached(self):
+ self.resolver.cache = dns.resolver.Cache()
+ q = dns.message.make_query(self.qname, dns.rdatatype.A)
+ r = self.make_negative_response(q, True)
+ (_, _) = self.resn.next_request()
+ (_, _, _, _) = self.resn.next_nameserver()
+ (answer, done) = self.resn.query_result(r, None)
+ self.assertTrue(answer is None)
+ self.assertTrue(done)
+ cache_answer = self.resolver.cache.get((self.qname, dns.rdatatype.ANY,
+ dns.rdataclass.IN))
+ self.assertTrue(cache_answer.response is r)
+
def test_query_result_yxdomain(self):
q = dns.message.make_query(self.qname, dns.rdatatype.A)
r = self.make_address_response(q)