summaryrefslogtreecommitdiff
path: root/dns/query.py
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2020-01-07 13:03:13 -0800
committerBrian Wellington <bwelling@xbill.org>2020-01-07 13:03:13 -0800
commit7ec39e21ab0a6761a34ec405a4a59dc4ebe54924 (patch)
tree61a5e0db5d3e733e95c87d35ed2c04a73e99f3ba /dns/query.py
parent0ed99480529ecf3217738fe671a31dddd3360e48 (diff)
downloaddnspython-7ec39e21ab0a6761a34ec405a4a59dc4ebe54924.tar.gz
DoH cleanup.
Diffstat (limited to 'dns/query.py')
-rw-r--r--dns/query.py104
1 files changed, 68 insertions, 36 deletions
diff --git a/dns/query.py b/dns/query.py
index c36248e..5876623 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -189,14 +189,16 @@ def _addresses_equal(af, a1, a2):
return n1 == n2 and a1[1:] == a2[1:]
-def _destination_and_source(af, where, port, source, source_port):
+def _destination_and_source(af, where, port, source, source_port,
+ default_to_inet=True):
# Apply defaults and compute destination and source tuples
# suitable for use in connect(), sendto(), or bind().
if af is None:
try:
af = dns.inet.af_for_address(where)
except Exception:
- af = dns.inet.AF_INET
+ if default_to_inet:
+ af = dns.inet.AF_INET
if af == dns.inet.AF_INET:
destination = (where, port)
if source is not None or source_port != 0:
@@ -209,6 +211,9 @@ def _destination_and_source(af, where, port, source, source_port):
if source is None:
source = '::'
source = (source, source_port, 0, 0)
+ else:
+ source = None
+ destination = None
return (af, destination, source)
def send_https(session, what, lifetime=None):
@@ -225,9 +230,10 @@ def send_https(session, what, lifetime=None):
what = what.prepare()
return session.send(what, timeout=lifetime)
-def https(q, where, session, timeout=None, port=443, path='/dns-query', post=True,
- bootstrap_address=None, verify=True, source=None, source_port=0,
- one_rr_per_rrset=False, ignore_trailing=False):
+def https(q, where, timeout=None, port=443, af=None, source=None, source_port=0,
+ one_rr_per_rrset=False, ignore_trailing=False,
+ session=None, path='/dns-query', post=True,
+ bootstrap_address=None, verify=True):
"""Return the response obtained after sending a query via DNS-over-HTTPS.
*q*, a ``dns.message.Message``, the query to send.
@@ -236,21 +242,15 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
address is given, the URL will be constructed using the following schema:
https://<IP-address>:<port>/<path>.
- *session*, a ``requests.session.Session``, the session to use to send the
- queries. This argument is required to allow for connection reuse.
-
*timeout*, a ``float`` or ``None``, the number of seconds to
wait before the query times out. If ``None``, the default, wait forever.
- *port*, a ``int``, the port to send the query to. Default is 443.
-
- *path*, a ``str``. If *where* is an IP address, then *path* will be used to
- construct the URL to send the DNS query to.
-
- *post*, a ``bool``. If ``True``, the default, POST method will be used.
+ *port*, a ``int``, the port to send the query to. The default is 443.
- *bootstrap_address*, a ``str``, the IP address to use to bypass the system's
- DNS resolver.
+ *af*, an ``int``, the address family to use. The default is ``None``,
+ which causes the address family to use to be inferred from the form of
+ *where*, or uses the system default. Setting this to AF_INET or
+ AF_INET6 currently has no effect.
*source*, a ``str`` containing an IPv4 or IPv6 address, specifying
the source address. The default is the wildcard address.
@@ -264,13 +264,27 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
+ *session*, a ``requests.session.Session``. If provided, the session to use
+ to send the queries.
+
+ *path*, a ``str``. If *where* is an IP address, then *path* will be used to
+ construct the URL to send the DNS query to.
+
+ *post*, a ``bool``. If ``True``, the default, POST method will be used.
+
+ *bootstrap_address*, a ``str``, the IP address to use to bypass the
+ system's DNS resolver.
+
+ *verify*, a ``str`, containing a path to a certificate file or directory.
+
Returns a ``dns.message.Message``.
"""
wire = q.to_wire()
- af = None
(af, destination, source) = _destination_and_source(af, where, port,
- source, source_port)
+ source, source_port,
+ False)
+ transport_adapter = None
headers = {
"accept": "application/dns-message"
}
@@ -282,31 +296,49 @@ def https(q, where, session, timeout=None, port=443, path='/dns-query', post=Tru
split_url = urllib.parse.urlsplit(where)
headers['Host'] = split_url.hostname
url = where.replace(split_url.hostname, bootstrap_address)
- session.mount(url, HostHeaderSSLAdapter())
+ transport_adapter = HostHeaderSSLAdapter()
else:
url = where
if source is not None:
# set source port and source address
- session.mount(url, SourceAddressAdapter(source))
-
- # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH GET and POST examples
- if post:
- headers.update({
- "content-type": "application/dns-message",
- "content-length": str(len(wire))
- })
- response = session.post(url, headers=headers, data=wire, stream=True,
- timeout=timeout, verify=verify)
+ transport_adapter = SourceAddressAdapter(source)
+
+ if session:
+ close_session = False
else:
- wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
- url += "?dns={}".format(wire)
- response = session.get(url, headers=headers, stream=True,
- timeout=timeout, verify=verify)
+ session = requests.sessions.Session()
+ close_session = True
+
+ try:
+ if transport_adapter:
+ session.mount(url, transport_adapter)
+
+ # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
+ # GET and POST examples
+ if post:
+ headers.update({
+ "content-type": "application/dns-message",
+ "content-length": str(len(wire))
+ })
+ response = session.post(url, headers=headers, data=wire,
+ stream=True, timeout=timeout,
+ verify=verify)
+ else:
+ wire = base64.urlsafe_b64encode(wire).decode('utf-8').strip("=")
+ url += "?dns={}".format(wire)
+ response = session.get(url, headers=headers, stream=True,
+ timeout=timeout, verify=verify)
+ finally:
+ if close_session:
+ session.close()
- # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH status codes
+ # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
+ # status codes
if response.status_code < 200 or response.status_code > 299:
- raise ValueError('{} responded with status code {}\nResponse body: {}'.format(
- where, response.status_code, response.content))
+ raise ValueError('{} responded with status code {}'
+ '\nResponse body: {}'.format(where,
+ response.status_code,
+ response.content))
r = dns.message.from_wire(response.content,
keyring=q.keyring,
request_mac=q.request_mac,