diff options
author | Bob Halley <halley@dnspython.org> | 2020-06-30 11:01:16 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-06-30 11:01:16 -0700 |
commit | 58d7ca7207b9a46b1238fce8bfa4325a400e4ac5 (patch) | |
tree | 9c540090184a5fd27ee06ba51dda5bc2a603bec3 | |
parent | 1e16e7fee759893b17abb9257bdbe96e7edc65f8 (diff) | |
parent | c8ec6ca47397420cb128db9aa4951245ac4ef07f (diff) | |
download | dnspython-58d7ca7207b9a46b1238fce8bfa4325a400e4ac5.tar.gz |
Merge pull request #525 from bwelling/tsig
Adds support for a TSIG record class.
-rw-r--r-- | dns/message.py | 205 | ||||
-rw-r--r-- | dns/message.pyi | 18 | ||||
-rw-r--r-- | dns/query.py | 5 | ||||
-rw-r--r-- | dns/rdtypes/ANY/TSIG.py | 112 | ||||
-rw-r--r-- | dns/rdtypes/ANY/__init__.py | 2 | ||||
-rw-r--r-- | dns/renderer.py | 42 | ||||
-rw-r--r-- | dns/tsig.py | 108 | ||||
-rw-r--r-- | tests/test_renderer.py | 48 | ||||
-rw-r--r-- | tests/test_tsig.py | 8 |
9 files changed, 331 insertions, 217 deletions
diff --git a/dns/message.py b/dns/message.py index 597b329..00359ef 100644 --- a/dns/message.py +++ b/dns/message.py @@ -38,6 +38,7 @@ import dns.renderer import dns.tsig import dns.wiredata import dns.rdtypes.ANY.OPT +import dns.rdtypes.ANY.TSIG class ShortHeader(dns.exception.FormError): @@ -109,20 +110,11 @@ class Message: self.opt = None self.request_payload = 0 self.keyring = None - self.keyname = None - self.keyalgorithm = dns.tsig.default_algorithm + self.tsig = None self.request_mac = b'' - self.other_data = b'' - self.tsig_error = 0 - self.fudge = 300 - self.original_id = self.id - self.mac = b'' self.xfr = False self.origin = None self.tsig_ctx = None - self.had_tsig = False - self.multi = False - self.first = True self.index = {} @property @@ -443,22 +435,31 @@ class Message: for rrset in self.additional: r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw) r.write_header() - if self.keyname is not None: + if self.tsig is not None: + (new_tsig, ctx) = dns.tsig.sign(r.get_wire(), + self.tsig.name, + self.tsig[0], + self.keyring[self.tsig.name], + int(time.time()), + self.request_mac, + tsig_ctx, + multi) + self.tsig.clear() + self.tsig.add(new_tsig) + r.add_rrset(dns.renderer.ADDITIONAL, self.tsig) + r.write_header() if multi: - ctx = r.add_multi_tsig(tsig_ctx, - self.keyname, self.keyring[self.keyname], - self.fudge, self.original_id, - self.tsig_error, self.other_data, - self.request_mac, self.keyalgorithm) self.tsig_ctx = ctx - else: - r.add_tsig(self.keyname, self.keyring[self.keyname], - self.fudge, self.original_id, self.tsig_error, - self.other_data, self.request_mac, - self.keyalgorithm) - self.mac = r.mac return r.get_wire() + @staticmethod + def _make_tsig(keyname, algorithm, time_signed, fudge, mac, original_id, + error, other): + tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, + algorithm, time_signed, fudge, mac, + original_id, error, other) + return dns.rrset.from_rdata(keyname, 0, tsig) + def use_tsig(self, keyring, keyname=None, fudge=300, original_id=None, tsig_error=0, other_data=b'', algorithm=dns.tsig.default_algorithm): @@ -492,19 +493,45 @@ class Message: self.keyring = keyring if keyname is None: - self.keyname = list(self.keyring.keys())[0] - else: - if isinstance(keyname, str): - keyname = dns.name.from_text(keyname) - self.keyname = keyname - self.keyalgorithm = algorithm - self.fudge = fudge + keyname = list(self.keyring.keys())[0] + elif isinstance(keyname, str): + keyname = dns.name.from_text(keyname) if original_id is None: - self.original_id = self.id + original_id = self.id + self.tsig = self._make_tsig(keyname, algorithm, 0, fudge, b'', + original_id, tsig_error, other_data) + + @property + def keyname(self): + if self.tsig: + return self.tsig.name + else: + return None + + @property + def keyalgorithm(self): + if self.tsig: + return self.tsig[0].algorithm + else: + return None + + @property + def mac(self): + if self.tsig: + return self.tsig[0].mac else: - self.original_id = original_id - self.tsig_error = tsig_error - self.other_data = other_data + return None + + @property + def tsig_error(self): + if self.tsig: + return self.tsig[0].error + else: + return None + + @property + def had_tsig(self): + return bool(self.tsig) @staticmethod def _make_opt(flags=0, payload=1280, options=None): @@ -649,11 +676,17 @@ class Message: raise dns.exception.FormError return (rdclass, rdtype, None, False) - def _parse_special_rr_header(self, section, name, rdclass, rdtype): + def _parse_special_rr_header(self, section, count, position, + name, rdclass, rdtype): if rdtype == dns.rdatatype.OPT: if section != MessageSection.ADDITIONAL or self.opt or \ name != dns.name.root: raise BadEDNS + elif rdtype == dns.rdatatype.TSIG: + if section != MessageSection.ADDITIONAL or \ + rdclass != dns.rdatatype.ANY or \ + position != count - 1: + raise BadTSIG return (rdclass, rdtype, None, False) @@ -691,11 +724,12 @@ class _WireReader: question_only: Are we only reading the question? one_rr_per_rrset: Put each RR into its own RRset? ignore_trailing: Ignore trailing junk at end of request? + multi: Is this message part of a multi-message sequence? DNS dynamic updates. """ def __init__(self, wire, initialize_message, question_only=False, - one_rr_per_rrset=False, ignore_trailing=False): + one_rr_per_rrset=False, ignore_trailing=False, multi=False): self.wire = dns.wiredata.maybe_wrap(wire) self.message = None self.current = 0 @@ -703,6 +737,7 @@ class _WireReader: self.question_only = question_only self.one_rr_per_rrset = one_rr_per_rrset self.ignore_trailing = ignore_trailing + self.multi = multi def _get_question(self, section_number, qcount): """Read the next *qcount* records from the wire data and add them to @@ -746,65 +781,55 @@ class _WireReader: struct.unpack('!HHIH', self.wire[self.current:self.current + 10]) self.current += 10 - if rdtype == dns.rdatatype.TSIG: - if not (section is self.message.additional and - i == (count - 1)): - raise BadTSIG + if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG): + (rdclass, rdtype, deleting, empty) = \ + self.message._parse_special_rr_header(section_number, + count, i, name, + rdclass, rdtype) + else: + (rdclass, rdtype, deleting, empty) = \ + self.message._parse_rr_header(section_number, + name, rdclass, rdtype) + if empty: + if rdlen > 0: + raise dns.exception.FormError + rd = None + covers = dns.rdatatype.NONE + else: + rd = dns.rdata.from_wire(rdclass, rdtype, + self.wire, self.current, rdlen, + self.message.origin) + covers = rd.covers() + if self.message.xfr and rdtype == dns.rdatatype.SOA: + force_unique = True + if rdtype == dns.rdatatype.OPT: + self.message.opt = dns.rrset.from_rdata(name, ttl, rd) + elif rdtype == dns.rdatatype.TSIG: if self.message.keyring is None: raise UnknownTSIGKey('got signed message without keyring') secret = self.message.keyring.get(absolute_name) if secret is None: raise UnknownTSIGKey("key '%s' unknown" % name) - self.message.keyname = absolute_name - (self.message.keyalgorithm, self.message.mac) = \ - dns.tsig.get_algorithm_and_mac(self.wire, self.current, - rdlen) self.message.tsig_ctx = \ dns.tsig.validate(self.wire, absolute_name, + rd, secret, int(time.time()), self.message.request_mac, rr_start, - self.current, - rdlen, self.message.tsig_ctx, - self.message.multi, - self.message.first) - self.message.had_tsig = True + self.multi) + self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd) else: - if rdtype == dns.rdatatype.OPT: - (rdclass, rdtype, deleting, empty) = \ - self.message._parse_special_rr_header(section_number, - name, - rdclass, rdtype) - else: - (rdclass, rdtype, deleting, empty) = \ - self.message._parse_rr_header(section_number, - name, rdclass, rdtype) - if empty: - if rdlen > 0: - raise dns.exception.FormError - rd = None - covers = dns.rdatatype.NONE - else: - rd = dns.rdata.from_wire(rdclass, rdtype, - self.wire, self.current, rdlen, - self.message.origin) - covers = rd.covers() - if self.message.xfr and rdtype == dns.rdatatype.SOA: - force_unique = True - if rdtype == dns.rdatatype.OPT: - self.message.opt = dns.rrset.from_rdata(name, ttl, rd) - else: - rrset = self.message.find_rrset(section, name, - rdclass, rdtype, covers, - deleting, True, - force_unique) - if rd is not None: - if ttl > 0x7fffffff: - ttl = 0 - rrset.add(rd, ttl) + rrset = self.message.find_rrset(section, name, + rdclass, rdtype, covers, + deleting, True, + force_unique) + if rd is not None: + if ttl > 0x7fffffff: + ttl = 0 + rrset.add(rd, ttl) self.current += rdlen def read(self): @@ -831,14 +856,13 @@ class _WireReader: self._get_section(MessageSection.ADDITIONAL, adcount) if not self.ignore_trailing and self.current != l: raise TrailingJunk - if self.message.multi and self.message.tsig_ctx and \ - not self.message.had_tsig: + if self.multi and self.message.tsig_ctx and not self.message.had_tsig: self.message.tsig_ctx.update(self.wire) return self.message def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, - tsig_ctx=None, multi=False, first=True, + tsig_ctx=None, multi=False, question_only=False, one_rr_per_rrset=False, ignore_trailing=False, raise_on_truncation=False): """Convert a DNS wire format message into a message @@ -863,9 +887,6 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, *multi*, a ``bool``, should be set to ``True`` if this message is part of a multiple message sequence. - *first*, a ``bool``, should be set to ``True`` if this message is - stand-alone, or the first message in a multi-message sequence. - *question_only*, a ``bool``. If ``True``, read only up to the end of the question section. @@ -902,11 +923,9 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None, message.xfr = xfr message.origin = origin message.tsig_ctx = tsig_ctx - message.multi = multi - message.first = first reader = _WireReader(wire, initialize_message, question_only, - one_rr_per_rrset, ignore_trailing) + one_rr_per_rrset, ignore_trailing, multi) try: m = reader.read() except dns.exception.FormError: @@ -1291,7 +1310,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None, def make_response(query, recursion_available=False, our_payload=8192, - fudge=300): + fudge=300, tsig_error=0): """Make a message which is a response for the specified query. The message returned is really a response skeleton; it has all of the infrastructure required of a response, but none of the @@ -1310,6 +1329,8 @@ def make_response(query, recursion_available=False, our_payload=8192, *fudge*, an ``int``, the TSIG time fudge. + *tsig_error*, an ``int``, the TSIG error. + Returns a ``dns.message.Message`` object whose specific class is appropriate for the query. For example, if query is a ``dns.update.UpdateMessage``, response will be too. @@ -1327,7 +1348,7 @@ def make_response(query, recursion_available=False, our_payload=8192, if query.edns >= 0: response.use_edns(0, 0, our_payload, query.payload) if query.had_tsig: - response.use_tsig(query.keyring, query.keyname, fudge, None, 0, b'', - query.keyalgorithm) + response.use_tsig(query.keyring, query.keyname, fudge, None, + tsig_error, b'', query.keyalgorithm) response.request_mac = query.mac return response diff --git a/dns/message.pyi b/dns/message.pyi index 76af040..8829db3 100644 --- a/dns/message.pyi +++ b/dns/message.pyi @@ -16,26 +16,14 @@ class Message: self.answer : List[rrset.RRset] = [] self.authority : List[rrset.RRset] = [] self.additional : List[rrset.RRset] = [] - self.edns = -1 - self.ednsflags = 0 - self.payload = 0 - self.options : List[edns.Option] = [] + self.opt : rrset.RRset = None self.request_payload = 0 self.keyring = None - self.keyname = None - self.keyalgorithm = tsig.default_algorithm + self.tsig : rrset.RRset = None self.request_mac = b'' - self.other_data = b'' - self.tsig_error = 0 - self.fudge = 300 - self.original_id = self.id - self.mac = b'' self.xfr = False self.origin = None self.tsig_ctx = None - self.had_tsig = False - self.multi = False - self.first = True self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {} def is_response(self, other : Message) -> bool: @@ -45,7 +33,7 @@ def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message: ... def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None, - tsig_ctx : Optional[hmac.HMAC] = None, multi=False, first=True, + tsig_ctx : Optional[hmac.HMAC] = None, multi=False, question_only=False, one_rr_per_rrset=False, ignore_trailing=False) -> Message: ... diff --git a/dns/query.py b/dns/query.py index ae4258a..3404b91 100644 --- a/dns/query.py +++ b/dns/query.py @@ -920,7 +920,6 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, origin = None oname = zone tsig_ctx = None - first = True while not done: (_, mexpiration) = _compute_times(timeout) if mexpiration is None or \ @@ -937,13 +936,11 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, r = dns.message.from_wire(wire, keyring=q.keyring, request_mac=q.mac, xfr=True, origin=origin, tsig_ctx=tsig_ctx, - multi=True, first=first, - one_rr_per_rrset=is_ixfr) + multi=True, one_rr_per_rrset=is_ixfr) rcode = r.rcode() if rcode != dns.rcode.NOERROR: raise TransferError(rcode) tsig_ctx = r.tsig_ctx - first = False answer_index = 0 if soa_rrset is None: if not r.answer or r.answer[0].name != oname: diff --git a/dns/rdtypes/ANY/TSIG.py b/dns/rdtypes/ANY/TSIG.py new file mode 100644 index 0000000..002c2db --- /dev/null +++ b/dns/rdtypes/ANY/TSIG.py @@ -0,0 +1,112 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2001-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import struct + +import dns.exception +import dns.rdata + + +class TSIG(dns.rdata.Rdata): + + """TSIG record""" + + __slots__ = ['algorithm', 'time_signed', 'fudge', 'mac', + 'original_id', 'error', 'other'] + + def __init__(self, rdclass, rdtype, algorithm, time_signed, fudge, mac, + original_id, error, other): + """Initialize a TSIG rdata. + + *rdclass*, an ``int`` is the rdataclass of the Rdata. + + *rdtype*, an ``int`` is the rdatatype of the Rdata. + + *algorithm*, a ``dns.name.Name``. + + *time_signed*, an ``int``. + + *fudge*, an ``int`. + + *mac*, a ``bytes`` + + *original_id*, an ``int`` + + *error*, an ``int`` + + *other*, a ``bytes`` + """ + + super().__init__(rdclass, rdtype) + object.__setattr__(self, 'algorithm', algorithm) + object.__setattr__(self, 'time_signed', time_signed) + object.__setattr__(self, 'fudge', fudge) + object.__setattr__(self, 'mac', dns.rdata._constify(mac)) + object.__setattr__(self, 'original_id', original_id) + object.__setattr__(self, 'error', error) + object.__setattr__(self, 'other', dns.rdata._constify(other)) + + def to_text(self, origin=None, relativize=True, **kw): + algorithm = self.algorithm.choose_relativity(origin, relativize) + return f"{algorithm} {self.fudge} {self.time_signed} " + \ + f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 256)} " + \ + f"{self.original_id} {self.error} " + \ + f"{len(self.other)} {dns.rdata._base64ify(self.other, 256)}" + + def _to_wire(self, file, compress=None, origin=None, canonicalize=False): + self.algorithm.to_wire(file, None, origin, False) + file.write(struct.pack('!HIHH', + (self.time_signed >> 32) & 0xffff, + self.time_signed & 0xffffffff, + self.fudge, + len(self.mac))) + file.write(self.mac) + file.write(struct.pack('!HHH', self.original_id, self.error, + len(self.other))) + file.write(self.other) + + @classmethod + def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None): + (algorithm, cused) = dns.name.from_wire(wire[: current + rdlen], + current) + current += cused + rdlen -= cused + if rdlen < 10: + raise dns.exception.FormError + (time_hi, time_lo, fudge, mac_len) = \ + struct.unpack('!HIHH', wire[current: current + 10]) + current += 10 + rdlen -= 10 + time_signed = (time_hi << 32) + time_lo + if rdlen < mac_len: + raise dns.exception.FormError + mac = wire[current: current + mac_len].unwrap() + current += mac_len + rdlen -= mac_len + if rdlen < 6: + raise dns.exception.FormError + (original_id, error, other_len) = \ + struct.unpack('!HHH', wire[current: current + 6]) + current += 6 + rdlen -= 6 + if rdlen < other_len: + raise dns.exception.FormError + other = wire[current: current + other_len].unwrap() + current += other_len + rdlen -= other_len + return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac, + original_id, error, other) diff --git a/dns/rdtypes/ANY/__init__.py b/dns/rdtypes/ANY/__init__.py index ca41ef8..ea704c8 100644 --- a/dns/rdtypes/ANY/__init__.py +++ b/dns/rdtypes/ANY/__init__.py @@ -43,6 +43,7 @@ __all__ = [ 'NSEC3', 'NSEC3PARAM', 'OPENPGPKEY', + 'OPT', 'PTR', 'RP', 'RRSIG', @@ -51,6 +52,7 @@ __all__ = [ 'SPF', 'SSHFP', 'TLSA', + 'TSIG', 'TXT', 'URI', 'X25', diff --git a/dns/renderer.py b/dns/renderer.py index 8b25487..be57a62 100644 --- a/dns/renderer.py +++ b/dns/renderer.py @@ -178,17 +178,12 @@ class Renderer: """Add a TSIG signature to the message.""" s = self.output.getvalue() - (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s, - keyname, - secret, - int(time.time()), - fudge, - id, - tsig_error, - other_data, - request_mac, - algorithm=algorithm) - self._write_tsig(tsig_rdata, keyname) + + tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge, + b'', id, tsig_error, other_data) + (tsig, _) = dns.tsig.sign(s, keyname, tsig[0], secret, + int(time.time()), request_mac) + self._write_tsig(tsig, keyname) def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error, other_data, request_mac, @@ -202,30 +197,23 @@ class Renderer: add_multi_tsig() call for the previous message.""" s = self.output.getvalue() - (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s, - keyname, - secret, - int(time.time()), - fudge, - id, - tsig_error, - other_data, - request_mac, - ctx=ctx, - first=ctx is None, - multi=True, - algorithm=algorithm) - self._write_tsig(tsig_rdata, keyname) + + tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge, + b'', id, tsig_error, other_data) + (tsig, ctx) = dns.tsig.sign(s, keyname, tsig[0], secret, + int(time.time()), request_mac, + ctx, True) + self._write_tsig(tsig, keyname) return ctx - def _write_tsig(self, tsig_rdata, keyname): + def _write_tsig(self, tsig, keyname): self._set_section(ADDITIONAL) with self._track_size(): keyname.to_wire(self.output, self.compress, self.origin) self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG, dns.rdataclass.ANY, 0, 0)) rdata_start = self.output.tell() - self.output.write(tsig_rdata) + tsig.to_wire(self.output) after = self.output.tell() self.output.seek(rdata_start - 2) diff --git a/dns/tsig.py b/dns/tsig.py index 5744f1a..12cbae6 100644 --- a/dns/tsig.py +++ b/dns/tsig.py @@ -85,61 +85,59 @@ BADTIME = 18 BADTRUNC = 22 -def sign(wire, keyname, secret, time, fudge, original_id, error, - other_data, request_mac, ctx=None, multi=False, first=True, - algorithm=default_algorithm): +def sign(wire, keyname, rdata, secret, time=None, request_mac=None, + ctx=None, multi=False): """Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata for the input parameters, the HMAC MAC calculated by applying the TSIG signature algorithm, and the TSIG digest context. - @rtype: (string, string, hmac.HMAC object) + @rtype: (string, hmac.HMAC object) @raises ValueError: I{other_data} is too long @raises NotImplementedError: I{algorithm} is not supported """ - if isinstance(other_data, str): - other_data = other_data.encode() - (algorithm_name, digestmod) = get_algorithm(algorithm) + first = not (ctx and multi) + (algorithm_name, digestmod) = get_algorithm(rdata.algorithm) if first: ctx = hmac.new(secret, digestmod=digestmod) - ml = len(request_mac) - if ml > 0: - ctx.update(struct.pack('!H', ml)) + if request_mac: + ctx.update(struct.pack('!H', len(request_mac))) ctx.update(request_mac) - id = struct.pack('!H', original_id) - ctx.update(id) + ctx.update(struct.pack('!H', rdata.original_id)) ctx.update(wire[2:]) if first: ctx.update(keyname.to_digestable()) ctx.update(struct.pack('!H', dns.rdataclass.ANY)) ctx.update(struct.pack('!I', 0)) + if time is None: + time = rdata.time_signed upper_time = (time >> 32) & 0xffff lower_time = time & 0xffffffff - time_mac = struct.pack('!HIH', upper_time, lower_time, fudge) - pre_mac = algorithm_name + time_mac - ol = len(other_data) - if ol > 65535: + time_encoded = struct.pack('!HIH', upper_time, lower_time, rdata.fudge) + other_len = len(rdata.other) + if other_len > 65535: raise ValueError('TSIG Other Data is > 65535 bytes') - post_mac = struct.pack('!HH', error, ol) + other_data if first: - ctx.update(pre_mac) - ctx.update(post_mac) + ctx.update(algorithm_name + time_encoded) + ctx.update(struct.pack('!HH', rdata.error, other_len) + rdata.other) else: - ctx.update(time_mac) + ctx.update(time_encoded) mac = ctx.digest() - mpack = struct.pack('!H', len(mac)) - tsig_rdata = pre_mac + mpack + mac + id + post_mac if multi: ctx = hmac.new(secret, digestmod=digestmod) - ml = len(mac) - ctx.update(struct.pack('!H', ml)) + ctx.update(struct.pack('!H', len(mac))) ctx.update(mac) else: ctx = None - return (tsig_rdata, mac, ctx) + tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG, + rdata.algorithm, time, rdata.fudge, mac, + rdata.original_id, rdata.error, + rdata.other) + return (tsig, ctx) -def validate(wire, keyname, secret, now, request_mac, tsig_start, tsig_rdata, - tsig_rdlen, ctx=None, multi=False, first=True): + +def validate(wire, keyname, rdata, secret, now, request_mac, tsig_start, + ctx=None, multi=False): """Validate the specified TSIG rdata against the other input parameters. @raises FormError: The TSIG is badly formed. @@ -153,41 +151,22 @@ def validate(wire, keyname, secret, now, request_mac, tsig_start, tsig_rdata, raise dns.exception.FormError adcount -= 1 new_wire = wire[0:10] + struct.pack("!H", adcount) + wire[12:tsig_start] - current = tsig_rdata - (aname, used) = dns.name.from_wire(wire, current) - current = current + used - (upper_time, lower_time, fudge, mac_size) = \ - struct.unpack("!HIHH", wire[current:current + 10]) - time = (upper_time << 32) + lower_time - current += 10 - mac = wire[current:current + mac_size] - current += mac_size - (original_id, error, other_size) = \ - struct.unpack("!HHH", wire[current:current + 6]) - current += 6 - other_data = wire[current:current + other_size] - current += other_size - if current != tsig_rdata + tsig_rdlen: - raise dns.exception.FormError - if error != 0: - if error == BADSIG: + if rdata.error != 0: + if rdata.error == BADSIG: raise PeerBadSignature - elif error == BADKEY: + elif rdata.error == BADKEY: raise PeerBadKey - elif error == BADTIME: + elif rdata.error == BADTIME: raise PeerBadTime - elif error == BADTRUNC: + elif rdata.error == BADTRUNC: raise PeerBadTruncation else: - raise PeerError('unknown TSIG error code %d' % error) - time_low = time - fudge - time_high = time + fudge - if now < time_low or now > time_high: + raise PeerError('unknown TSIG error code %d' % rdata.error) + if abs(rdata.time_signed - now) > rdata.fudge: raise BadTime - (junk, our_mac, ctx) = sign(new_wire, keyname, secret, time, fudge, - original_id, error, other_data, - request_mac, ctx, multi, first, aname) - if our_mac != mac: + (our_rdata, ctx) = sign(new_wire, keyname, rdata, secret, None, request_mac, + ctx, multi) + if our_rdata.mac != rdata.mac: raise BadSignature return ctx @@ -208,20 +187,3 @@ def get_algorithm(algorithm): except KeyError: raise NotImplementedError("TSIG algorithm " + str(algorithm) + " is not supported") - - -def get_algorithm_and_mac(wire, tsig_rdata, tsig_rdlen): - """Return the tsig algorithm for the specified tsig_rdata - @raises FormError: The TSIG is badly formed. - """ - current = tsig_rdata - (aname, used) = dns.name.from_wire(wire, current) - current = current + used - (upper_time, lower_time, fudge, mac_size) = \ - struct.unpack("!HIHH", wire[current:current + 10]) - current += 10 - mac = wire[current:current + mac_size] - current += mac_size - if current > tsig_rdata + tsig_rdlen: - raise dns.exception.FormError - return (aname, mac) diff --git a/tests/test_renderer.py b/tests/test_renderer.py index 345ef82..c60ccf9 100644 --- a/tests/test_renderer.py +++ b/tests/test_renderer.py @@ -3,9 +3,11 @@ import unittest import dns.exception +import dns.flags import dns.message import dns.renderer -import dns.flags +import dns.tsig +import dns.tsigkeyring basic_answer = \ """flags QR @@ -35,6 +37,50 @@ class RendererTestCase(unittest.TestCase): expected.id = message.id self.assertEqual(message, expected) + def test_tsig(self): + r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) + qname = dns.name.from_text('foo.example') + r.add_question(qname, dns.rdatatype.A) + keyring = dns.tsigkeyring.from_text({'key' : '12345678'}) + keyname = next(iter(keyring)) + r.write_header() + r.add_tsig(keyname, keyring[keyname], 300, r.id, 0, b'', b'', + dns.tsig.HMAC_SHA256) + wire = r.get_wire() + message = dns.message.from_wire(wire, keyring=keyring) + expected = dns.message.make_query(qname, dns.rdatatype.A) + expected.id = message.id + self.assertEqual(message, expected) + + def test_multi_tsig(self): + qname = dns.name.from_text('foo.example') + keyring = dns.tsigkeyring.from_text({'key' : '12345678'}) + keyname = next(iter(keyring)) + + r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) + r.add_question(qname, dns.rdatatype.A) + r.write_header() + ctx = r.add_multi_tsig(None, keyname, keyring[keyname], 300, r.id, 0, + b'', b'', dns.tsig.HMAC_SHA256) + wire = r.get_wire() + message = dns.message.from_wire(wire, keyring=keyring, multi=True) + expected = dns.message.make_query(qname, dns.rdatatype.A) + expected.id = message.id + self.assertEqual(message, expected) + + r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512) + r.add_question(qname, dns.rdatatype.A) + r.write_header() + ctx = r.add_multi_tsig(ctx, keyname, keyring[keyname], 300, r.id, 0, + b'', b'', dns.tsig.HMAC_SHA256) + wire = r.get_wire() + message = dns.message.from_wire(wire, keyring=keyring, + tsig_ctx=message.tsig_ctx, multi=True) + expected = dns.message.make_query(qname, dns.rdatatype.A) + expected.id = message.id + self.assertEqual(message, expected) + + def test_going_backwards_fails(self): r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512) qname = dns.name.from_text('foo.example') diff --git a/tests/test_tsig.py b/tests/test_tsig.py index 037d5aa..2722e15 100644 --- a/tests/test_tsig.py +++ b/tests/test_tsig.py @@ -42,12 +42,11 @@ class TSIGTestCase(unittest.TestCase): # not raising is passing dns.message.from_wire(w, keyring) - def make_message_pair(self, qname='example', rdtype='A'): + def make_message_pair(self, qname='example', rdtype='A', tsig_error=0): q = dns.message.make_query(qname, rdtype) q.use_tsig(keyring=keyring, keyname=keyname) - q.had_tsig = True # so make_response() does the right thing q.to_wire() # to set q.mac - r = dns.message.make_response(q) + r = dns.message.make_response(q, tsig_error=tsig_error) return(q, r) def test_peer_errors(self): @@ -58,8 +57,7 @@ class TSIGTestCase(unittest.TestCase): (99, dns.tsig.PeerError), ] for err, ex in items: - q, r = self.make_message_pair() - r.tsig_error = err + q, r = self.make_message_pair(tsig_error=err) w = r.to_wire() def bad(): dns.message.from_wire(w, keyring=keyring, request_mac=q.mac) |