summaryrefslogtreecommitdiff
path: root/tests/test_query.py
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2020-06-01 10:09:37 -0700
committerBrian Wellington <bwelling@xbill.org>2020-06-01 10:09:37 -0700
commit52f46fc0d23d760f6b4b1e5d8d46cda031ed86e9 (patch)
tree0e81246dcc53c2c23ee33fc0f00e369bf837d0e5 /tests/test_query.py
parentc19e6716a86593528ae3d4904bf1cefcfcd477b9 (diff)
downloaddnspython-52f46fc0d23d760f6b4b1e5d8d46cda031ed86e9.tar.gz
Adds sock parameters to query methods.
Allow passing a socket into dns.query.{udp,tcp,tls,udp_with_fallback}, and add tests for this.
Diffstat (limited to 'tests/test_query.py')
-rw-r--r--tests/test_query.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/tests/test_query.py b/tests/test_query.py
index 9c63217..e031cfd 100644
--- a/tests/test_query.py
+++ b/tests/test_query.py
@@ -18,6 +18,12 @@
import socket
import unittest
+try:
+ import ssl
+ have_ssl = True
+except Exception:
+ have_ssl = False
+
import dns.message
import dns.name
import dns.rdataclass
@@ -46,6 +52,19 @@ class QueryTests(unittest.TestCase):
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryUDPWithSocket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
+ s.setblocking(0)
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.udp(q, '8.8.8.8', sock=s)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryTCP(self):
qname = dns.name.from_text('dns.google.')
q = dns.message.make_query(qname, dns.rdatatype.A)
@@ -57,6 +76,20 @@ class QueryTests(unittest.TestCase):
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ def testQueryTCPWithSocket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.connect(('8.8.8.8', 53))
+ s.setblocking(0)
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tcp(q, None, sock=s)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryTLS(self):
qname = dns.name.from_text('dns.google.')
q = dns.message.make_query(qname, dns.rdatatype.A)
@@ -68,12 +101,42 @@ class QueryTests(unittest.TestCase):
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
+ @unittest.skipUnless(have_ssl, "No SSL support")
+ def testQueryTLSWithSocket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
+ s.connect(('8.8.8.8', 853))
+ ctx = ssl.create_default_context()
+ s = ctx.wrap_socket(s, server_hostname='dns.google')
+ s.setblocking(0)
+ qname = dns.name.from_text('dns.google.')
+ q = dns.message.make_query(qname, dns.rdatatype.A)
+ response = dns.query.tls(q, None, sock=s)
+ rrs = response.get_rrset(response.answer, qname,
+ dns.rdataclass.IN, dns.rdatatype.A)
+ self.assertTrue(rrs is not None)
+ seen = set([rdata.address for rdata in rrs])
+ self.assertTrue('8.8.8.8' in seen)
+ self.assertTrue('8.8.4.4' in seen)
+
def testQueryUDPFallback(self):
qname = dns.name.from_text('.')
q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
(_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8')
self.assertTrue(tcp)
+ def testQueryUDPFallbackWithSocket(self):
+ with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as udp_s:
+ udp_s.setblocking(0)
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_s:
+ tcp_s.connect(('8.8.8.8', 53))
+ tcp_s.setblocking(0)
+ qname = dns.name.from_text('.')
+ q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
+ (_, tcp) = dns.query.udp_with_fallback(q, '8.8.8.8',
+ udp_sock=udp_s,
+ tcp_sock=tcp_s)
+ self.assertTrue(tcp)
+
def testQueryUDPFallbackNoFallback(self):
qname = dns.name.from_text('dns.google.')
q = dns.message.make_query(qname, dns.rdatatype.A)