summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile7
-rw-r--r--dns/dnssec.py8
-rw-r--r--dns/message.py238
-rw-r--r--dns/query.py4
-rw-r--r--dns/rdata.py74
-rw-r--r--dns/rdtypes/ANY/CDS.py5
-rw-r--r--dns/rdtypes/dsbase.py24
-rw-r--r--dns/resolver.py15
-rw-r--r--dns/zone.py9
-rw-r--r--doc/examples.rst10
-rw-r--r--doc/manual.rst1
-rwxr-xr-xexamples/edns.py52
-rw-r--r--pyproject.toml22
-rw-r--r--tests/test_async.py8
-rw-r--r--tests/test_dnssec.py23
-rw-r--r--tests/test_doh.py6
-rw-r--r--tests/test_message.py52
-rw-r--r--tests/test_rdata.py89
-rw-r--r--tests/test_resolver.py10
-rw-r--r--tests/test_zonedigest.py15
20 files changed, 522 insertions, 150 deletions
diff --git a/Makefile b/Makefile
index 2e120ae..76e7028 100644
--- a/Makefile
+++ b/Makefile
@@ -69,6 +69,11 @@ pocov:
poetry run coverage html --include 'dns*'
poetry run coverage report --include 'dns*'
-pokit:
+oldpokit:
po run python setup.py sdist --formats=zip bdist_wheel
+pokit:
+ po build
+
+findjunk:
+ find dns -type f | egrep -v '.*\.py' | egrep -v 'py\.typed'
diff --git a/dns/dnssec.py b/dns/dnssec.py
index f09ecd6..6e9946f 100644
--- a/dns/dnssec.py
+++ b/dns/dnssec.py
@@ -404,13 +404,13 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None):
rrnamebuf = rrname.to_digestable()
rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass,
rrsig.original_ttl)
- for rr in sorted(rdataset):
+ rdatas = [rdata.to_digestable(origin) for rdata in rdataset]
+ for rdata in sorted(rdatas):
data += rrnamebuf
data += rrfixed
- rrdata = rr.to_digestable(origin)
- rrlen = struct.pack('!H', len(rrdata))
+ rrlen = struct.pack('!H', len(rdata))
data += rrlen
- data += rrdata
+ data += rdata
chosen_hash = _make_hash(rrsig.algorithm)
diff --git a/dns/message.py b/dns/message.py
index 6fa90ca..1e67a17 100644
--- a/dns/message.py
+++ b/dns/message.py
@@ -108,6 +108,12 @@ class MessageSection(dns.enum.IntEnum):
return 3
+class MessageError:
+ def __init__(self, exception, offset):
+ self.exception = exception
+ self.offset = offset
+
+
DEFAULT_EDNS_PAYLOAD = 1232
MAX_CHAIN = 16
@@ -132,6 +138,7 @@ class Message:
self.origin = None
self.tsig_ctx = None
self.index = {}
+ self.errors = []
@property
def question(self):
@@ -873,11 +880,14 @@ class _WireReader:
ignore_trailing: Ignore trailing junk at end of request?
multi: Is this message part of a multi-message sequence?
DNS dynamic updates.
+ continue_on_error: try to extract as much information as possible from
+ the message, accumulating MessageErrors in the *errors* attribute instead of
+ raising them.
"""
def __init__(self, wire, initialize_message, question_only=False,
one_rr_per_rrset=False, ignore_trailing=False,
- keyring=None, multi=False):
+ keyring=None, multi=False, continue_on_error=False):
self.parser = dns.wire.Parser(wire)
self.message = None
self.initialize_message = initialize_message
@@ -886,6 +896,8 @@ class _WireReader:
self.ignore_trailing = ignore_trailing
self.keyring = keyring
self.multi = multi
+ self.continue_on_error = continue_on_error
+ self.errors = []
def _get_question(self, section_number, qcount):
"""Read the next *qcount* records from the wire data and add them to
@@ -902,11 +914,14 @@ class _WireReader:
self.message.find_rrset(section, qname, rdclass, rdtype,
create=True, force_unique=True)
+ def _add_error(self, e):
+ self.errors.append(MessageError(e, self.parser.current))
+
def _get_section(self, section_number, count):
"""Read the next I{count} records from the wire data and add them to
the specified section.
- section: the section of the message to which to add records
+ section_number: the section of the message to which to add records
count: the number of records to read
"""
@@ -929,55 +944,65 @@ class _WireReader:
(rdclass, rdtype, deleting, empty) = \
self.message._parse_rr_header(section_number,
name, rdclass, rdtype)
- if empty:
- if rdlen > 0:
- raise dns.exception.FormError
- rd = None
- covers = dns.rdatatype.NONE
- else:
- with self.parser.restrict_to(rdlen):
- rd = dns.rdata.from_wire_parser(rdclass, rdtype,
- self.parser,
- self.message.origin)
- covers = rd.covers()
- if self.message.xfr and rdtype == dns.rdatatype.SOA:
- force_unique = True
- if rdtype == dns.rdatatype.OPT:
- self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
- elif rdtype == dns.rdatatype.TSIG:
- if self.keyring is None:
- raise UnknownTSIGKey('got signed message without keyring')
- if isinstance(self.keyring, dict):
- key = self.keyring.get(absolute_name)
- if isinstance(key, bytes):
- key = dns.tsig.Key(absolute_name, key, rd.algorithm)
- elif callable(self.keyring):
- key = self.keyring(self.message, absolute_name)
+ try:
+ rdata_start = self.parser.current
+ if empty:
+ if rdlen > 0:
+ raise dns.exception.FormError
+ rd = None
+ covers = dns.rdatatype.NONE
else:
- key = self.keyring
- if key is None:
- raise UnknownTSIGKey("key '%s' unknown" % name)
- self.message.keyring = key
- self.message.tsig_ctx = \
- dns.tsig.validate(self.parser.wire,
- key,
- absolute_name,
- rd,
- int(time.time()),
- self.message.request_mac,
- rr_start,
- self.message.tsig_ctx,
- self.multi)
- self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
- else:
- rrset = self.message.find_rrset(section, name,
- rdclass, rdtype, covers,
- deleting, True,
- force_unique)
- if rd is not None:
- if ttl > 0x7fffffff:
- ttl = 0
- rrset.add(rd, ttl)
+ with self.parser.restrict_to(rdlen):
+ rd = dns.rdata.from_wire_parser(rdclass, rdtype,
+ self.parser,
+ self.message.origin)
+ covers = rd.covers()
+ if self.message.xfr and rdtype == dns.rdatatype.SOA:
+ force_unique = True
+ if rdtype == dns.rdatatype.OPT:
+ self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
+ elif rdtype == dns.rdatatype.TSIG:
+ if self.keyring is None:
+ raise UnknownTSIGKey('got signed message without '
+ 'keyring')
+ if isinstance(self.keyring, dict):
+ key = self.keyring.get(absolute_name)
+ if isinstance(key, bytes):
+ key = dns.tsig.Key(absolute_name, key, rd.algorithm)
+ elif callable(self.keyring):
+ key = self.keyring(self.message, absolute_name)
+ else:
+ key = self.keyring
+ if key is None:
+ raise UnknownTSIGKey("key '%s' unknown" % name)
+ self.message.keyring = key
+ self.message.tsig_ctx = \
+ dns.tsig.validate(self.parser.wire,
+ key,
+ absolute_name,
+ rd,
+ int(time.time()),
+ self.message.request_mac,
+ rr_start,
+ self.message.tsig_ctx,
+ self.multi)
+ self.message.tsig = dns.rrset.from_rdata(absolute_name, 0,
+ rd)
+ else:
+ rrset = self.message.find_rrset(section, name,
+ rdclass, rdtype, covers,
+ deleting, True,
+ force_unique)
+ if rd is not None:
+ if ttl > 0x7fffffff:
+ ttl = 0
+ rrset.add(rd, ttl)
+ except Exception as e:
+ if self.continue_on_error:
+ self._add_error(e)
+ self.parser.seek(rdata_start + rdlen)
+ else:
+ raise
def read(self):
"""Read a wire format DNS message and build a dns.message.Message
@@ -993,69 +1018,82 @@ class _WireReader:
self.initialize_message(self.message)
self.one_rr_per_rrset = \
self.message._get_one_rr_per_rrset(self.one_rr_per_rrset)
- self._get_question(MessageSection.QUESTION, qcount)
- if self.question_only:
- return self.message
- self._get_section(MessageSection.ANSWER, ancount)
- self._get_section(MessageSection.AUTHORITY, aucount)
- self._get_section(MessageSection.ADDITIONAL, adcount)
- if not self.ignore_trailing and self.parser.remaining() != 0:
- raise TrailingJunk
- if self.multi and self.message.tsig_ctx and not self.message.had_tsig:
- self.message.tsig_ctx.update(self.parser.wire)
+ try:
+ self._get_question(MessageSection.QUESTION, qcount)
+ if self.question_only:
+ return self.message
+ self._get_section(MessageSection.ANSWER, ancount)
+ self._get_section(MessageSection.AUTHORITY, aucount)
+ self._get_section(MessageSection.ADDITIONAL, adcount)
+ if not self.ignore_trailing and self.parser.remaining() != 0:
+ raise TrailingJunk
+ if self.multi and self.message.tsig_ctx and \
+ not self.message.had_tsig:
+ self.message.tsig_ctx.update(self.parser.wire)
+ except Exception as e:
+ if self.continue_on_error:
+ self._add_error(e)
+ else:
+ raise
return self.message
def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
tsig_ctx=None, multi=False,
question_only=False, one_rr_per_rrset=False,
- ignore_trailing=False, raise_on_truncation=False):
- """Convert a DNS wire format message into a message
- object.
+ ignore_trailing=False, raise_on_truncation=False,
+ continue_on_error=False):
+ """Convert a DNS wire format message into a message object.
- *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use
- if the message is signed.
+ *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the
+ message is signed.
- *request_mac*, a ``bytes``. If the message is a response to a
- TSIG-signed request, *request_mac* should be set to the MAC of
- that request.
+ *request_mac*, a ``bytes``. If the message is a response to a TSIG-signed
+ request, *request_mac* should be set to the MAC of that request.
- *xfr*, a ``bool``, should be set to ``True`` if this message is part of
- a zone transfer.
+ *xfr*, a ``bool``, should be set to ``True`` if this message is part of a
+ zone transfer.
- *origin*, a ``dns.name.Name`` or ``None``. If the message is part
- of a zone transfer, *origin* should be the origin name of the
- zone. If not ``None``, names will be relativized to the origin.
+ *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone
+ transfer, *origin* should be the origin name of the zone. If not ``None``,
+ names will be relativized to the origin.
*tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the
ongoing TSIG context, used when validating zone transfers.
- *multi*, a ``bool``, should be set to ``True`` if this message is
- part of a multiple message sequence.
+ *multi*, a ``bool``, should be set to ``True`` if this message is part of a
+ multiple message sequence.
- *question_only*, a ``bool``. If ``True``, read only up to
- the end of the question section.
+ *question_only*, a ``bool``. If ``True``, read only up to the end of the
+ question section.
- *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its
- own RRset.
+ *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own
+ RRset.
- *ignore_trailing*, a ``bool``. If ``True``, ignore trailing
- junk at end of the message.
+ *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of
+ the message.
- *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if
- the TC bit is set.
+ *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the
+ TC bit is set.
+
+ *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even
+ if errors occur. Erroneous rdata will be ignored. Errors will be
+ accumulated as a list of MessageError objects in the message's ``errors``
+ attribute. This option is recommended only for DNS analysis tools, or for
+ use in a server as part of an error handling path. The default is
+ ``False``.
Raises ``dns.message.ShortHeader`` if the message is less than 12 octets
long.
- Raises ``dns.message.TrailingJunk`` if there were octets in the message
- past the end of the proper DNS message, and *ignore_trailing* is ``False``.
+ Raises ``dns.message.TrailingJunk`` if there were octets in the message past
+ the end of the proper DNS message, and *ignore_trailing* is ``False``.
- Raises ``dns.message.BadEDNS`` if an OPT record was in the
- wrong section, or occurred more than once.
+ Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or
+ occurred more than once.
- Raises ``dns.message.BadTSIG`` if a TSIG record was not the last
- record of the additional data section.
+ Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of
+ the additional data section.
Raises ``dns.message.Truncated`` if the TC flag is set and
*raise_on_truncation* is ``True``.
@@ -1070,7 +1108,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
message.tsig_ctx = tsig_ctx
reader = _WireReader(wire, initialize_message, question_only,
- one_rr_per_rrset, ignore_trailing, keyring, multi)
+ one_rr_per_rrset, ignore_trailing, keyring, multi,
+ continue_on_error)
try:
m = reader.read()
except dns.exception.FormError:
@@ -1083,6 +1122,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
# have to do this check here too.
if m.flags & dns.flags.TC and raise_on_truncation:
raise Truncated(message=m)
+ if continue_on_error:
+ m.errors = reader.errors
return m
@@ -1383,7 +1424,8 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False):
def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
want_dnssec=False, ednsflags=None, payload=None,
- request_payload=None, options=None, idna_codec=None):
+ request_payload=None, options=None, idna_codec=None,
+ id=None, flags=dns.flags.RD):
"""Make a query message.
The query name, type, and class may all be specified either
@@ -1400,7 +1442,9 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
is class IN.
*use_edns*, an ``int``, ``bool`` or ``None``. The EDNS level to use; the
- default is None (no EDNS).
+ default is ``None``. If ``None``, EDNS will be enabled only if other
+ parameters (*ednsflags*, *payload*, *request_payload*, or *options*) are
+ set.
See the description of dns.message.Message.use_edns() for the possible
values for use_edns and their meanings.
@@ -1423,6 +1467,12 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder
is used.
+ *id*, an ``int`` or ``None``, the desired query id. The default is
+ ``None``, which generates a random query id.
+
+ *flags*, an ``int``, the desired query flags. The default is
+ ``dns.flags.RD``.
+
Returns a ``dns.message.QueryMessage``
"""
@@ -1430,8 +1480,8 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
qname = dns.name.from_text(qname, idna_codec=idna_codec)
rdtype = dns.rdatatype.RdataType.make(rdtype)
rdclass = dns.rdataclass.RdataClass.make(rdclass)
- m = QueryMessage()
- m.flags |= dns.flags.RD
+ m = QueryMessage(id=id)
+ m.flags = dns.flags.Flag(flags)
m.find_rrset(m.question, qname, rdclass, rdtype, create=True,
force_unique=True)
# only pass keywords on to use_edns if they have been set to a
diff --git a/dns/query.py b/dns/query.py
index 934bf41..fee5d6a 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -548,7 +548,7 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None,
if a socket is provided, it must be a nonblocking datagram socket,
and the *source* and *source_port* are ignored for the UDP query.
- *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the
+ *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
TCP query. If ``None``, the default, a socket is created. Note that
if a socket is provided, it must be a nonblocking connected stream
socket, and *where*, *source* and *source_port* are ignored for the TCP
@@ -702,7 +702,7 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0,
*ignore_trailing*, a ``bool``. If ``True``, ignore trailing
junk at end of the received message.
- *sock*, a ``socket.socket``, or ``None``, the socket to use for the
+ *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
query. If ``None``, the default, a socket is created. Note that
if a socket is provided, it must be a nonblocking connected stream
socket, and *where*, *port*, *source* and *source_port* are ignored.
diff --git a/dns/rdata.py b/dns/rdata.py
index 0831c41..624063e 100644
--- a/dns/rdata.py
+++ b/dns/rdata.py
@@ -38,6 +38,22 @@ import dns.ttl
_chunksize = 32
+# We currently allow comparisons for rdata with relative names for backwards
+# compatibility, but in the future we will not, as these kinds of comparisons
+# can lead to subtle bugs if code is not carefully written.
+#
+# This switch allows the future behavior to be turned on so code can be
+# tested with it.
+_allow_relative_comparisons = True
+
+
+class NoRelativeRdataOrdering(dns.exception.DNSException):
+ """An attempt was made to do an ordered comparison of one or more
+ rdata with relative names. The only reliable way of sorting rdata
+ is to use non-relativized rdata.
+
+ """
+
def _wordbreak(data, chunksize=_chunksize, separator=b' '):
"""Break a binary string into chunks of chunksize characters separated by
@@ -232,12 +248,42 @@ class Rdata:
"""Compare an rdata with another rdata of the same rdtype and
rdclass.
- Return < 0 if self < other in the DNSSEC ordering, 0 if self
- == other, and > 0 if self > other.
-
+ For rdata with only absolute names:
+ Return < 0 if self < other in the DNSSEC ordering, 0 if self
+ == other, and > 0 if self > other.
+ For rdata with at least one relative names:
+ The rdata sorts before any rdata with only absolute names.
+ When compared with another relative rdata, all names are
+ made absolute as if they were relative to the root, as the
+ proper origin is not available. While this creates a stable
+ ordering, it is NOT guaranteed to be the DNSSEC ordering.
+ In the future, all ordering comparisons for rdata with
+ relative names will be disallowed.
"""
- our = self.to_digestable(dns.name.root)
- their = other.to_digestable(dns.name.root)
+ try:
+ our = self.to_digestable()
+ our_relative = False
+ except dns.name.NeedAbsoluteNameOrOrigin:
+ if _allow_relative_comparisons:
+ our = self.to_digestable(dns.name.root)
+ our_relative = True
+ try:
+ their = other.to_digestable()
+ their_relative = False
+ except dns.name.NeedAbsoluteNameOrOrigin:
+ if _allow_relative_comparisons:
+ their = other.to_digestable(dns.name.root)
+ their_relative = True
+ if _allow_relative_comparisons:
+ if our_relative != their_relative:
+ # For the purpose of comparison, all rdata with at least one
+ # relative name is less than an rdata with only absolute names.
+ if our_relative:
+ return -1
+ else:
+ return 1
+ elif our_relative or their_relative:
+ raise NoRelativeRdataOrdering
if our == their:
return 0
elif our > their:
@@ -250,14 +296,28 @@ class Rdata:
return False
if self.rdclass != other.rdclass or self.rdtype != other.rdtype:
return False
- return self._cmp(other) == 0
+ our_relative = False
+ their_relative = False
+ try:
+ our = self.to_digestable()
+ except dns.name.NeedAbsoluteNameOrOrigin:
+ our = self.to_digestable(dns.name.root)
+ our_relative = True
+ try:
+ their = other.to_digestable()
+ except dns.name.NeedAbsoluteNameOrOrigin:
+ their = other.to_digestable(dns.name.root)
+ their_relative = True
+ if our_relative != their_relative:
+ return False
+ return our == their
def __ne__(self, other):
if not isinstance(other, Rdata):
return True
if self.rdclass != other.rdclass or self.rdtype != other.rdtype:
return True
- return self._cmp(other) != 0
+ return not self.__eq__(other)
def __lt__(self, other):
if not isinstance(other, Rdata) or \
diff --git a/dns/rdtypes/ANY/CDS.py b/dns/rdtypes/ANY/CDS.py
index 39e3556..094de12 100644
--- a/dns/rdtypes/ANY/CDS.py
+++ b/dns/rdtypes/ANY/CDS.py
@@ -23,3 +23,8 @@ import dns.immutable
class CDS(dns.rdtypes.dsbase.DSBase):
"""CDS record"""
+
+ _digest_length_by_type = {
+ **dns.rdtypes.dsbase.DSBase._digest_length_by_type,
+ 0: 1, # delete, RFC 8078 Sec. 4 (including Errata ID 5049)
+ }
diff --git a/dns/rdtypes/dsbase.py b/dns/rdtypes/dsbase.py
index d125db2..403e937 100644
--- a/dns/rdtypes/dsbase.py
+++ b/dns/rdtypes/dsbase.py
@@ -24,15 +24,6 @@ import dns.rdata
import dns.rdatatype
-# Digest types registry: https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml
-_digest_length_by_type = {
- 1: 20, # SHA-1, RFC 3658 Sec. 2.4
- 2: 32, # SHA-256, RFC 4509 Sec. 2.2
- 3: 32, # GOST R 34.11-94, RFC 5933 Sec. 4 in conjunction with RFC 4490 Sec. 2.1
- 4: 48, # SHA-384, RFC 6605 Sec. 2
-}
-
-
@dns.immutable.immutable
class DSBase(dns.rdata.Rdata):
@@ -40,6 +31,14 @@ class DSBase(dns.rdata.Rdata):
__slots__ = ['key_tag', 'algorithm', 'digest_type', 'digest']
+ # Digest types registry: https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml
+ _digest_length_by_type = {
+ 1: 20, # SHA-1, RFC 3658 Sec. 2.4
+ 2: 32, # SHA-256, RFC 4509 Sec. 2.2
+ 3: 32, # GOST R 34.11-94, RFC 5933 Sec. 4 in conjunction with RFC 4490 Sec. 2.1
+ 4: 48, # SHA-384, RFC 6605 Sec. 2
+ }
+
def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type,
digest):
super().__init__(rdclass, rdtype)
@@ -49,13 +48,12 @@ class DSBase(dns.rdata.Rdata):
self.digest = self._as_bytes(digest)
try:
+ if len(self.digest) != self._digest_length_by_type[self.digest_type]:
+ raise ValueError('digest length inconsistent with digest type')
+ except KeyError:
if self.digest_type == 0: # reserved, RFC 3658 Sec. 2.4
raise ValueError('digest type 0 is reserved')
- expected_length = _digest_length_by_type[self.digest_type]
- except KeyError:
raise ValueError('unknown digest type')
- if len(self.digest) != expected_length:
- raise ValueError('digest length inconsistent with digest type')
def to_text(self, origin=None, relativize=True, **kw):
kw = kw.copy()
diff --git a/dns/resolver.py b/dns/resolver.py
index 6a9974d..108dd52 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -807,7 +807,7 @@ class BaseResolver:
f = stack.enter_context(open(f))
except OSError:
# /etc/resolv.conf doesn't exist, can't be read, etc.
- raise NoResolverConfiguration
+ raise NoResolverConfiguration(f'cannot open {f}')
for l in f:
if len(l) == 0 or l[0] == '#' or l[0] == ';':
@@ -848,7 +848,7 @@ class BaseResolver:
except (ValueError, IndexError):
pass
if len(self.nameservers) == 0:
- raise NoResolverConfiguration
+ raise NoResolverConfiguration('no nameservers')
def _determine_split_char(self, entry):
#
@@ -1120,6 +1120,14 @@ class BaseResolver:
``list``.
"""
if isinstance(nameservers, list):
+ for nameserver in nameservers:
+ if not dns.inet.is_address(nameserver):
+ try:
+ if urlparse(nameserver).scheme != 'https':
+ raise NotImplementedError
+ except Exception:
+ raise ValueError(f'nameserver {nameserver} is not an '
+ 'IP address or valid https URL')
self._nameservers = nameservers
else:
raise ValueError('nameservers must be a list'
@@ -1219,9 +1227,6 @@ class Resolver(BaseResolver):
source_port=source_port,
raise_on_truncation=True)
else:
- protocol = urlparse(nameserver).scheme
- if protocol != 'https':
- raise NotImplementedError
response = dns.query.https(request, nameserver,
timeout=timeout)
except Exception as ex:
diff --git a/dns/zone.py b/dns/zone.py
index 9c3204b..d154928 100644
--- a/dns/zone.py
+++ b/dns/zone.py
@@ -733,10 +733,11 @@ class Zone(dns.transaction.TransactionManager):
continue
rrfixed = struct.pack('!HHI', rdataset.rdtype,
rdataset.rdclass, rdataset.ttl)
- for rr in sorted(rdataset):
- rrdata = rr.to_digestable(self.origin)
- rrlen = struct.pack('!H', len(rrdata))
- hasher.update(rrnamebuf + rrfixed + rrlen + rrdata)
+ rdatas = [rdata.to_digestable(self.origin)
+ for rdata in rdataset]
+ for rdata in sorted(rdatas):
+ rrlen = struct.pack('!H', len(rdata))
+ hasher.update(rrnamebuf + rrfixed + rrlen + rdata)
return hasher.digest()
def compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE):
diff --git a/doc/examples.rst b/doc/examples.rst
new file mode 100644
index 0000000..4811b48
--- /dev/null
+++ b/doc/examples.rst
@@ -0,0 +1,10 @@
+.. examples:
+
+Examples
+--------
+
+The dnspython source comes with example programs that show how
+to use dnspython in practice. You can clone the dnspython source
+from GitHub:
+ git clone https://github.com/rthalley/dnspython.git
+The example prgrams are in the ``examples/`` directory.
diff --git a/doc/manual.rst b/doc/manual.rst
index 0ebbd03..d5ed014 100644
--- a/doc/manual.rst
+++ b/doc/manual.rst
@@ -17,3 +17,4 @@ Dnspython Manual
utilities
typing
threads
+ examples
diff --git a/examples/edns.py b/examples/edns.py
new file mode 100755
index 0000000..a130f85
--- /dev/null
+++ b/examples/edns.py
@@ -0,0 +1,52 @@
+#!/usr/bin/env python3
+
+import dns.edns
+import dns.message
+import dns.query
+import dns.resolver
+
+n = '.'
+t = dns.rdatatype.SOA
+l = '199.7.83.42' # Address of l.root-servers.net
+i = '149.20.1.73' # Address of ns1.isc.org, for COOKIEs
+
+q_list = []
+
+# A query without EDNS0
+q_list.append((l, dns.message.make_query(n, t)))
+
+# The same query, but with EDNS0 turned on with no options
+q_list.append((l,dns.message.make_query(n, t, use_edns=0)))
+
+# Use use_edns() to specify EDNS0 options, such as buffer size
+this_q = dns.message.make_query(n, t)
+this_q.use_edns(0, payload=2000)
+q_list.append((l, this_q))
+
+# With an NSID option
+# use_edns=0 is not needed if options are specified)
+q_list.append((l, dns.message.make_query(n, t,\
+ options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b'')])))
+
+# With an NSID option, but with use_edns() to specify the options
+this_q = dns.message.make_query(n, t)
+this_q.use_edns(0, options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b'')])
+q_list.append((l, this_q))
+
+# With a COOKIE
+q_list.append((i, dns.message.make_query(n, t,\
+ options=[dns.edns.GenericOption(dns.edns.OptionType.COOKIE, b'0xfe11ac99bebe3322')])))
+
+# With an ECS option using dns.edns.ECSOption to form the option
+q_list.append((l, dns.message.make_query(n, t,\
+ options=[dns.edns.ECSOption('192.168.0.0', 20)])))
+
+for (addr, q) in q_list:
+ r = dns.query.udp(q, addr)
+ if not r.options:
+ print('No EDNS options returned')
+ else:
+ for o in r.options:
+ print(o.otype.value, o.data)
+ print()
+
diff --git a/pyproject.toml b/pyproject.toml
index ae662ed..51bfbae 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -7,6 +7,26 @@ license = "ISC"
packages = [
{include = "dns"}
]
+include = [
+ { path="LICENSE", format="sdist" },
+ { path="README.md", format="sdist" },
+ { path="examples/*.txt", format="sdist" },
+ { path="examples/*.py", format="sdist" },
+ { path="tests/*.txt", format="sdist" },
+ { path="tests/*.py", format="sdist" },
+ { path="tests/*.good", format="sdist" },
+ { path="tests/example", format="sdist" },
+ { path="tests/query", format="sdist" },
+ { path="tests/*.pickle", format="sdist" },
+ { path="tests/*.text", format="sdist" },
+ { path="tests/*.generic", format="sdist" },
+ { path="util/**", format="sdist" },
+ { path="setup.cfg", format="sdist" },
+]
+exclude = [
+ "**/.DS_Store",
+ "**/__pycache__/**",
+]
[tool.poetry.dependencies]
python = "^3.6"
@@ -40,4 +60,4 @@ curio = ['curio', 'sniffio']
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
-[tool.setuptools_scm] \ No newline at end of file
+[tool.setuptools_scm]
diff --git a/tests/test_async.py b/tests/test_async.py
index cad7e20..0782c7a 100644
--- a/tests/test_async.py
+++ b/tests/test_async.py
@@ -216,14 +216,6 @@ class AsyncTests(unittest.TestCase):
return await dns.asyncresolver.canonical_name(name)
self.assertEqual(self.async_run(run), cname)
- def testResolverBadScheme(self):
- res = dns.asyncresolver.Resolver(configure=False)
- res.nameservers = ['bogus://dns.google/dns-query']
- async def run():
- answer = await res.resolve('dns.google', 'A')
- def bad():
- self.async_run(run)
- self.assertRaises(dns.resolver.NoNameservers, bad)
def testZoneForName1(self):
async def run():
diff --git a/tests/test_dnssec.py b/tests/test_dnssec.py
index 6ea51dc..b018b86 100644
--- a/tests/test_dnssec.py
+++ b/tests/test_dnssec.py
@@ -499,13 +499,14 @@ class DNSSECMakeDSTestCase(unittest.TestCase):
def testInvalidDigestType(self): # type: () -> None
digest_type_errors = {
- 0: 'digest type 0 is reserved',
- 5: 'unknown digest type',
+ (dns.rdatatype.DS, 0): 'digest type 0 is reserved',
+ (dns.rdatatype.DS, 5): 'unknown digest type',
+ (dns.rdatatype.CDS, 5): 'unknown digest type',
}
- for digest_type, msg in digest_type_errors.items():
+ for (rdtype, digest_type), msg in digest_type_errors.items():
with self.assertRaises(dns.exception.SyntaxError) as cm:
dns.rdata.from_text(dns.rdataclass.IN,
- dns.rdatatype.DS,
+ rdtype,
f'18673 3 {digest_type} 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7')
self.assertEqual(msg, str(cm.exception))
@@ -526,6 +527,20 @@ class DNSSECMakeDSTestCase(unittest.TestCase):
self.assertEqual('digest length inconsistent with digest type', str(cm.exception))
+ def testInvalidDigestLengthCDS0(self): # type: () -> None
+ # Make sure the construction is working
+ dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CDS, f'0 0 0 00')
+
+ test_records = {
+ 'digest length inconsistent with digest type': ['0 0 0', '0 0 0 0000'],
+ 'Odd-length string': ['0 0 0 0', '0 0 0 000'],
+ }
+ for msg, records in test_records.items():
+ for record in records:
+ with self.assertRaises(dns.exception.SyntaxError) as cm:
+ dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CDS, record)
+ self.assertEqual(msg, str(cm.exception))
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_doh.py b/tests/test_doh.py
index 793a500..835e07d 100644
--- a/tests/test_doh.py
+++ b/tests/test_doh.py
@@ -139,12 +139,6 @@ class DNSOverHTTPSTestCase(unittest.TestCase):
self.assertTrue('8.8.8.8' in seen)
self.assertTrue('8.8.4.4' in seen)
- def test_resolver_bad_scheme(self):
- res = dns.resolver.Resolver(configure=False)
- res.nameservers = ['bogus://dns.google/dns-query']
- def bad():
- answer = res.resolve('dns.google', 'A')
- self.assertRaises(dns.resolver.NoNameservers, bad)
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_message.py b/tests/test_message.py
index 19738e6..ad30298 100644
--- a/tests/test_message.py
+++ b/tests/test_message.py
@@ -437,6 +437,16 @@ class MessageTestCase(unittest.TestCase):
self.assertEqual(q.payload, 4096)
self.assertEqual(q.options, ())
+ def test_setting_id(self):
+ q = dns.message.make_query('www.dnspython.org.', 'a', id=12345)
+ self.assertEqual(q.id, 12345)
+
+ def test_setting_flags(self):
+ q = dns.message.make_query('www.dnspython.org.', 'a',
+ flags=dns.flags.RD|dns.flags.CD)
+ self.assertEqual(q.flags, dns.flags.RD|dns.flags.CD)
+ self.assertEqual(q.flags, 0x0110)
+
def test_generic_message_class(self):
q1 = dns.message.Message(id=1)
q1.set_opcode(dns.opcode.NOTIFY)
@@ -558,7 +568,7 @@ www.example. IN CNAME
;AUTHORITY
example. 300 IN SOA . . 1 2 3 4 5
''')
- # passing is actuall not going into an infinite loop in this call
+ # passing is not going into an infinite loop in this call
result = r.resolve_chaining()
self.assertEqual(result.canonical_name,
dns.name.from_text('www.example.'))
@@ -680,6 +690,46 @@ flags QR
m = dns.message.from_wire(goodwire)
self.assertIsInstance(m.flags, dns.flags.Flag)
self.assertEqual(m.flags, dns.flags.Flag.RD)
+
+ def test_continue_on_error(self):
+ good_message = dns.message.from_text(
+"""id 1234
+opcode QUERY
+rcode NOERROR
+flags QR AA RD
+;QUESTION
+www.dnspython.org. IN SOA
+;ANSWER
+www.dnspython.org. 300 IN SOA . . 1 2 3 4 5
+www.dnspython.org. 300 IN A 1.2.3.4
+www.dnspython.org. 300 IN AAAA ::1
+""")
+ wire = good_message.to_wire()
+ # change ANCOUNT to 255
+ bad_wire = wire[:6] + b'\x00\xff' + wire[8:]
+ # change AAAA into rdata with rdlen 0
+ bad_wire = bad_wire[:-18] + b'\x00' * 2
+ # change SOA MINIMUM field to 0xffffffff (too large)
+ bad_wire = bad_wire.replace(b'\x00\x00\x00\x05', b'\xff' * 4)
+ m = dns.message.from_wire(bad_wire, continue_on_error=True)
+ self.assertEqual(len(m.errors), 3)
+ print(m.errors)
+ self.assertEqual(str(m.errors[0].exception), 'value too large')
+ self.assertEqual(str(m.errors[1].exception),
+ 'IPv6 addresses are 16 bytes long')
+ self.assertEqual(str(m.errors[2].exception),
+ 'DNS message is malformed.')
+ expected_message = dns.message.from_text(
+"""id 1234
+opcode QUERY
+rcode NOERROR
+flags QR AA RD
+;QUESTION
+www.dnspython.org. IN SOA
+;ANSWER
+www.dnspython.org. 300 IN A 1.2.3.4
+""")
+ self.assertEqual(m, expected_message)
if __name__ == '__main__':
diff --git a/tests/test_rdata.py b/tests/test_rdata.py
index 05ec6ca..f87ff56 100644
--- a/tests/test_rdata.py
+++ b/tests/test_rdata.py
@@ -696,7 +696,96 @@ class RdataTestCase(unittest.TestCase):
rr = dns.rdata.from_text('IN', 'DNSKEY', input_variation)
new_text = rr.to_text(chunksize=chunksize)
self.assertEqual(output, new_text)
+
+ def test_relative_vs_absolute_compare_unstrict(self):
+ try:
+ saved = dns.rdata._allow_relative_comparisons
+ dns.rdata._allow_relative_comparisons = True
+ r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.')
+ r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www')
+ self.assertFalse(r1 == r2)
+ self.assertTrue(r1 != r2)
+ self.assertFalse(r1 < r2)
+ self.assertFalse(r1 <= r2)
+ self.assertTrue(r1 > r2)
+ self.assertTrue(r1 >= r2)
+ self.assertTrue(r2 < r1)
+ self.assertTrue(r2 <= r1)
+ self.assertFalse(r2 > r1)
+ self.assertFalse(r2 >= r1)
+ finally:
+ dns.rdata._allow_relative_comparisons = saved
+
+ def test_relative_vs_absolute_compare_strict(self):
+ try:
+ saved = dns.rdata._allow_relative_comparisons
+ dns.rdata._allow_relative_comparisons = False
+ r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.')
+ r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www')
+ self.assertFalse(r1 == r2)
+ self.assertTrue(r1 != r2)
+ def bad1():
+ r1 < r2
+ def bad2():
+ r1 <= r2
+ def bad3():
+ r1 > r2
+ def bad4():
+ r1 >= r2
+ self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad1)
+ self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad2)
+ self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad3)
+ self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad4)
+ finally:
+ dns.rdata._allow_relative_comparisons = saved
+
+ def test_absolute_vs_absolute_compare(self):
+ r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.')
+ r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx.')
+ self.assertFalse(r1 == r2)
+ self.assertTrue(r1 != r2)
+ self.assertTrue(r1 < r2)
+ self.assertTrue(r1 <= r2)
+ self.assertFalse(r1 > r2)
+ self.assertFalse(r1 >= r2)
+ def test_relative_vs_relative_compare_unstrict(self):
+ try:
+ saved = dns.rdata._allow_relative_comparisons
+ dns.rdata._allow_relative_comparisons = True
+ r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www')
+ r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx')
+ self.assertFalse(r1 == r2)
+ self.assertTrue(r1 != r2)
+ self.assertTrue(r1 < r2)
+ self.assertTrue(r1 <= r2)
+ self.assertFalse(r1 > r2)
+ self.assertFalse(r1 >= r2)
+ finally:
+ dns.rdata._allow_relative_comparisons = saved
+
+ def test_relative_vs_relative_compare_strict(self):
+ try:
+ saved = dns.rdata._allow_relative_comparisons
+ dns.rdata._allow_relative_comparisons = False
+ r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www')
+ r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx')
+ self.assertFalse(r1 == r2)
+ self.assertTrue(r1 != r2)
+ def bad1():
+ r1 < r2
+ def bad2():
+ r1 <= r2
+ def bad3():
+ r1 > r2
+ def bad4():
+ r1 >= r2
+ self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad1)
+ self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad2)
+ self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad3)
+ self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad4)
+ finally:
+ dns.rdata._allow_relative_comparisons = saved
class UtilTestCase(unittest.TestCase):
diff --git a/tests/test_resolver.py b/tests/test_resolver.py
index b2a47d2..ecd1bf2 100644
--- a/tests/test_resolver.py
+++ b/tests/test_resolver.py
@@ -700,6 +700,16 @@ class LiveResolverTests(unittest.TestCase):
cname = dns.name.from_text('dangling-target.dnspython.org')
self.assertEqual(dns.resolver.canonical_name(name), cname)
+ def testNameserverSetting(self):
+ res = dns.resolver.Resolver(configure=False)
+ ns = ['1.2.3.4', '::1', 'https://ns.example']
+ res.nameservers = ns[:]
+ self.assertEqual(res.nameservers, ns)
+ for ns in ['999.999.999.999', 'ns.example.', 'bogus://ns.example']:
+ with self.assertRaises(ValueError):
+ res.nameservers = [ns]
+
+
class PollingMonkeyPatchMixin(object):
def setUp(self):
self.__native_selector_class = dns.query._selector_class
diff --git a/tests/test_zonedigest.py b/tests/test_zonedigest.py
index f98e5f7..d94be24 100644
--- a/tests/test_zonedigest.py
+++ b/tests/test_zonedigest.py
@@ -176,3 +176,18 @@ class ZoneDigestTestCase(unittest.TestCase):
with self.assertRaises(dns.exception.SyntaxError):
dns.rdata.from_text('IN', 'ZONEMD', '100 1 0 ' + self.sha384_hash)
+ sorting_zone = textwrap.dedent('''
+ @ 86400 IN SOA ns1 admin 2018031900 (
+ 1800 900 604800 86400 )
+ 86400 IN NS ns1
+ 86400 IN NS ns2
+ 86400 IN RP n1.example. a.
+ 86400 IN RP n1. b.
+ ''')
+
+ def test_relative_zone_sorting(self):
+ z1 = dns.zone.from_text(self.sorting_zone, 'example.', relativize=True)
+ z2 = dns.zone.from_text(self.sorting_zone, 'example.', relativize=False)
+ zmd1 = z1.compute_digest(dns.zone.DigestHashAlgorithm.SHA384)
+ zmd2 = z2.compute_digest(dns.zone.DigestHashAlgorithm.SHA384)
+ self.assertEqual(zmd1, zmd2)