diff options
-rw-r--r-- | Makefile | 7 | ||||
-rw-r--r-- | dns/dnssec.py | 8 | ||||
-rw-r--r-- | dns/message.py | 238 | ||||
-rw-r--r-- | dns/query.py | 4 | ||||
-rw-r--r-- | dns/rdata.py | 74 | ||||
-rw-r--r-- | dns/rdtypes/ANY/CDS.py | 5 | ||||
-rw-r--r-- | dns/rdtypes/dsbase.py | 24 | ||||
-rw-r--r-- | dns/resolver.py | 15 | ||||
-rw-r--r-- | dns/zone.py | 9 | ||||
-rw-r--r-- | doc/examples.rst | 10 | ||||
-rw-r--r-- | doc/manual.rst | 1 | ||||
-rwxr-xr-x | examples/edns.py | 52 | ||||
-rw-r--r-- | pyproject.toml | 22 | ||||
-rw-r--r-- | tests/test_async.py | 8 | ||||
-rw-r--r-- | tests/test_dnssec.py | 23 | ||||
-rw-r--r-- | tests/test_doh.py | 6 | ||||
-rw-r--r-- | tests/test_message.py | 52 | ||||
-rw-r--r-- | tests/test_rdata.py | 89 | ||||
-rw-r--r-- | tests/test_resolver.py | 10 | ||||
-rw-r--r-- | tests/test_zonedigest.py | 15 |
20 files changed, 522 insertions, 150 deletions
@@ -69,6 +69,11 @@ pocov: poetry run coverage html --include 'dns*' poetry run coverage report --include 'dns*' -pokit: +oldpokit: po run python setup.py sdist --formats=zip bdist_wheel +pokit: + po build + +findjunk: + find dns -type f | egrep -v '.*\.py' | egrep -v 'py\.typed' diff --git a/dns/dnssec.py b/dns/dnssec.py index f09ecd6..6e9946f 100644 --- a/dns/dnssec.py +++ b/dns/dnssec.py @@ -404,13 +404,13 @@ def _validate_rrsig(rrset, rrsig, keys, origin=None, now=None): rrnamebuf = rrname.to_digestable() rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass, rrsig.original_ttl) - for rr in sorted(rdataset): + rdatas = [rdata.to_digestable(origin) for rdata in rdataset] + for rdata in sorted(rdatas): data += rrnamebuf data += rrfixed - rrdata = rr.to_digestable(origin) - rrlen = struct.pack('!H', len(rrdata)) + rrlen = struct.pack('!H', len(rdata)) data += rrlen - data += rrdata + data += rdata chosen_hash = _make_hash(rrsig.algorithm) diff --git a/dns/message.py b/dns/message.py index 6fa90ca..1e67a17 100644 --- a/dns/message.py +++ b/dns/message.py @@ -108,6 +108,12 @@ class MessageSection(dns.enum.IntEnum): return 3 +class MessageError: + def __init__(self, exception, offset): + self.exception = exception + self.offset = offset + + DEFAULT_EDNS_PAYLOAD = 1232 MAX_CHAIN = 16 @@ -132,6 +138,7 @@ class Message: self.origin = None self.tsig_ctx = None self.index = {} + self.errors = [] @property def question(self): @@ -873,11 +880,14 @@ class _WireReader: ignore_trailing: Ignore trailing junk at end of request? multi: Is this message part of a multi-message sequence? DNS dynamic updates. + continue_on_error: try to extract as much information as possible from + the message, accumulating MessageErrors in the *errors* attribute instead of + raising them. """ def __init__(self, wire, initialize_message, question_only=False, one_rr_per_rrset=False, ignore_trailing=False, - keyring=None, multi=False): + keyring=None, multi=False, continue_on_error=False): self.parser = dns.wire.Parser(wire) self.message = None self.initialize_message = initialize_message @@ -886,6 +896,8 @@ class _WireReader: self.ignore_trailing = ignore_trailing self.keyring = keyring self.multi = multi + self.continue_on_error = continue_on_error + self.errors = [] def _get_question(self, section_number, qcount): """Read the next *qcount* records from the wire data and add them to @@ -902,11 +914,14 @@ class _WireReader: self.message.find_rrset(section, qname, rdclass, rdtype, create=True, force_unique=True) + def _add_error(self, e): + self.errors.append(MessageError(e, self.parser.current)) + def _get_section(self, section_number, count): """Read the next I{count} records from the wire data and add them to the specified section. - section: the section of the message to which to add records + section_number: the section of the message to which to add records count: the number of records to read """ @@ -929,55 +944,65 @@ class _WireReader: (rdclass, rdtype, deleting, empty) = \ self.message._parse_rr_header(section_number, name, rdclass, rdtype) - if empty: - if rdlen > 0: - raise dns.exception.FormError - rd = None - covers = dns.rdatatype.NONE - else: - with self.parser.restrict_to(rdlen): - rd = dns.rdata.from_wire_parser(rdclass, rdtype, - self.parser, - self.message.origin) - covers = rd.covers() - if self.message.xfr and rdtype == dns.rdatatype.SOA: - force_unique = True - if rdtype == dns.rdatatype.OPT: - self.message.opt = dns.rrset.from_rdata(name, ttl, rd) - elif rdtype == dns.rdatatype.TSIG: - if self.keyring is None: - raise UnknownTSIGKey('got signed message without keyring') - if isinstance(self.keyring, dict): - key = self.keyring.get(absolute_name) - if isinstance(key, bytes): - key = dns.tsig.Key(absolute_name, key, rd.algorithm) - elif callable(self.keyring): - key = self.keyring(self.message, absolute_name) + try: + rdata_start = self.parser.current + if empty: + if rdlen > 0: + raise dns.exception.FormError + rd = None + covers = dns.rdatatype.NONE else: - key = self.keyring - if key is None: - raise UnknownTSIGKey("key '%s' unknown" % name) - self.message.keyring = key - self.message.tsig_ctx = \ - dns.tsig.validate(self.parser.wire, - key, - absolute_name, - rd, - int(time.time()), - self.message.request_mac, - rr_start, - self.message.tsig_ctx, - self.multi) - self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) - else: - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, - force_unique) - if rd is not None: - if ttl > 0x7fffffff: - ttl = 0 - rrset.add(rd, ttl) + with self.parser.restrict_to(rdlen): + rd = dns.rdata.from_wire_parser(rdclass, rdtype, + self.parser, + self.message.origin) + covers = rd.covers() + if self.message.xfr and rdtype == dns.rdatatype.SOA: + force_unique = True + if rdtype == dns.rdatatype.OPT: + self.message.opt = dns.rrset.from_rdata(name, ttl, rd) + elif rdtype == dns.rdatatype.TSIG: + if self.keyring is None: + raise UnknownTSIGKey('got signed message without ' + 'keyring') + if isinstance(self.keyring, dict): + key = self.keyring.get(absolute_name) + if isinstance(key, bytes): + key = dns.tsig.Key(absolute_name, key, rd.algorithm) + elif callable(self.keyring): + key = self.keyring(self.message, absolute_name) + else: + key = self.keyring + if key is None: + raise UnknownTSIGKey("key '%s' unknown" % name) + self.message.keyring = key + self.message.tsig_ctx = \ + dns.tsig.validate(self.parser.wire, + key, + absolute_name, + rd, + int(time.time()), + self.message.request_mac, + rr_start, + self.message.tsig_ctx, + self.multi) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, + rd) + else: + rrset = self.message.find_rrset(section, name, + rdclass, rdtype, covers, + deleting, True, + force_unique) + if rd is not None: + if ttl > 0x7fffffff: + ttl = 0 + rrset.add(rd, ttl) + except Exception as e: + if self.continue_on_error: + self._add_error(e) + self.parser.seek(rdata_start + rdlen) + else: + raise def read(self): """Read a wire format DNS message and build a dns.message.Message @@ -993,69 +1018,82 @@ class _WireReader: self.initialize_message(self.message) self.one_rr_per_rrset = \ self.message._get_one_rr_per_rrset(self.one_rr_per_rrset) - self._get_question(MessageSection.QUESTION, qcount) - if self.question_only: - return self.message - self._get_section(MessageSection.ANSWER, ancount) - self._get_section(MessageSection.AUTHORITY, aucount) - self._get_section(MessageSection.ADDITIONAL, adcount) - if not self.ignore_trailing and self.parser.remaining() != 0: - raise TrailingJunk - if self.multi and self.message.tsig_ctx and not self.message.had_tsig: - self.message.tsig_ctx.update(self.parser.wire) + try: + self._get_question(MessageSection.QUESTION, qcount) + if self.question_only: + return self.message + self._get_section(MessageSection.ANSWER, ancount) + self._get_section(MessageSection.AUTHORITY, aucount) + self._get_section(MessageSection.ADDITIONAL, adcount) + if not self.ignore_trailing and self.parser.remaining() != 0: + raise TrailingJunk + if self.multi and self.message.tsig_ctx and \ + not self.message.had_tsig: + self.message.tsig_ctx.update(self.parser.wire) + except Exception as e: + if self.continue_on_error: + self._add_error(e) + else: + raise return self.message def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, tsig_ctx=None, multi=False, question_only=False, one_rr_per_rrset=False, - ignore_trailing=False, raise_on_truncation=False): - """Convert a DNS wire format message into a message - object. + ignore_trailing=False, raise_on_truncation=False, + continue_on_error=False): + """Convert a DNS wire format message into a message object. - *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use - if the message is signed. + *keyring*, a ``dns.tsig.Key`` or ``dict``, the key or keyring to use if the + message is signed. - *request_mac*, a ``bytes``. If the message is a response to a - TSIG-signed request, *request_mac* should be set to the MAC of - that request. + *request_mac*, a ``bytes``. If the message is a response to a TSIG-signed + request, *request_mac* should be set to the MAC of that request. - *xfr*, a ``bool``, should be set to ``True`` if this message is part of - a zone transfer. + *xfr*, a ``bool``, should be set to ``True`` if this message is part of a + zone transfer. - *origin*, a ``dns.name.Name`` or ``None``. If the message is part - of a zone transfer, *origin* should be the origin name of the - zone. If not ``None``, names will be relativized to the origin. + *origin*, a ``dns.name.Name`` or ``None``. If the message is part of a zone + transfer, *origin* should be the origin name of the zone. If not ``None``, + names will be relativized to the origin. *tsig_ctx*, a ``dns.tsig.HMACTSig`` or ``dns.tsig.GSSTSig`` object, the ongoing TSIG context, used when validating zone transfers. - *multi*, a ``bool``, should be set to ``True`` if this message is - part of a multiple message sequence. + *multi*, a ``bool``, should be set to ``True`` if this message is part of a + multiple message sequence. - *question_only*, a ``bool``. If ``True``, read only up to - the end of the question section. + *question_only*, a ``bool``. If ``True``, read only up to the end of the + question section. - *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its - own RRset. + *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own + RRset. - *ignore_trailing*, a ``bool``. If ``True``, ignore trailing - junk at end of the message. + *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of + the message. - *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if - the TC bit is set. + *raise_on_truncation*, a ``bool``. If ``True``, raise an exception if the + TC bit is set. + + *continue_on_error*, a ``bool``. If ``True``, try to continue parsing even + if errors occur. Erroneous rdata will be ignored. Errors will be + accumulated as a list of MessageError objects in the message's ``errors`` + attribute. This option is recommended only for DNS analysis tools, or for + use in a server as part of an error handling path. The default is + ``False``. Raises ``dns.message.ShortHeader`` if the message is less than 12 octets long. - Raises ``dns.message.TrailingJunk`` if there were octets in the message - past the end of the proper DNS message, and *ignore_trailing* is ``False``. + Raises ``dns.message.TrailingJunk`` if there were octets in the message past + the end of the proper DNS message, and *ignore_trailing* is ``False``. - Raises ``dns.message.BadEDNS`` if an OPT record was in the - wrong section, or occurred more than once. + Raises ``dns.message.BadEDNS`` if an OPT record was in the wrong section, or + occurred more than once. - Raises ``dns.message.BadTSIG`` if a TSIG record was not the last - record of the additional data section. + Raises ``dns.message.BadTSIG`` if a TSIG record was not the last record of + the additional data section. Raises ``dns.message.Truncated`` if the TC flag is set and *raise_on_truncation* is ``True``. @@ -1070,7 +1108,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, message.tsig_ctx = tsig_ctx reader = _WireReader(wire, initialize_message, question_only, - one_rr_per_rrset, ignore_trailing, keyring, multi) + one_rr_per_rrset, ignore_trailing, keyring, multi, + continue_on_error) try: m = reader.read() except dns.exception.FormError: @@ -1083,6 +1122,8 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, # have to do this check here too. if m.flags & dns.flags.TC and raise_on_truncation: raise Truncated(message=m) + if continue_on_error: + m.errors = reader.errors return m @@ -1383,7 +1424,8 @@ def from_file(f, idna_codec=None, one_rr_per_rrset=False): def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, want_dnssec=False, ednsflags=None, payload=None, - request_payload=None, options=None, idna_codec=None): + request_payload=None, options=None, idna_codec=None, + id=None, flags=dns.flags.RD): """Make a query message. The query name, type, and class may all be specified either @@ -1400,7 +1442,9 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, is class IN. *use_edns*, an ``int``, ``bool`` or ``None``. The EDNS level to use; the - default is None (no EDNS). + default is ``None``. If ``None``, EDNS will be enabled only if other + parameters (*ednsflags*, *payload*, *request_payload*, or *options*) are + set. See the description of dns.message.Message.use_edns() for the possible values for use_edns and their meanings. @@ -1423,6 +1467,12 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, encoder/decoder. If ``None``, the default IDNA 2003 encoder/decoder is used. + *id*, an ``int`` or ``None``, the desired query id. The default is + ``None``, which generates a random query id. + + *flags*, an ``int``, the desired query flags. The default is + ``dns.flags.RD``. + Returns a ``dns.message.QueryMessage`` """ @@ -1430,8 +1480,8 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, qname = dns.name.from_text(qname, idna_codec=idna_codec) rdtype = dns.rdatatype.RdataType.make(rdtype) rdclass = dns.rdataclass.RdataClass.make(rdclass) - m = QueryMessage() - m.flags |= dns.flags.RD + m = QueryMessage(id=id) + m.flags = dns.flags.Flag(flags) m.find_rrset(m.question, qname, rdclass, rdtype, create=True, force_unique=True) # only pass keywords on to use_edns if they have been set to a diff --git a/dns/query.py b/dns/query.py index 934bf41..fee5d6a 100644 --- a/dns/query.py +++ b/dns/query.py @@ -548,7 +548,7 @@ def udp_with_fallback(q, where, timeout=None, port=53, source=None, if a socket is provided, it must be a nonblocking datagram socket, and the *source* and *source_port* are ignored for the UDP query. - *tcp_sock*, a ``socket.socket``, or ``None``, the socket to use for the + *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the TCP query. If ``None``, the default, a socket is created. Note that if a socket is provided, it must be a nonblocking connected stream socket, and *where*, *source* and *source_port* are ignored for the TCP @@ -702,7 +702,7 @@ def tcp(q, where, timeout=None, port=53, source=None, source_port=0, *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the received message. - *sock*, a ``socket.socket``, or ``None``, the socket to use for the + *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the query. If ``None``, the default, a socket is created. Note that if a socket is provided, it must be a nonblocking connected stream socket, and *where*, *port*, *source* and *source_port* are ignored. diff --git a/dns/rdata.py b/dns/rdata.py index 0831c41..624063e 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -38,6 +38,22 @@ import dns.ttl _chunksize = 32 +# We currently allow comparisons for rdata with relative names for backwards +# compatibility, but in the future we will not, as these kinds of comparisons +# can lead to subtle bugs if code is not carefully written. +# +# This switch allows the future behavior to be turned on so code can be +# tested with it. +_allow_relative_comparisons = True + + +class NoRelativeRdataOrdering(dns.exception.DNSException): + """An attempt was made to do an ordered comparison of one or more + rdata with relative names. The only reliable way of sorting rdata + is to use non-relativized rdata. + + """ + def _wordbreak(data, chunksize=_chunksize, separator=b' '): """Break a binary string into chunks of chunksize characters separated by @@ -232,12 +248,42 @@ class Rdata: """Compare an rdata with another rdata of the same rdtype and rdclass. - Return < 0 if self < other in the DNSSEC ordering, 0 if self - == other, and > 0 if self > other. - + For rdata with only absolute names: + Return < 0 if self < other in the DNSSEC ordering, 0 if self + == other, and > 0 if self > other. + For rdata with at least one relative names: + The rdata sorts before any rdata with only absolute names. + When compared with another relative rdata, all names are + made absolute as if they were relative to the root, as the + proper origin is not available. While this creates a stable + ordering, it is NOT guaranteed to be the DNSSEC ordering. + In the future, all ordering comparisons for rdata with + relative names will be disallowed. """ - our = self.to_digestable(dns.name.root) - their = other.to_digestable(dns.name.root) + try: + our = self.to_digestable() + our_relative = False + except dns.name.NeedAbsoluteNameOrOrigin: + if _allow_relative_comparisons: + our = self.to_digestable(dns.name.root) + our_relative = True + try: + their = other.to_digestable() + their_relative = False + except dns.name.NeedAbsoluteNameOrOrigin: + if _allow_relative_comparisons: + their = other.to_digestable(dns.name.root) + their_relative = True + if _allow_relative_comparisons: + if our_relative != their_relative: + # For the purpose of comparison, all rdata with at least one + # relative name is less than an rdata with only absolute names. + if our_relative: + return -1 + else: + return 1 + elif our_relative or their_relative: + raise NoRelativeRdataOrdering if our == their: return 0 elif our > their: @@ -250,14 +296,28 @@ class Rdata: return False if self.rdclass != other.rdclass or self.rdtype != other.rdtype: return False - return self._cmp(other) == 0 + our_relative = False + their_relative = False + try: + our = self.to_digestable() + except dns.name.NeedAbsoluteNameOrOrigin: + our = self.to_digestable(dns.name.root) + our_relative = True + try: + their = other.to_digestable() + except dns.name.NeedAbsoluteNameOrOrigin: + their = other.to_digestable(dns.name.root) + their_relative = True + if our_relative != their_relative: + return False + return our == their def __ne__(self, other): if not isinstance(other, Rdata): return True if self.rdclass != other.rdclass or self.rdtype != other.rdtype: return True - return self._cmp(other) != 0 + return not self.__eq__(other) def __lt__(self, other): if not isinstance(other, Rdata) or \ diff --git a/dns/rdtypes/ANY/CDS.py b/dns/rdtypes/ANY/CDS.py index 39e3556..094de12 100644 --- a/dns/rdtypes/ANY/CDS.py +++ b/dns/rdtypes/ANY/CDS.py @@ -23,3 +23,8 @@ import dns.immutable class CDS(dns.rdtypes.dsbase.DSBase): """CDS record""" + + _digest_length_by_type = { + **dns.rdtypes.dsbase.DSBase._digest_length_by_type, + 0: 1, # delete, RFC 8078 Sec. 4 (including Errata ID 5049) + } diff --git a/dns/rdtypes/dsbase.py b/dns/rdtypes/dsbase.py index d125db2..403e937 100644 --- a/dns/rdtypes/dsbase.py +++ b/dns/rdtypes/dsbase.py @@ -24,15 +24,6 @@ import dns.rdata import dns.rdatatype -# Digest types registry: https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml -_digest_length_by_type = { - 1: 20, # SHA-1, RFC 3658 Sec. 2.4 - 2: 32, # SHA-256, RFC 4509 Sec. 2.2 - 3: 32, # GOST R 34.11-94, RFC 5933 Sec. 4 in conjunction with RFC 4490 Sec. 2.1 - 4: 48, # SHA-384, RFC 6605 Sec. 2 -} - - @dns.immutable.immutable class DSBase(dns.rdata.Rdata): @@ -40,6 +31,14 @@ class DSBase(dns.rdata.Rdata): __slots__ = ['key_tag', 'algorithm', 'digest_type', 'digest'] + # Digest types registry: https://www.iana.org/assignments/ds-rr-types/ds-rr-types.xhtml + _digest_length_by_type = { + 1: 20, # SHA-1, RFC 3658 Sec. 2.4 + 2: 32, # SHA-256, RFC 4509 Sec. 2.2 + 3: 32, # GOST R 34.11-94, RFC 5933 Sec. 4 in conjunction with RFC 4490 Sec. 2.1 + 4: 48, # SHA-384, RFC 6605 Sec. 2 + } + def __init__(self, rdclass, rdtype, key_tag, algorithm, digest_type, digest): super().__init__(rdclass, rdtype) @@ -49,13 +48,12 @@ class DSBase(dns.rdata.Rdata): self.digest = self._as_bytes(digest) try: + if len(self.digest) != self._digest_length_by_type[self.digest_type]: + raise ValueError('digest length inconsistent with digest type') + except KeyError: if self.digest_type == 0: # reserved, RFC 3658 Sec. 2.4 raise ValueError('digest type 0 is reserved') - expected_length = _digest_length_by_type[self.digest_type] - except KeyError: raise ValueError('unknown digest type') - if len(self.digest) != expected_length: - raise ValueError('digest length inconsistent with digest type') def to_text(self, origin=None, relativize=True, **kw): kw = kw.copy() diff --git a/dns/resolver.py b/dns/resolver.py index 6a9974d..108dd52 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -807,7 +807,7 @@ class BaseResolver: f = stack.enter_context(open(f)) except OSError: # /etc/resolv.conf doesn't exist, can't be read, etc. - raise NoResolverConfiguration + raise NoResolverConfiguration(f'cannot open {f}') for l in f: if len(l) == 0 or l[0] == '#' or l[0] == ';': @@ -848,7 +848,7 @@ class BaseResolver: except (ValueError, IndexError): pass if len(self.nameservers) == 0: - raise NoResolverConfiguration + raise NoResolverConfiguration('no nameservers') def _determine_split_char(self, entry): # @@ -1120,6 +1120,14 @@ class BaseResolver: ``list``. """ if isinstance(nameservers, list): + for nameserver in nameservers: + if not dns.inet.is_address(nameserver): + try: + if urlparse(nameserver).scheme != 'https': + raise NotImplementedError + except Exception: + raise ValueError(f'nameserver {nameserver} is not an ' + 'IP address or valid https URL') self._nameservers = nameservers else: raise ValueError('nameservers must be a list' @@ -1219,9 +1227,6 @@ class Resolver(BaseResolver): source_port=source_port, raise_on_truncation=True) else: - protocol = urlparse(nameserver).scheme - if protocol != 'https': - raise NotImplementedError response = dns.query.https(request, nameserver, timeout=timeout) except Exception as ex: diff --git a/dns/zone.py b/dns/zone.py index 9c3204b..d154928 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -733,10 +733,11 @@ class Zone(dns.transaction.TransactionManager): continue rrfixed = struct.pack('!HHI', rdataset.rdtype, rdataset.rdclass, rdataset.ttl) - for rr in sorted(rdataset): - rrdata = rr.to_digestable(self.origin) - rrlen = struct.pack('!H', len(rrdata)) - hasher.update(rrnamebuf + rrfixed + rrlen + rrdata) + rdatas = [rdata.to_digestable(self.origin) + for rdata in rdataset] + for rdata in sorted(rdatas): + rrlen = struct.pack('!H', len(rdata)) + hasher.update(rrnamebuf + rrfixed + rrlen + rdata) return hasher.digest() def compute_digest(self, hash_algorithm, scheme=DigestScheme.SIMPLE): diff --git a/doc/examples.rst b/doc/examples.rst new file mode 100644 index 0000000..4811b48 --- /dev/null +++ b/doc/examples.rst @@ -0,0 +1,10 @@ +.. examples: + +Examples +-------- + +The dnspython source comes with example programs that show how +to use dnspython in practice. You can clone the dnspython source +from GitHub: + git clone https://github.com/rthalley/dnspython.git +The example prgrams are in the ``examples/`` directory. diff --git a/doc/manual.rst b/doc/manual.rst index 0ebbd03..d5ed014 100644 --- a/doc/manual.rst +++ b/doc/manual.rst @@ -17,3 +17,4 @@ Dnspython Manual utilities typing threads + examples diff --git a/examples/edns.py b/examples/edns.py new file mode 100755 index 0000000..a130f85 --- /dev/null +++ b/examples/edns.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 + +import dns.edns +import dns.message +import dns.query +import dns.resolver + +n = '.' +t = dns.rdatatype.SOA +l = '199.7.83.42' # Address of l.root-servers.net +i = '149.20.1.73' # Address of ns1.isc.org, for COOKIEs + +q_list = [] + +# A query without EDNS0 +q_list.append((l, dns.message.make_query(n, t))) + +# The same query, but with EDNS0 turned on with no options +q_list.append((l,dns.message.make_query(n, t, use_edns=0))) + +# Use use_edns() to specify EDNS0 options, such as buffer size +this_q = dns.message.make_query(n, t) +this_q.use_edns(0, payload=2000) +q_list.append((l, this_q)) + +# With an NSID option +# use_edns=0 is not needed if options are specified) +q_list.append((l, dns.message.make_query(n, t,\ + options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b'')]))) + +# With an NSID option, but with use_edns() to specify the options +this_q = dns.message.make_query(n, t) +this_q.use_edns(0, options=[dns.edns.GenericOption(dns.edns.OptionType.NSID, b'')]) +q_list.append((l, this_q)) + +# With a COOKIE +q_list.append((i, dns.message.make_query(n, t,\ + options=[dns.edns.GenericOption(dns.edns.OptionType.COOKIE, b'0xfe11ac99bebe3322')]))) + +# With an ECS option using dns.edns.ECSOption to form the option +q_list.append((l, dns.message.make_query(n, t,\ + options=[dns.edns.ECSOption('192.168.0.0', 20)]))) + +for (addr, q) in q_list: + r = dns.query.udp(q, addr) + if not r.options: + print('No EDNS options returned') + else: + for o in r.options: + print(o.otype.value, o.data) + print() + diff --git a/pyproject.toml b/pyproject.toml index ae662ed..51bfbae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,26 @@ license = "ISC" packages = [ {include = "dns"} ] +include = [ + { path="LICENSE", format="sdist" }, + { path="README.md", format="sdist" }, + { path="examples/*.txt", format="sdist" }, + { path="examples/*.py", format="sdist" }, + { path="tests/*.txt", format="sdist" }, + { path="tests/*.py", format="sdist" }, + { path="tests/*.good", format="sdist" }, + { path="tests/example", format="sdist" }, + { path="tests/query", format="sdist" }, + { path="tests/*.pickle", format="sdist" }, + { path="tests/*.text", format="sdist" }, + { path="tests/*.generic", format="sdist" }, + { path="util/**", format="sdist" }, + { path="setup.cfg", format="sdist" }, +] +exclude = [ + "**/.DS_Store", + "**/__pycache__/**", +] [tool.poetry.dependencies] python = "^3.6" @@ -40,4 +60,4 @@ curio = ['curio', 'sniffio'] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" -[tool.setuptools_scm]
\ No newline at end of file +[tool.setuptools_scm] diff --git a/tests/test_async.py b/tests/test_async.py index cad7e20..0782c7a 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -216,14 +216,6 @@ class AsyncTests(unittest.TestCase): return await dns.asyncresolver.canonical_name(name) self.assertEqual(self.async_run(run), cname) - def testResolverBadScheme(self): - res = dns.asyncresolver.Resolver(configure=False) - res.nameservers = ['bogus://dns.google/dns-query'] - async def run(): - answer = await res.resolve('dns.google', 'A') - def bad(): - self.async_run(run) - self.assertRaises(dns.resolver.NoNameservers, bad) def testZoneForName1(self): async def run(): diff --git a/tests/test_dnssec.py b/tests/test_dnssec.py index 6ea51dc..b018b86 100644 --- a/tests/test_dnssec.py +++ b/tests/test_dnssec.py @@ -499,13 +499,14 @@ class DNSSECMakeDSTestCase(unittest.TestCase): def testInvalidDigestType(self): # type: () -> None digest_type_errors = { - 0: 'digest type 0 is reserved', - 5: 'unknown digest type', + (dns.rdatatype.DS, 0): 'digest type 0 is reserved', + (dns.rdatatype.DS, 5): 'unknown digest type', + (dns.rdatatype.CDS, 5): 'unknown digest type', } - for digest_type, msg in digest_type_errors.items(): + for (rdtype, digest_type), msg in digest_type_errors.items(): with self.assertRaises(dns.exception.SyntaxError) as cm: dns.rdata.from_text(dns.rdataclass.IN, - dns.rdatatype.DS, + rdtype, f'18673 3 {digest_type} 71b71d4f3e11bbd71b4eff12cde69f7f9215bbe7') self.assertEqual(msg, str(cm.exception)) @@ -526,6 +527,20 @@ class DNSSECMakeDSTestCase(unittest.TestCase): self.assertEqual('digest length inconsistent with digest type', str(cm.exception)) + def testInvalidDigestLengthCDS0(self): # type: () -> None + # Make sure the construction is working + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CDS, f'0 0 0 00') + + test_records = { + 'digest length inconsistent with digest type': ['0 0 0', '0 0 0 0000'], + 'Odd-length string': ['0 0 0 0', '0 0 0 000'], + } + for msg, records in test_records.items(): + for record in records: + with self.assertRaises(dns.exception.SyntaxError) as cm: + dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.CDS, record) + self.assertEqual(msg, str(cm.exception)) + if __name__ == '__main__': unittest.main() diff --git a/tests/test_doh.py b/tests/test_doh.py index 793a500..835e07d 100644 --- a/tests/test_doh.py +++ b/tests/test_doh.py @@ -139,12 +139,6 @@ class DNSOverHTTPSTestCase(unittest.TestCase): self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) - def test_resolver_bad_scheme(self): - res = dns.resolver.Resolver(configure=False) - res.nameservers = ['bogus://dns.google/dns-query'] - def bad(): - answer = res.resolve('dns.google', 'A') - self.assertRaises(dns.resolver.NoNameservers, bad) if __name__ == '__main__': unittest.main() diff --git a/tests/test_message.py b/tests/test_message.py index 19738e6..ad30298 100644 --- a/tests/test_message.py +++ b/tests/test_message.py @@ -437,6 +437,16 @@ class MessageTestCase(unittest.TestCase): self.assertEqual(q.payload, 4096) self.assertEqual(q.options, ()) + def test_setting_id(self): + q = dns.message.make_query('www.dnspython.org.', 'a', id=12345) + self.assertEqual(q.id, 12345) + + def test_setting_flags(self): + q = dns.message.make_query('www.dnspython.org.', 'a', + flags=dns.flags.RD|dns.flags.CD) + self.assertEqual(q.flags, dns.flags.RD|dns.flags.CD) + self.assertEqual(q.flags, 0x0110) + def test_generic_message_class(self): q1 = dns.message.Message(id=1) q1.set_opcode(dns.opcode.NOTIFY) @@ -558,7 +568,7 @@ www.example. IN CNAME ;AUTHORITY example. 300 IN SOA . . 1 2 3 4 5 ''') - # passing is actuall not going into an infinite loop in this call + # passing is not going into an infinite loop in this call result = r.resolve_chaining() self.assertEqual(result.canonical_name, dns.name.from_text('www.example.')) @@ -680,6 +690,46 @@ flags QR m = dns.message.from_wire(goodwire) self.assertIsInstance(m.flags, dns.flags.Flag) self.assertEqual(m.flags, dns.flags.Flag.RD) + + def test_continue_on_error(self): + good_message = dns.message.from_text( +"""id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD +;QUESTION +www.dnspython.org. IN SOA +;ANSWER +www.dnspython.org. 300 IN SOA . . 1 2 3 4 5 +www.dnspython.org. 300 IN A 1.2.3.4 +www.dnspython.org. 300 IN AAAA ::1 +""") + wire = good_message.to_wire() + # change ANCOUNT to 255 + bad_wire = wire[:6] + b'\x00\xff' + wire[8:] + # change AAAA into rdata with rdlen 0 + bad_wire = bad_wire[:-18] + b'\x00' * 2 + # change SOA MINIMUM field to 0xffffffff (too large) + bad_wire = bad_wire.replace(b'\x00\x00\x00\x05', b'\xff' * 4) + m = dns.message.from_wire(bad_wire, continue_on_error=True) + self.assertEqual(len(m.errors), 3) + print(m.errors) + self.assertEqual(str(m.errors[0].exception), 'value too large') + self.assertEqual(str(m.errors[1].exception), + 'IPv6 addresses are 16 bytes long') + self.assertEqual(str(m.errors[2].exception), + 'DNS message is malformed.') + expected_message = dns.message.from_text( +"""id 1234 +opcode QUERY +rcode NOERROR +flags QR AA RD +;QUESTION +www.dnspython.org. IN SOA +;ANSWER +www.dnspython.org. 300 IN A 1.2.3.4 +""") + self.assertEqual(m, expected_message) if __name__ == '__main__': diff --git a/tests/test_rdata.py b/tests/test_rdata.py index 05ec6ca..f87ff56 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -696,7 +696,96 @@ class RdataTestCase(unittest.TestCase): rr = dns.rdata.from_text('IN', 'DNSKEY', input_variation) new_text = rr.to_text(chunksize=chunksize) self.assertEqual(output, new_text) + + def test_relative_vs_absolute_compare_unstrict(self): + try: + saved = dns.rdata._allow_relative_comparisons + dns.rdata._allow_relative_comparisons = True + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.') + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www') + self.assertFalse(r1 == r2) + self.assertTrue(r1 != r2) + self.assertFalse(r1 < r2) + self.assertFalse(r1 <= r2) + self.assertTrue(r1 > r2) + self.assertTrue(r1 >= r2) + self.assertTrue(r2 < r1) + self.assertTrue(r2 <= r1) + self.assertFalse(r2 > r1) + self.assertFalse(r2 >= r1) + finally: + dns.rdata._allow_relative_comparisons = saved + + def test_relative_vs_absolute_compare_strict(self): + try: + saved = dns.rdata._allow_relative_comparisons + dns.rdata._allow_relative_comparisons = False + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.') + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www') + self.assertFalse(r1 == r2) + self.assertTrue(r1 != r2) + def bad1(): + r1 < r2 + def bad2(): + r1 <= r2 + def bad3(): + r1 > r2 + def bad4(): + r1 >= r2 + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad1) + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad2) + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad3) + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad4) + finally: + dns.rdata._allow_relative_comparisons = saved + + def test_absolute_vs_absolute_compare(self): + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www.') + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx.') + self.assertFalse(r1 == r2) + self.assertTrue(r1 != r2) + self.assertTrue(r1 < r2) + self.assertTrue(r1 <= r2) + self.assertFalse(r1 > r2) + self.assertFalse(r1 >= r2) + def test_relative_vs_relative_compare_unstrict(self): + try: + saved = dns.rdata._allow_relative_comparisons + dns.rdata._allow_relative_comparisons = True + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www') + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx') + self.assertFalse(r1 == r2) + self.assertTrue(r1 != r2) + self.assertTrue(r1 < r2) + self.assertTrue(r1 <= r2) + self.assertFalse(r1 > r2) + self.assertFalse(r1 >= r2) + finally: + dns.rdata._allow_relative_comparisons = saved + + def test_relative_vs_relative_compare_strict(self): + try: + saved = dns.rdata._allow_relative_comparisons + dns.rdata._allow_relative_comparisons = False + r1 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'www') + r2 = dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.NS, 'xxx') + self.assertFalse(r1 == r2) + self.assertTrue(r1 != r2) + def bad1(): + r1 < r2 + def bad2(): + r1 <= r2 + def bad3(): + r1 > r2 + def bad4(): + r1 >= r2 + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad1) + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad2) + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad3) + self.assertRaises(dns.rdata.NoRelativeRdataOrdering, bad4) + finally: + dns.rdata._allow_relative_comparisons = saved class UtilTestCase(unittest.TestCase): diff --git a/tests/test_resolver.py b/tests/test_resolver.py index b2a47d2..ecd1bf2 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -700,6 +700,16 @@ class LiveResolverTests(unittest.TestCase): cname = dns.name.from_text('dangling-target.dnspython.org') self.assertEqual(dns.resolver.canonical_name(name), cname) + def testNameserverSetting(self): + res = dns.resolver.Resolver(configure=False) + ns = ['1.2.3.4', '::1', 'https://ns.example'] + res.nameservers = ns[:] + self.assertEqual(res.nameservers, ns) + for ns in ['999.999.999.999', 'ns.example.', 'bogus://ns.example']: + with self.assertRaises(ValueError): + res.nameservers = [ns] + + class PollingMonkeyPatchMixin(object): def setUp(self): self.__native_selector_class = dns.query._selector_class diff --git a/tests/test_zonedigest.py b/tests/test_zonedigest.py index f98e5f7..d94be24 100644 --- a/tests/test_zonedigest.py +++ b/tests/test_zonedigest.py @@ -176,3 +176,18 @@ class ZoneDigestTestCase(unittest.TestCase): with self.assertRaises(dns.exception.SyntaxError): dns.rdata.from_text('IN', 'ZONEMD', '100 1 0 ' + self.sha384_hash) + sorting_zone = textwrap.dedent(''' + @ 86400 IN SOA ns1 admin 2018031900 ( + 1800 900 604800 86400 ) + 86400 IN NS ns1 + 86400 IN NS ns2 + 86400 IN RP n1.example. a. + 86400 IN RP n1. b. + ''') + + def test_relative_zone_sorting(self): + z1 = dns.zone.from_text(self.sorting_zone, 'example.', relativize=True) + z2 = dns.zone.from_text(self.sorting_zone, 'example.', relativize=False) + zmd1 = z1.compute_digest(dns.zone.DigestHashAlgorithm.SHA384) + zmd2 = z2.compute_digest(dns.zone.DigestHashAlgorithm.SHA384) + self.assertEqual(zmd1, zmd2) |