summaryrefslogtreecommitdiff
path: root/tests/test_async.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_async.py')
-rw-r--r--tests/test_async.py331
1 files changed, 203 insertions, 128 deletions
diff --git a/tests/test_async.py b/tests/test_async.py
index ce0caa1..3c9a7e6 100644
--- a/tests/test_async.py
+++ b/tests/test_async.py
@@ -46,7 +46,7 @@ except Exception:
# skip those if it's not there.
_network_available = True
try:
- socket.gethostbyname('dnspython.org')
+ socket.gethostbyname("dnspython.org")
except socket.gaierror:
_network_available = False
@@ -57,15 +57,17 @@ except socket.gaierror:
_systemd_resolved_present = False
try:
_resolver = dns.resolver.Resolver()
- if _resolver.nameservers == ['127.0.0.53']:
+ if _resolver.nameservers == ["127.0.0.53"]:
_systemd_resolved_present = True
except Exception:
pass
# Probe for IPv4 and IPv6
query_addresses = []
-for (af, address) in ((socket.AF_INET, '8.8.8.8'),
- (socket.AF_INET6, '2001:4860:4860::8888')):
+for (af, address) in (
+ (socket.AF_INET, "8.8.8.8"),
+ (socket.AF_INET6, "2001:4860:4860::8888"),
+):
try:
with socket.socket(af, socket.SOCK_DGRAM) as s:
# Connecting a UDP socket is supposed to return ENETUNREACH if
@@ -75,31 +77,37 @@ for (af, address) in ((socket.AF_INET, '8.8.8.8'),
except Exception:
pass
-KNOWN_ANYCAST_DOH_RESOLVER_URLS = ['https://cloudflare-dns.com/dns-query',
- 'https://dns.google/dns-query',
- # 'https://dns11.quad9.net/dns-query',
- ]
+KNOWN_ANYCAST_DOH_RESOLVER_URLS = [
+ "https://cloudflare-dns.com/dns-query",
+ "https://dns.google/dns-query",
+ # 'https://dns11.quad9.net/dns-query',
+]
class AsyncDetectionTests(unittest.TestCase):
- sniff_result = 'asyncio'
+ sniff_result = "asyncio"
def async_run(self, afunc):
return asyncio.run(afunc())
def test_sniff(self):
dns.asyncbackend._default_backend = None
+
async def run():
self.assertEqual(dns.asyncbackend.sniff(), self.sniff_result)
+
self.async_run(run)
def test_get_default_backend(self):
dns.asyncbackend._default_backend = None
+
async def run():
backend = dns.asyncbackend.get_default_backend()
self.assertEqual(backend.name(), self.sniff_result)
+
self.async_run(run)
+
class NoSniffioAsyncDetectionTests(AsyncDetectionTests):
expect_raise = False
@@ -112,10 +120,13 @@ class NoSniffioAsyncDetectionTests(AsyncDetectionTests):
def test_sniff(self):
dns.asyncbackend._default_backend = None
if self.expect_raise:
+
async def abad():
dns.asyncbackend.sniff()
+
def bad():
self.async_run(abad)
+
self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad)
else:
super().test_sniff()
@@ -123,10 +134,13 @@ class NoSniffioAsyncDetectionTests(AsyncDetectionTests):
def test_get_default_backend(self):
dns.asyncbackend._default_backend = None
if self.expect_raise:
+
async def abad():
dns.asyncbackend.get_default_backend()
+
def bad():
self.async_run(abad)
+
self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad)
else:
super().test_get_default_backend()
@@ -135,13 +149,16 @@ class NoSniffioAsyncDetectionTests(AsyncDetectionTests):
class MiscBackend(unittest.TestCase):
def test_sniff_without_run_loop(self):
dns.asyncbackend._default_backend = None
+
def bad():
dns.asyncbackend.sniff()
+
self.assertRaises(dns.asyncbackend.AsyncLibraryNotFoundError, bad)
def test_bogus_backend(self):
def bad():
- dns.asyncbackend.get_backend('bogus')
+ dns.asyncbackend.get_backend("bogus")
+
self.assertRaises(NotImplementedError, bad)
@@ -151,256 +168,297 @@ class MiscQuery(unittest.TestCase):
self.assertEqual(t, None)
t = dns.asyncquery._source_tuple(socket.AF_INET6, None, 0)
self.assertEqual(t, None)
- t = dns.asyncquery._source_tuple(socket.AF_INET, '1.2.3.4', 53)
- self.assertEqual(t, ('1.2.3.4', 53))
- t = dns.asyncquery._source_tuple(socket.AF_INET6, '1::2', 53)
- self.assertEqual(t, ('1::2', 53))
+ t = dns.asyncquery._source_tuple(socket.AF_INET, "1.2.3.4", 53)
+ self.assertEqual(t, ("1.2.3.4", 53))
+ t = dns.asyncquery._source_tuple(socket.AF_INET6, "1::2", 53)
+ self.assertEqual(t, ("1::2", 53))
t = dns.asyncquery._source_tuple(socket.AF_INET, None, 53)
- self.assertEqual(t, ('0.0.0.0', 53))
+ self.assertEqual(t, ("0.0.0.0", 53))
t = dns.asyncquery._source_tuple(socket.AF_INET6, None, 53)
- self.assertEqual(t, ('::', 53))
+ self.assertEqual(t, ("::", 53))
@unittest.skipIf(not _network_available, "Internet not reachable")
class AsyncTests(unittest.TestCase):
- connect_udp = sys.platform == 'win32'
+ connect_udp = sys.platform == "win32"
def setUp(self):
- self.backend = dns.asyncbackend.set_default_backend('asyncio')
+ self.backend = dns.asyncbackend.set_default_backend("asyncio")
def async_run(self, afunc):
return asyncio.run(afunc())
def testResolve(self):
async def run():
- answer = await dns.asyncresolver.resolve('dns.google.', 'A')
+ answer = await dns.asyncresolver.resolve("dns.google.", "A")
return set([rdata.address for rdata in answer])
+
seen = self.async_run(run)
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
def testResolveAddress(self):
async def run():
- return await dns.asyncresolver.resolve_address('8.8.8.8')
+ return await dns.asyncresolver.resolve_address("8.8.8.8")
+
answer = self.async_run(run)
- dnsgoogle = dns.name.from_text('dns.google.')
+ dnsgoogle = dns.name.from_text("dns.google.")
self.assertEqual(answer[0].target, dnsgoogle)
def testCanonicalNameNoCNAME(self):
- cname = dns.name.from_text('www.google.com')
+ cname = dns.name.from_text("www.google.com")
+
async def run():
- return await dns.asyncresolver.canonical_name('www.google.com')
+ return await dns.asyncresolver.canonical_name("www.google.com")
+
self.assertEqual(self.async_run(run), cname)
def testCanonicalNameCNAME(self):
- name = dns.name.from_text('www.dnspython.org')
- cname = dns.name.from_text('dmfrjf4ips8xa.cloudfront.net')
+ name = dns.name.from_text("www.dnspython.org")
+ cname = dns.name.from_text("dmfrjf4ips8xa.cloudfront.net")
+
async def run():
return await dns.asyncresolver.canonical_name(name)
+
self.assertEqual(self.async_run(run), cname)
@unittest.skipIf(_systemd_resolved_present, "systemd-resolved in use")
def testCanonicalNameDangling(self):
- name = dns.name.from_text('dangling-cname.dnspython.org')
- cname = dns.name.from_text('dangling-target.dnspython.org')
+ name = dns.name.from_text("dangling-cname.dnspython.org")
+ cname = dns.name.from_text("dangling-target.dnspython.org")
+
async def run():
return await dns.asyncresolver.canonical_name(name)
- self.assertEqual(self.async_run(run), cname)
+ self.assertEqual(self.async_run(run), cname)
def testZoneForName1(self):
async def run():
- name = dns.name.from_text('www.dnspython.org.')
+ name = dns.name.from_text("www.dnspython.org.")
return await dns.asyncresolver.zone_for_name(name)
- ezname = dns.name.from_text('dnspython.org.')
+
+ ezname = dns.name.from_text("dnspython.org.")
zname = self.async_run(run)
self.assertEqual(zname, ezname)
def testZoneForName2(self):
async def run():
- name = dns.name.from_text('a.b.www.dnspython.org.')
+ name = dns.name.from_text("a.b.www.dnspython.org.")
return await dns.asyncresolver.zone_for_name(name)
- ezname = dns.name.from_text('dnspython.org.')
+
+ ezname = dns.name.from_text("dnspython.org.")
zname = self.async_run(run)
self.assertEqual(zname, ezname)
def testZoneForName3(self):
async def run():
- name = dns.name.from_text('dnspython.org.')
+ name = dns.name.from_text("dnspython.org.")
return await dns.asyncresolver.zone_for_name(name)
- ezname = dns.name.from_text('dnspython.org.')
+
+ ezname = dns.name.from_text("dnspython.org.")
zname = self.async_run(run)
self.assertEqual(zname, ezname)
def testZoneForName4(self):
def bad():
- name = dns.name.from_text('dnspython.org', None)
+ name = dns.name.from_text("dnspython.org", None)
+
async def run():
return await dns.asyncresolver.zone_for_name(name)
+
self.async_run(run)
+
self.assertRaises(dns.resolver.NotAbsolute, bad)
def testQueryUDP(self):
for address in query_addresses:
- qname = dns.name.from_text('dns.google.')
+ qname = dns.name.from_text("dns.google.")
+
async def run():
q = dns.message.make_query(qname, dns.rdatatype.A)
return await dns.asyncquery.udp(q, address, timeout=2)
+
response = self.async_run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
self.assertTrue(rrs is not None)
seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
def testQueryUDPWithSocket(self):
for address in query_addresses:
- qname = dns.name.from_text('dns.google.')
+ qname = dns.name.from_text("dns.google.")
+
async def run():
if self.connect_udp:
- dtuple=(address, 53)
+ dtuple = (address, 53)
else:
- dtuple=None
+ dtuple = None
async with await self.backend.make_socket(
- dns.inet.af_for_address(address),
- socket.SOCK_DGRAM, 0, None, dtuple) as s:
+ dns.inet.af_for_address(address), socket.SOCK_DGRAM, 0, None, dtuple
+ ) as s:
q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.asyncquery.udp(q, address, sock=s,
- timeout=2)
+ return await dns.asyncquery.udp(q, address, sock=s, timeout=2)
+
response = self.async_run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
self.assertTrue(rrs is not None)
seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
def testQueryTCP(self):
for address in query_addresses:
- qname = dns.name.from_text('dns.google.')
+ qname = dns.name.from_text("dns.google.")
+
async def run():
q = dns.message.make_query(qname, dns.rdatatype.A)
return await dns.asyncquery.tcp(q, address, timeout=2)
+
response = self.async_run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
self.assertTrue(rrs is not None)
seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
def testQueryTCPWithSocket(self):
for address in query_addresses:
- qname = dns.name.from_text('dns.google.')
+ qname = dns.name.from_text("dns.google.")
+
async def run():
async with await self.backend.make_socket(
- dns.inet.af_for_address(address),
- socket.SOCK_STREAM, 0,
- None,
- (address, 53), 2) as s:
+ dns.inet.af_for_address(address),
+ socket.SOCK_STREAM,
+ 0,
+ None,
+ (address, 53),
+ 2,
+ ) as s:
# for basic coverage
await s.getsockname()
q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.asyncquery.tcp(q, address, sock=s,
- timeout=2)
+ return await dns.asyncquery.tcp(q, address, sock=s, timeout=2)
+
response = self.async_run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
self.assertTrue(rrs is not None)
seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
@unittest.skipIf(not _ssl_available, "SSL not available")
def testQueryTLS(self):
for address in query_addresses:
- qname = dns.name.from_text('dns.google.')
+ qname = dns.name.from_text("dns.google.")
+
async def run():
q = dns.message.make_query(qname, dns.rdatatype.A)
return await dns.asyncquery.tls(q, address, timeout=2)
+
response = self.async_run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
self.assertTrue(rrs is not None)
seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
@unittest.skipIf(not _ssl_available, "SSL not available")
def testQueryTLSWithSocket(self):
for address in query_addresses:
- qname = dns.name.from_text('dns.google.')
+ qname = dns.name.from_text("dns.google.")
+
async def run():
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
async with await self.backend.make_socket(
- dns.inet.af_for_address(address),
- socket.SOCK_STREAM, 0,
- None,
- (address, 853), 2,
- ssl_context, None) as s:
+ dns.inet.af_for_address(address),
+ socket.SOCK_STREAM,
+ 0,
+ None,
+ (address, 853),
+ 2,
+ ssl_context,
+ None,
+ ) as s:
# for basic coverage
await s.getsockname()
q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.asyncquery.tls(q, '8.8.8.8', sock=s,
- timeout=2)
+ return await dns.asyncquery.tls(q, "8.8.8.8", sock=s, timeout=2)
+
response = self.async_run(run)
- rrs = response.get_rrset(response.answer, qname,
- dns.rdataclass.IN, dns.rdatatype.A)
+ rrs = response.get_rrset(
+ response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A
+ )
self.assertTrue(rrs is not None)
seen = set([rdata.address for rdata in rrs])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
def testQueryUDPFallback(self):
for address in query_addresses:
- qname = dns.name.from_text('.')
+ qname = dns.name.from_text(".")
+
async def run():
q = dns.message.make_query(qname, dns.rdatatype.DNSKEY)
- return await dns.asyncquery.udp_with_fallback(q, address,
- timeout=2)
+ return await dns.asyncquery.udp_with_fallback(q, address, timeout=2)
+
(_, tcp) = self.async_run(run)
self.assertTrue(tcp)
def testQueryUDPFallbackNoFallback(self):
for address in query_addresses:
- qname = dns.name.from_text('dns.google.')
+ qname = dns.name.from_text("dns.google.")
+
async def run():
q = dns.message.make_query(qname, dns.rdatatype.A)
- return await dns.asyncquery.udp_with_fallback(q, address,
- timeout=2)
+ return await dns.asyncquery.udp_with_fallback(q, address, timeout=2)
+
(_, tcp) = self.async_run(run)
self.assertFalse(tcp)
def testUDPReceiveQuery(self):
if self.connect_udp:
- self.skipTest('test needs connectionless sockets')
+ self.skipTest("test needs connectionless sockets")
+
async def run():
async with await self.backend.make_socket(
- socket.AF_INET, socket.SOCK_DGRAM,
- source=('127.0.0.1', 0)) as listener:
+ socket.AF_INET, socket.SOCK_DGRAM, source=("127.0.0.1", 0)
+ ) as listener:
listener_address = await listener.getsockname()
async with await self.backend.make_socket(
- socket.AF_INET, socket.SOCK_DGRAM,
- source=('127.0.0.1', 0)) as sender:
+ socket.AF_INET, socket.SOCK_DGRAM, source=("127.0.0.1", 0)
+ ) as sender:
sender_address = await sender.getsockname()
- q = dns.message.make_query('dns.google', dns.rdatatype.A)
+ q = dns.message.make_query("dns.google", dns.rdatatype.A)
await dns.asyncquery.send_udp(sender, q, listener_address)
expiration = time.time() + 2
(_, _, recv_address) = await dns.asyncquery.receive_udp(
- listener, expiration=expiration)
+ listener, expiration=expiration
+ )
return (sender_address, recv_address)
+
(sender_address, recv_address) = self.async_run(run)
self.assertEqual(sender_address, recv_address)
def testUDPReceiveTimeout(self):
if self.connect_udp:
- self.skipTest('test needs connectionless sockets')
+ self.skipTest("test needs connectionless sockets")
+
async def arun():
- async with await self.backend.make_socket(socket.AF_INET,
- socket.SOCK_DGRAM, 0,
- ('127.0.0.1', 0)) as s:
+ async with await self.backend.make_socket(
+ socket.AF_INET, socket.SOCK_DGRAM, 0, ("127.0.0.1", 0)
+ ) as s:
try:
# for basic coverage
await s.getpeername()
@@ -408,62 +466,69 @@ class AsyncTests(unittest.TestCase):
# we expect failure as we haven't connected the socket
pass
await s.recvfrom(1000, 0.05)
+
def run():
self.async_run(arun)
+
self.assertRaises(dns.exception.Timeout, run)
@unittest.skipIf(not dns.query._have_httpx, "httpx not available")
def testDOHGetRequest(self):
- if self.backend.name() == 'curio':
- self.skipTest('anyio dropped curio support')
+ if self.backend.name() == "curio":
+ self.skipTest("anyio dropped curio support")
+
async def run():
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
- q = dns.message.make_query('example.com.', dns.rdatatype.A)
- r = await dns.asyncquery.https(q, nameserver_url, post=False,
- timeout=4)
+ q = dns.message.make_query("example.com.", dns.rdatatype.A)
+ r = await dns.asyncquery.https(q, nameserver_url, post=False, timeout=4)
self.assertTrue(q.is_response(r))
+
self.async_run(run)
@unittest.skipIf(not dns.query._have_httpx, "httpx not available")
def testDOHGetRequestHttp1(self):
- if self.backend.name() == 'curio':
- self.skipTest('anyio dropped curio support')
+ if self.backend.name() == "curio":
+ self.skipTest("anyio dropped curio support")
+
async def run():
saved_have_http2 = dns.query._have_http2
try:
dns.query._have_http2 = False
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
- q = dns.message.make_query('example.com.', dns.rdatatype.A)
- r = await dns.asyncquery.https(q, nameserver_url, post=False,
- timeout=4)
+ q = dns.message.make_query("example.com.", dns.rdatatype.A)
+ r = await dns.asyncquery.https(q, nameserver_url, post=False, timeout=4)
self.assertTrue(q.is_response(r))
finally:
dns.query._have_http2 = saved_have_http2
+
self.async_run(run)
@unittest.skipIf(not dns.query._have_httpx, "httpx not available")
def testDOHPostRequest(self):
- if self.backend.name() == 'curio':
- self.skipTest('anyio dropped curio support')
+ if self.backend.name() == "curio":
+ self.skipTest("anyio dropped curio support")
+
async def run():
nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS)
- q = dns.message.make_query('example.com.', dns.rdatatype.A)
- r = await dns.asyncquery.https(q, nameserver_url, post=True,
- timeout=4)
+ q = dns.message.make_query("example.com.", dns.rdatatype.A)
+ r = await dns.asyncquery.https(q, nameserver_url, post=True, timeout=4)
self.assertTrue(q.is_response(r))
+
self.async_run(run)
@unittest.skipIf(not dns.query._have_httpx, "httpx not available")
def testResolverDOH(self):
- if self.backend.name() == 'curio':
- self.skipTest('anyio dropped curio support')
+ if self.backend.name() == "curio":
+ self.skipTest("anyio dropped curio support")
+
async def run():
res = dns.asyncresolver.Resolver(configure=False)
- res.nameservers = ['https://dns.google/dns-query']
- answer = await res.resolve('dns.google', 'A', backend=self.backend)
+ res.nameservers = ["https://dns.google/dns-query"]
+ answer = await res.resolve("dns.google", "A", backend=self.backend)
seen = set([rdata.address for rdata in answer])
- self.assertTrue('8.8.8.8' in seen)
- self.assertTrue('8.8.4.4' in seen)
+ self.assertTrue("8.8.8.8" in seen)
+ self.assertTrue("8.8.4.4" in seen)
+
self.async_run(run)
def testSleep(self):
@@ -472,29 +537,35 @@ class AsyncTests(unittest.TestCase):
await self.backend.sleep(0.1)
after = time.time()
self.assertTrue(after - before >= 0.1)
+
self.async_run(run)
+
try:
import trio
import sniffio
class TrioAsyncDetectionTests(AsyncDetectionTests):
- sniff_result = 'trio'
+ sniff_result = "trio"
+
def async_run(self, afunc):
return trio.run(afunc)
class TrioNoSniffioAsyncDetectionTests(NoSniffioAsyncDetectionTests):
expect_raise = True
+
def async_run(self, afunc):
return trio.run(afunc)
class TrioAsyncTests(AsyncTests):
connect_udp = False
+
def setUp(self):
- self.backend = dns.asyncbackend.set_default_backend('trio')
+ self.backend = dns.asyncbackend.set_default_backend("trio")
def async_run(self, afunc):
return trio.run(afunc)
+
except ImportError:
pass
@@ -503,21 +574,25 @@ try:
import sniffio
class CurioAsyncDetectionTests(AsyncDetectionTests):
- sniff_result = 'curio'
+ sniff_result = "curio"
+
def async_run(self, afunc):
return curio.run(afunc)
class CurioNoSniffioAsyncDetectionTests(NoSniffioAsyncDetectionTests):
expect_raise = True
+
def async_run(self, afunc):
return curio.run(afunc)
class CurioAsyncTests(AsyncTests):
connect_udp = False
+
def setUp(self):
- self.backend = dns.asyncbackend.set_default_backend('curio')
+ self.backend = dns.asyncbackend.set_default_backend("curio")
def async_run(self, afunc):
return curio.run(afunc)
+
except ImportError:
pass