diff options
-rw-r--r-- | dns/query.py | 36 | ||||
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | tests/test_doh.py | 22 |
3 files changed, 40 insertions, 20 deletions
diff --git a/dns/query.py b/dns/query.py index 5bf471a..1111e08 100644 --- a/dns/query.py +++ b/dns/query.py @@ -46,11 +46,19 @@ try: except ImportError: # pragma: no cover _have_requests = False +_have_httpx = False +_have_http2 = False try: import httpx _have_httpx = True + try: + # See if http2 support is available. + with httpx.Client(http2=True): + _have_http2 = True + except Exception: + pass except ImportError: # pragma: no cover - _have_httpx = False + pass have_doh = _have_requests or _have_httpx @@ -283,9 +291,9 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, """ if not have_doh: - raise NoDOH # pragma: no cover + raise NoDOH('Neither httpx nor requests is available.') # pragma: no cover - _httpx_ok = True + _httpx_ok = _have_httpx wire = q.to_wire() (af, _, source) = _destination_and_source(where, port, source, source_port, @@ -319,28 +327,26 @@ def https(q, where, timeout=None, port=443, source=None, source_port=0, if _have_requests: transport_adapter = SourceAddressAdapter(source) + if session: + _is_httpx = isinstance(session, httpx.Client) + if _is_httpx and not _httpx_ok: + raise NoDOH('Session is httpx, but httpx cannot be used for ' + 'the requested operation.') + else: + _is_httpx = _httpx_ok + if not _httpx_ok and not _have_requests: raise NoDOH('Cannot use httpx for this operation, and ' 'requests is not available.') with contextlib.ExitStack() as stack: - if session: - if _have_httpx: - _is_httpx = isinstance(session, httpx.Client) - else: - _is_httpx = False - if _is_httpx and not _httpx_ok: - # we can't use this session - session = None if not session: - if _have_httpx and _httpx_ok: - _is_httpx = True + if _is_httpx: session = stack.enter_context(httpx.Client(http1=True, - http2=True, + http2=_have_http2, verify=verify, transport=transport)) else: - _is_httpx = False session = stack.enter_context(requests.sessions.Session()) if transport_adapter: diff --git a/pyproject.toml b/pyproject.toml index 51bfbae..4c60837 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ wheel = "^0.35.0" pylint = "^2.7.4" [tool.poetry.extras] -doh = ['requests', 'requests-toolbelt'] +doh = ['httpx[http2]', 'requests', 'requests-toolbelt'] idna = ['idna'] dnssec = ['cryptography'] trio = ['trio'] diff --git a/tests/test_doh.py b/tests/test_doh.py index b575054..9dc4cec 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -160,6 +160,18 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): timeout=4) self.assertTrue(q.is_response(r)) + def test_get_request_http1(self): + saved_have_http2 = dns.query._have_http2 + try: + dns.query._have_http2 = False + nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) + q = dns.message.make_query('example.com.', dns.rdatatype.A) + r = dns.query.https(q, nameserver_url, session=self.session, post=False, + timeout=4) + self.assertTrue(q.is_response(r)) + finally: + dns.query._have_http2 = saved_have_http2 + def test_post_request(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) q = dns.message.make_query('example.com.', dns.rdatatype.A) @@ -186,7 +198,7 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): post=False, timeout=4) self.assertTrue(q.is_response(r)) - def test_bootstrap_address(self): + def test_bootstrap_address_fails(self): # We test this to see if v4 is available if resolver_v4_addresses: ip = '185.228.168.168' @@ -198,10 +210,12 @@ class DNSOverHTTPSTestCaseHttpx(unittest.TestCase): with self.assertRaises(httpx.ConnectError): dns.query.https(q, invalid_tls_url, session=self.session, timeout=4) - # use host header - r = dns.query.https(q, valid_tls_url, session=self.session, + # We can't do the Host header and SNI magic with httpx, but + # we are demanding httpx be used by providing a session, so + # we should get a NoDOH exception. + with self.assertRaises(dns.query.NoDOH): + dns.query.https(q, valid_tls_url, session=self.session, bootstrap_address=ip, timeout=4) - self.assertTrue(q.is_response(r)) def test_new_session(self): nameserver_url = random.choice(KNOWN_ANYCAST_DOH_RESOLVER_URLS) |