From 3baf45dcfd00914b2a0a93963d39bded50b9c526 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Fri, 9 Oct 2020 09:35:14 -0700 Subject: First pass at adding network timeouts to tests. This is for when it looks like we have a network but it's not connected to the Internet. --- tests/test_async.py | 25 +++++++++++++++---------- tests/test_doh.py | 18 +++++++++++------- tests/test_query.py | 26 +++++++++++++++----------- 3 files changed, 41 insertions(+), 28 deletions(-) diff --git a/tests/test_async.py b/tests/test_async.py index 690a1eb..e9a26bb 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -248,7 +248,7 @@ class AsyncTests(unittest.TestCase): qname = dns.name.from_text('dns.google.') async def run(): q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.udp(q, address) + return await dns.asyncquery.udp(q, address, timeout=2) response = self.async_run(run) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) @@ -265,7 +265,8 @@ class AsyncTests(unittest.TestCase): dns.inet.af_for_address(address), socket.SOCK_DGRAM) as s: q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.udp(q, address, sock=s) + return await dns.asyncquery.udp(q, address, sock=s, + timeout=2) response = self.async_run(run) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) @@ -279,7 +280,7 @@ class AsyncTests(unittest.TestCase): qname = dns.name.from_text('dns.google.') async def run(): q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.tcp(q, address) + return await dns.asyncquery.tcp(q, address, timeout=2) response = self.async_run(run) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) @@ -296,11 +297,12 @@ class AsyncTests(unittest.TestCase): dns.inet.af_for_address(address), socket.SOCK_STREAM, 0, None, - (address, 53)) as s: + (address, 53), 2) as s: # for basic coverage await s.getsockname() q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.tcp(q, address, sock=s) + return await dns.asyncquery.tcp(q, address, sock=s, + timeout=2) response = self.async_run(run) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) @@ -315,7 +317,7 @@ class AsyncTests(unittest.TestCase): qname = dns.name.from_text('dns.google.') async def run(): q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.tls(q, address) + return await dns.asyncquery.tls(q, address, timeout=2) response = self.async_run(run) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) @@ -335,12 +337,13 @@ class AsyncTests(unittest.TestCase): dns.inet.af_for_address(address), socket.SOCK_STREAM, 0, None, - (address, 853), None, + (address, 853), 2, ssl_context, None) as s: # for basic coverage await s.getsockname() q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.tls(q, '8.8.8.8', sock=s) + return await dns.asyncquery.tls(q, '8.8.8.8', sock=s, + timeout=2) response = self.async_run(run) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) @@ -354,7 +357,8 @@ class AsyncTests(unittest.TestCase): qname = dns.name.from_text('.') async def run(): q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) - return await dns.asyncquery.udp_with_fallback(q, address) + return await dns.asyncquery.udp_with_fallback(q, address, + timeout=2) (_, tcp) = self.async_run(run) self.assertTrue(tcp) @@ -363,7 +367,8 @@ class AsyncTests(unittest.TestCase): qname = dns.name.from_text('dns.google.') async def run(): q = dns.message.make_query(qname, dns.rdatatype.A) - return await dns.asyncquery.udp_with_fallback(q, address) + return await dns.asyncquery.udp_with_fallback(q, address, + timeout=2) (_, tcp) = self.async_run(run) self.assertFalse(tcp) diff --git a/tests/test_doh.py b/tests/test_doh.py index c5c0569..793a500 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -32,6 +32,7 @@ resolver_v4_addresses = [] resolver_v6_addresses = [] try: with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: + s.settimeout(4) s.connect(('8.8.8.8', 53)) resolver_v4_addresses = [ '1.1.1.1', @@ -77,13 +78,15 @@ class DNSOverHTTPSTestCase(unittest.TestCase): def test_get_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, session=self.session, post=False) + r = dns.query.https(q, nameserver_url, session=self.session, post=False, + timeout=4) self.assertTrue(q.is_response(r)) def test_post_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url, session=self.session, post=True) + r = dns.query.https(q, nameserver_url, session=self.session, post=True, + timeout=4) self.assertTrue(q.is_response(r)) def test_build_url_from_ip(self): @@ -95,14 +98,14 @@ class DNSOverHTTPSTestCase(unittest.TestCase): # https://8.8.8.8/dns-query # So we're just going to do GET requests here r = dns.query.https(q, nameserver_ip, session=self.session, - post=False) + post=False, timeout=4) self.assertTrue(q.is_response(r)) if resolver_v6_addresses: nameserver_ip = random.choice(resolver_v6_addresses) q = dns.message.make_query('example.com.', dns.rdatatype.A) r = dns.query.https(q, nameserver_ip, session=self.session, - post=False) + post=False, timeout=4) self.assertTrue(q.is_response(r)) def test_bootstrap_address(self): @@ -115,16 +118,17 @@ class DNSOverHTTPSTestCase(unittest.TestCase): # make sure CleanBrowsing's IP address will fail TLS certificate # check with self.assertRaises(SSLError): - dns.query.https(q, invalid_tls_url, session=self.session) + dns.query.https(q, invalid_tls_url, session=self.session, + timeout=4) # use host header r = dns.query.https(q, valid_tls_url, session=self.session, - bootstrap_address=ip) + bootstrap_address=ip, timeout=4) self.assertTrue(q.is_response(r)) def test_new_session(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query('example.com.', dns.rdatatype.A) - r = dns.query.https(q, nameserver_url) + r = dns.query.https(q, nameserver_url, timeout=4) self.assertTrue(q.is_response(r)) def test_resolver(self): diff --git a/tests/test_query.py b/tests/test_query.py index 7a1ec71..8f2b65f 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -68,7 +68,7 @@ for (af, address) in ((socket.AF_INET, '8.8.8.8'), except Exception: pass -keyring = dns.tsigkeyring.from_text({'name' : 'tDz6cfXXGtNivRpQ98hr6A=='}) +keyring = dns.tsigkeyring.from_text({'name': 'tDz6cfXXGtNivRpQ98hr6A=='}) @unittest.skipIf(not _network_available, "Internet not reachable") class QueryTests(unittest.TestCase): @@ -77,7 +77,7 @@ class QueryTests(unittest.TestCase): for address in query_addresses: qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) - response = dns.query.udp(q, address) + response = dns.query.udp(q, address, timeout=2) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) @@ -92,7 +92,7 @@ class QueryTests(unittest.TestCase): s.setblocking(0) qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) - response = dns.query.udp(q, address, sock=s) + response = dns.query.udp(q, address, sock=s, timeout=2) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) @@ -104,7 +104,7 @@ class QueryTests(unittest.TestCase): for address in query_addresses: qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) - response = dns.query.tcp(q, address) + response = dns.query.tcp(q, address, timeout=2) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) @@ -117,11 +117,12 @@ class QueryTests(unittest.TestCase): with socket.socket(dns.inet.af_for_address(address), socket.SOCK_STREAM) as s: ll = dns.inet.low_level_address_tuple((address, 53)) + s.settimeout(2) s.connect(ll) s.setblocking(0) qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) - response = dns.query.tcp(q, None, sock=s) + response = dns.query.tcp(q, None, sock=s, timeout=2) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) @@ -133,7 +134,7 @@ class QueryTests(unittest.TestCase): for address in query_addresses: qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) - response = dns.query.tls(q, address) + response = dns.query.tls(q, address, timeout=2) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) @@ -147,13 +148,14 @@ class QueryTests(unittest.TestCase): with socket.socket(dns.inet.af_for_address(address), socket.SOCK_STREAM) as base_s: ll = dns.inet.low_level_address_tuple((address, 853)) + base_s.settimeout(2) base_s.connect(ll) ctx = ssl.create_default_context() with ctx.wrap_socket(base_s, server_hostname='dns.google') as s: s.setblocking(0) qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) - response = dns.query.tls(q, None, sock=s) + response = dns.query.tls(q, None, sock=s, timeout=2) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) @@ -165,7 +167,7 @@ class QueryTests(unittest.TestCase): for address in query_addresses: qname = dns.name.from_text('.') q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) - (_, tcp) = dns.query.udp_with_fallback(q, address) + (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2) self.assertTrue(tcp) def testQueryUDPFallbackWithSocket(self): @@ -175,20 +177,22 @@ class QueryTests(unittest.TestCase): udp_s.setblocking(0) with socket.socket(af, socket.SOCK_STREAM) as tcp_s: ll = dns.inet.low_level_address_tuple((address, 53)) + tcp_s.settimeout(2) tcp_s.connect(ll) tcp_s.setblocking(0) qname = dns.name.from_text('.') q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) (_, tcp) = dns.query.udp_with_fallback(q, address, - udp_sock=udp_s, - tcp_sock=tcp_s) + udp_sock=udp_s, + tcp_sock=tcp_s, + timeout=2) self.assertTrue(tcp) def testQueryUDPFallbackNoFallback(self): for address in query_addresses: qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) - (_, tcp) = dns.query.udp_with_fallback(q, address) + (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2) self.assertFalse(tcp) def testUDPReceiveQuery(self): -- cgit v1.2.1