diff options
author | Brian Wellington <bwelling@xbill.org> | 2020-01-07 13:03:13 -0800 |
---|---|---|
committer | Brian Wellington <bwelling@xbill.org> | 2020-01-07 13:03:13 -0800 |
commit | 7ec39e21ab0a6761a34ec405a4a59dc4ebe54924 (patch) | |
tree | 61a5e0db5d3e733e95c87d35ed2c04a73e99f3ba /dns/query.py | |
parent | 0ed99480529ecf3217738fe671a31dddd3360e48 (diff) | |
download | dnspython-7ec39e21ab0a6761a34ec405a4a59dc4ebe54924.tar.gz |
DoH cleanup.
Diffstat (limited to 'dns/query.py')
-rw-r--r-- | dns/query.py | 104 |
1 files changed, 68 insertions, 36 deletions
diff --git a/dns/query.py b/dns/query.py index c36248e..5876623 100644 --- a/dns/query.py +++ b/dns/query.py @@ -189,14 +189,16 @@ def _addresses_equal(af, a1, a2): return n1 == n2 and a1[1:] == a2[1:] -def _destination_and_source(af, where, port, source, source_port): +def _destination_and_source(af, where, port, source, source_port, + default_to_inet=True): # Apply defaults and compute destination and source tuples # suitable for use in connect(), sendto(), or bind(). if af is None: try: af = dns.inet.af_for_address(where) except Exception: - af = dns.inet.AF_INET + if default_to_inet: + af = dns.inet.AF_INET if af == dns.inet.AF_INET: destination = (where, port) if source is not None or source_port != 0: @@ -209,6 +211,9 @@ def _destination_and_source(af, where, port, source, source_port): if source is None: source = '::' source = (source, source_port, 0, 0) + else: + source = None + destination = None return (af, destination, source) def send_https(session, what, lifetime=None): @@ -225,9 +230,10 @@ def send_https(session, what, lifetime=None): what = what.prepare() return session.send(what, timeout=lifetime) -def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True, - bootstrap_address=None, verify=True, source=None, source_port=0, - one_rr_per_rrset=False, ignore_trailing=False): +def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0, + one_rr_per_rrset=False, ignore_trailing=False, + session=None, path='/dns-query', post=True, + bootstrap_address=None, verify=True): """Return the response obtained after sending a query via DNS-over-HTTPS. *q*, a ``dns.message.Message``, the query to send. @@ -236,21 +242,15 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru address is given, the URL will be constructed using the following schema: https://<IP-address>:<port>/<path>. - *session*, a ``requests.session.Session``, the session to use to send the - queries. This argument is required to allow for connection reuse. - *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query times out. If ``None``, the default, wait forever. - *port*, a ``int``, the port to send the query to. Default is 443. - - *path*, a ``str``. If *where* is an IP address, then *path* will be used to - construct the URL to send the DNS query to. - - *post*, a ``bool``. If ``True``, the default, POST method will be used. + *port*, a ``int``, the port to send the query to. The default is 443. - *bootstrap_address*, a ``str``, the IP address to use to bypass the system's - DNS resolver. + *af*, an ``int``, the address family to use. The default is ``None``, + which causes the address family to use to be inferred from the form of + *where*, or uses the system default. Setting this to AF_INET or + AF_INET6 currently has no effect. *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source address. The default is the wildcard address. @@ -264,13 +264,27 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. + *session*, a ``requests.session.Session``. If provided, the session to use + to send the queries. + + *path*, a ``str``. If *where* is an IP address, then *path* will be used to + construct the URL to send the DNS query to. + + *post*, a ``bool``. If ``True``, the default, POST method will be used. + + *bootstrap_address*, a ``str``, the IP address to use to bypass the + system's DNS resolver. + + *verify*, a ``str`, containing a path to a certificate file or directory. + Returns a ``dns.message.Message``. """ wire = q.to_wire() - af = None (af, destination, source) = _destination_and_source(af, where, port, - source, source_port) + source, source_port, + False) + transport_adapter = None headers = { "accept": "application/dns-message" } @@ -282,31 +296,49 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru split_url = urllib.parse.urlsplit(where) headers['Host'] = split_url.hostname url = where.replace(split_url.hostname, bootstrap_address) - session.mount(url, HostHeaderSSLAdapter()) + transport_adapter = HostHeaderSSLAdapter() else: url = where if source is not None: # set source port and source address - session.mount(url, SourceAddressAdapter(source)) - - # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples - if post: - headers.update({ - "content-type": "application/dns-message", - "content-length": str(len(wire)) - }) - response = session.post(url, headers=headers, data=wire, stream=True, - timeout=timeout, verify=verify) + transport_adapter = SourceAddressAdapter(source) + + if session: + close_session = False else: - wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=") - url += "?dns={}".format(wire) - response = session.get(url, headers=headers, stream=True, - timeout=timeout, verify=verify) + session = requests.sessions.Session() + close_session = True + + try: + if transport_adapter: + session.mount(url, transport_adapter) + + # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH + # GET and POST examples + if post: + headers.update({ + "content-type": "application/dns-message", + "content-length": str(len(wire)) + }) + response = session.post(url, headers=headers, data=wire, + stream=True, timeout=timeout, + verify=verify) + else: + wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=") + url += "?dns={}".format(wire) + response = session.get(url, headers=headers, stream=True, + timeout=timeout, verify=verify) + finally: + if close_session: + session.close() - # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes + # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH + # status codes if response.status_code < 200 or response.status_code > 299: - raise ValueError('{} responded with status code {}\nResponse body: {}'.format( - where, response.status_code, response.content)) + raise ValueError('{} responded with status code {}' + '\nResponse body: {}'.format(where, + response.status_code, + response.content)) r = dns.message.from_wire(response.content, keyring=q.keyring, request_mac=q.request_mac, |