summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-06-30 11:01:16 -0700
committerGitHub <noreply@github.com>2020-06-30 11:01:16 -0700
commit58d7ca7207b9a46b1238fce8bfa4325a400e4ac5 (patch)
tree9c540090184a5fd27ee06ba51dda5bc2a603bec3
parent1e16e7fee759893b17abb9257bdbe96e7edc65f8 (diff)
parentc8ec6ca47397420cb128db9aa4951245ac4ef07f (diff)
downloaddnspython-58d7ca7207b9a46b1238fce8bfa4325a400e4ac5.tar.gz
Merge pull request #525 from bwelling/tsig
Adds support for a TSIG record class.
-rw-r--r--dns/message.py205
-rw-r--r--dns/message.pyi18
-rw-r--r--dns/query.py5
-rw-r--r--dns/rdtypes/ANY/TSIG.py112
-rw-r--r--dns/rdtypes/ANY/__init__.py2
-rw-r--r--dns/renderer.py42
-rw-r--r--dns/tsig.py108
-rw-r--r--tests/test_renderer.py48
-rw-r--r--tests/test_tsig.py8
9 files changed, 331 insertions, 217 deletions
diff --git a/dns/message.py b/dns/message.py
index 597b329..00359ef 100644
--- a/dns/message.py
+++ b/dns/message.py
@@ -38,6 +38,7 @@ import dns.renderer
import dns.tsig
import dns.wiredata
import dns.rdtypes.ANY.OPT
+import dns.rdtypes.ANY.TSIG
class ShortHeader(dns.exception.FormError):
@@ -109,20 +110,11 @@ class Message:
self.opt = None
self.request_payload = 0
self.keyring = None
- self.keyname = None
- self.keyalgorithm = dns.tsig.default_algorithm
+ self.tsig = None
self.request_mac = b''
- self.other_data = b''
- self.tsig_error = 0
- self.fudge = 300
- self.original_id = self.id
- self.mac = b''
self.xfr = False
self.origin = None
self.tsig_ctx = None
- self.had_tsig = False
- self.multi = False
- self.first = True
self.index = {}
@property
@@ -443,22 +435,31 @@ class Message:
for rrset in self.additional:
r.add_rrset(dns.renderer.ADDITIONAL, rrset, **kw)
r.write_header()
- if self.keyname is not None:
+ if self.tsig is not None:
+ (new_tsig, ctx) = dns.tsig.sign(r.get_wire(),
+ self.tsig.name,
+ self.tsig[0],
+ self.keyring[self.tsig.name],
+ int(time.time()),
+ self.request_mac,
+ tsig_ctx,
+ multi)
+ self.tsig.clear()
+ self.tsig.add(new_tsig)
+ r.add_rrset(dns.renderer.ADDITIONAL, self.tsig)
+ r.write_header()
if multi:
- ctx = r.add_multi_tsig(tsig_ctx,
- self.keyname, self.keyring[self.keyname],
- self.fudge, self.original_id,
- self.tsig_error, self.other_data,
- self.request_mac, self.keyalgorithm)
self.tsig_ctx = ctx
- else:
- r.add_tsig(self.keyname, self.keyring[self.keyname],
- self.fudge, self.original_id, self.tsig_error,
- self.other_data, self.request_mac,
- self.keyalgorithm)
- self.mac = r.mac
return r.get_wire()
+ @staticmethod
+ def _make_tsig(keyname, algorithm, time_signed, fudge, mac, original_id,
+ error, other):
+ tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG,
+ algorithm, time_signed, fudge, mac,
+ original_id, error, other)
+ return dns.rrset.from_rdata(keyname, 0, tsig)
+
def use_tsig(self, keyring, keyname=None, fudge=300,
original_id=None, tsig_error=0, other_data=b'',
algorithm=dns.tsig.default_algorithm):
@@ -492,19 +493,45 @@ class Message:
self.keyring = keyring
if keyname is None:
- self.keyname = list(self.keyring.keys())[0]
- else:
- if isinstance(keyname, str):
- keyname = dns.name.from_text(keyname)
- self.keyname = keyname
- self.keyalgorithm = algorithm
- self.fudge = fudge
+ keyname = list(self.keyring.keys())[0]
+ elif isinstance(keyname, str):
+ keyname = dns.name.from_text(keyname)
if original_id is None:
- self.original_id = self.id
+ original_id = self.id
+ self.tsig = self._make_tsig(keyname, algorithm, 0, fudge, b'',
+ original_id, tsig_error, other_data)
+
+ @property
+ def keyname(self):
+ if self.tsig:
+ return self.tsig.name
+ else:
+ return None
+
+ @property
+ def keyalgorithm(self):
+ if self.tsig:
+ return self.tsig[0].algorithm
+ else:
+ return None
+
+ @property
+ def mac(self):
+ if self.tsig:
+ return self.tsig[0].mac
else:
- self.original_id = original_id
- self.tsig_error = tsig_error
- self.other_data = other_data
+ return None
+
+ @property
+ def tsig_error(self):
+ if self.tsig:
+ return self.tsig[0].error
+ else:
+ return None
+
+ @property
+ def had_tsig(self):
+ return bool(self.tsig)
@staticmethod
def _make_opt(flags=0, payload=1280, options=None):
@@ -649,11 +676,17 @@ class Message:
raise dns.exception.FormError
return (rdclass, rdtype, None, False)
- def _parse_special_rr_header(self, section, name, rdclass, rdtype):
+ def _parse_special_rr_header(self, section, count, position,
+ name, rdclass, rdtype):
if rdtype == dns.rdatatype.OPT:
if section != MessageSection.ADDITIONAL or self.opt or \
name != dns.name.root:
raise BadEDNS
+ elif rdtype == dns.rdatatype.TSIG:
+ if section != MessageSection.ADDITIONAL or \
+ rdclass != dns.rdatatype.ANY or \
+ position != count - 1:
+ raise BadTSIG
return (rdclass, rdtype, None, False)
@@ -691,11 +724,12 @@ class _WireReader:
question_only: Are we only reading the question?
one_rr_per_rrset: Put each RR into its own RRset?
ignore_trailing: Ignore trailing junk at end of request?
+ multi: Is this message part of a multi-message sequence?
DNS dynamic updates.
"""
def __init__(self, wire, initialize_message, question_only=False,
- one_rr_per_rrset=False, ignore_trailing=False):
+ one_rr_per_rrset=False, ignore_trailing=False, multi=False):
self.wire = dns.wiredata.maybe_wrap(wire)
self.message = None
self.current = 0
@@ -703,6 +737,7 @@ class _WireReader:
self.question_only = question_only
self.one_rr_per_rrset = one_rr_per_rrset
self.ignore_trailing = ignore_trailing
+ self.multi = multi
def _get_question(self, section_number, qcount):
"""Read the next *qcount* records from the wire data and add them to
@@ -746,65 +781,55 @@ class _WireReader:
struct.unpack('!HHIH',
self.wire[self.current:self.current + 10])
self.current += 10
- if rdtype == dns.rdatatype.TSIG:
- if not (section is self.message.additional and
- i == (count - 1)):
- raise BadTSIG
+ if rdtype in (dns.rdatatype.OPT, dns.rdatatype.TSIG):
+ (rdclass, rdtype, deleting, empty) = \
+ self.message._parse_special_rr_header(section_number,
+ count, i, name,
+ rdclass, rdtype)
+ else:
+ (rdclass, rdtype, deleting, empty) = \
+ self.message._parse_rr_header(section_number,
+ name, rdclass, rdtype)
+ if empty:
+ if rdlen > 0:
+ raise dns.exception.FormError
+ rd = None
+ covers = dns.rdatatype.NONE
+ else:
+ rd = dns.rdata.from_wire(rdclass, rdtype,
+ self.wire, self.current, rdlen,
+ self.message.origin)
+ covers = rd.covers()
+ if self.message.xfr and rdtype == dns.rdatatype.SOA:
+ force_unique = True
+ if rdtype == dns.rdatatype.OPT:
+ self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
+ elif rdtype == dns.rdatatype.TSIG:
if self.message.keyring is None:
raise UnknownTSIGKey('got signed message without keyring')
secret = self.message.keyring.get(absolute_name)
if secret is None:
raise UnknownTSIGKey("key '%s' unknown" % name)
- self.message.keyname = absolute_name
- (self.message.keyalgorithm, self.message.mac) = \
- dns.tsig.get_algorithm_and_mac(self.wire, self.current,
- rdlen)
self.message.tsig_ctx = \
dns.tsig.validate(self.wire,
absolute_name,
+ rd,
secret,
int(time.time()),
self.message.request_mac,
rr_start,
- self.current,
- rdlen,
self.message.tsig_ctx,
- self.message.multi,
- self.message.first)
- self.message.had_tsig = True
+ self.multi)
+ self.message.tsig = dns.rrset.from_rdata(absolute_name, 0, rd)
else:
- if rdtype == dns.rdatatype.OPT:
- (rdclass, rdtype, deleting, empty) = \
- self.message._parse_special_rr_header(section_number,
- name,
- rdclass, rdtype)
- else:
- (rdclass, rdtype, deleting, empty) = \
- self.message._parse_rr_header(section_number,
- name, rdclass, rdtype)
- if empty:
- if rdlen > 0:
- raise dns.exception.FormError
- rd = None
- covers = dns.rdatatype.NONE
- else:
- rd = dns.rdata.from_wire(rdclass, rdtype,
- self.wire, self.current, rdlen,
- self.message.origin)
- covers = rd.covers()
- if self.message.xfr and rdtype == dns.rdatatype.SOA:
- force_unique = True
- if rdtype == dns.rdatatype.OPT:
- self.message.opt = dns.rrset.from_rdata(name, ttl, rd)
- else:
- rrset = self.message.find_rrset(section, name,
- rdclass, rdtype, covers,
- deleting, True,
- force_unique)
- if rd is not None:
- if ttl > 0x7fffffff:
- ttl = 0
- rrset.add(rd, ttl)
+ rrset = self.message.find_rrset(section, name,
+ rdclass, rdtype, covers,
+ deleting, True,
+ force_unique)
+ if rd is not None:
+ if ttl > 0x7fffffff:
+ ttl = 0
+ rrset.add(rd, ttl)
self.current += rdlen
def read(self):
@@ -831,14 +856,13 @@ class _WireReader:
self._get_section(MessageSection.ADDITIONAL, adcount)
if not self.ignore_trailing and self.current != l:
raise TrailingJunk
- if self.message.multi and self.message.tsig_ctx and \
- not self.message.had_tsig:
+ if self.multi and self.message.tsig_ctx and not self.message.had_tsig:
self.message.tsig_ctx.update(self.wire)
return self.message
def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
- tsig_ctx=None, multi=False, first=True,
+ tsig_ctx=None, multi=False,
question_only=False, one_rr_per_rrset=False,
ignore_trailing=False, raise_on_truncation=False):
"""Convert a DNS wire format message into a message
@@ -863,9 +887,6 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
*multi*, a ``bool``, should be set to ``True`` if this message is
part of a multiple message sequence.
- *first*, a ``bool``, should be set to ``True`` if this message is
- stand-alone, or the first message in a multi-message sequence.
-
*question_only*, a ``bool``. If ``True``, read only up to
the end of the question section.
@@ -902,11 +923,9 @@ def from_wire(wire, keyring=None, request_mac=b'', xfr=False, origin=None,
message.xfr = xfr
message.origin = origin
message.tsig_ctx = tsig_ctx
- message.multi = multi
- message.first = first
reader = _WireReader(wire, initialize_message, question_only,
- one_rr_per_rrset, ignore_trailing)
+ one_rr_per_rrset, ignore_trailing, multi)
try:
m = reader.read()
except dns.exception.FormError:
@@ -1291,7 +1310,7 @@ def make_query(qname, rdtype, rdclass=dns.rdataclass.IN, use_edns=None,
def make_response(query, recursion_available=False, our_payload=8192,
- fudge=300):
+ fudge=300, tsig_error=0):
"""Make a message which is a response for the specified query.
The message returned is really a response skeleton; it has all
of the infrastructure required of a response, but none of the
@@ -1310,6 +1329,8 @@ def make_response(query, recursion_available=False, our_payload=8192,
*fudge*, an ``int``, the TSIG time fudge.
+ *tsig_error*, an ``int``, the TSIG error.
+
Returns a ``dns.message.Message`` object whose specific class is
appropriate for the query. For example, if query is a
``dns.update.UpdateMessage``, response will be too.
@@ -1327,7 +1348,7 @@ def make_response(query, recursion_available=False, our_payload=8192,
if query.edns >= 0:
response.use_edns(0, 0, our_payload, query.payload)
if query.had_tsig:
- response.use_tsig(query.keyring, query.keyname, fudge, None, 0, b'',
- query.keyalgorithm)
+ response.use_tsig(query.keyring, query.keyname, fudge, None,
+ tsig_error, b'', query.keyalgorithm)
response.request_mac = query.mac
return response
diff --git a/dns/message.pyi b/dns/message.pyi
index 76af040..8829db3 100644
--- a/dns/message.pyi
+++ b/dns/message.pyi
@@ -16,26 +16,14 @@ class Message:
self.answer : List[rrset.RRset] = []
self.authority : List[rrset.RRset] = []
self.additional : List[rrset.RRset] = []
- self.edns = -1
- self.ednsflags = 0
- self.payload = 0
- self.options : List[edns.Option] = []
+ self.opt : rrset.RRset = None
self.request_payload = 0
self.keyring = None
- self.keyname = None
- self.keyalgorithm = tsig.default_algorithm
+ self.tsig : rrset.RRset = None
self.request_mac = b''
- self.other_data = b''
- self.tsig_error = 0
- self.fudge = 300
- self.original_id = self.id
- self.mac = b''
self.xfr = False
self.origin = None
self.tsig_ctx = None
- self.had_tsig = False
- self.multi = False
- self.first = True
self.index : Dict[Tuple[rrset.RRset, name.Name, int, int, Union[int,str], int], rrset.RRset] = {}
def is_response(self, other : Message) -> bool:
@@ -45,7 +33,7 @@ def from_text(a : str, idna_codec : Optional[name.IDNACodec] = None) -> Message:
...
def from_wire(wire, keyring : Optional[Dict[name.Name,bytes]] = None, request_mac = b'', xfr=False, origin=None,
- tsig_ctx : Optional[hmac.HMAC] = None, multi=False, first=True,
+ tsig_ctx : Optional[hmac.HMAC] = None, multi=False,
question_only=False, one_rr_per_rrset=False,
ignore_trailing=False) -> Message:
...
diff --git a/dns/query.py b/dns/query.py
index ae4258a..3404b91 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -920,7 +920,6 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
origin = None
oname = zone
tsig_ctx = None
- first = True
while not done:
(_, mexpiration) = _compute_times(timeout)
if mexpiration is None or \
@@ -937,13 +936,11 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
r = dns.message.from_wire(wire, keyring=q.keyring,
request_mac=q.mac, xfr=True,
origin=origin, tsig_ctx=tsig_ctx,
- multi=True, first=first,
- one_rr_per_rrset=is_ixfr)
+ multi=True, one_rr_per_rrset=is_ixfr)
rcode = r.rcode()
if rcode != dns.rcode.NOERROR:
raise TransferError(rcode)
tsig_ctx = r.tsig_ctx
- first = False
answer_index = 0
if soa_rrset is None:
if not r.answer or r.answer[0].name != oname:
diff --git a/dns/rdtypes/ANY/TSIG.py b/dns/rdtypes/ANY/TSIG.py
new file mode 100644
index 0000000..002c2db
--- /dev/null
+++ b/dns/rdtypes/ANY/TSIG.py
@@ -0,0 +1,112 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2001-2017 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+import struct
+
+import dns.exception
+import dns.rdata
+
+
+class TSIG(dns.rdata.Rdata):
+
+ """TSIG record"""
+
+ __slots__ = ['algorithm', 'time_signed', 'fudge', 'mac',
+ 'original_id', 'error', 'other']
+
+ def __init__(self, rdclass, rdtype, algorithm, time_signed, fudge, mac,
+ original_id, error, other):
+ """Initialize a TSIG rdata.
+
+ *rdclass*, an ``int`` is the rdataclass of the Rdata.
+
+ *rdtype*, an ``int`` is the rdatatype of the Rdata.
+
+ *algorithm*, a ``dns.name.Name``.
+
+ *time_signed*, an ``int``.
+
+ *fudge*, an ``int`.
+
+ *mac*, a ``bytes``
+
+ *original_id*, an ``int``
+
+ *error*, an ``int``
+
+ *other*, a ``bytes``
+ """
+
+ super().__init__(rdclass, rdtype)
+ object.__setattr__(self, 'algorithm', algorithm)
+ object.__setattr__(self, 'time_signed', time_signed)
+ object.__setattr__(self, 'fudge', fudge)
+ object.__setattr__(self, 'mac', dns.rdata._constify(mac))
+ object.__setattr__(self, 'original_id', original_id)
+ object.__setattr__(self, 'error', error)
+ object.__setattr__(self, 'other', dns.rdata._constify(other))
+
+ def to_text(self, origin=None, relativize=True, **kw):
+ algorithm = self.algorithm.choose_relativity(origin, relativize)
+ return f"{algorithm} {self.fudge} {self.time_signed} " + \
+ f"{len(self.mac)} {dns.rdata._base64ify(self.mac, 256)} " + \
+ f"{self.original_id} {self.error} " + \
+ f"{len(self.other)} {dns.rdata._base64ify(self.other, 256)}"
+
+ def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
+ self.algorithm.to_wire(file, None, origin, False)
+ file.write(struct.pack('!HIHH',
+ (self.time_signed >> 32) & 0xffff,
+ self.time_signed & 0xffffffff,
+ self.fudge,
+ len(self.mac)))
+ file.write(self.mac)
+ file.write(struct.pack('!HHH', self.original_id, self.error,
+ len(self.other)))
+ file.write(self.other)
+
+ @classmethod
+ def from_wire(cls, rdclass, rdtype, wire, current, rdlen, origin=None):
+ (algorithm, cused) = dns.name.from_wire(wire[: current + rdlen],
+ current)
+ current += cused
+ rdlen -= cused
+ if rdlen < 10:
+ raise dns.exception.FormError
+ (time_hi, time_lo, fudge, mac_len) = \
+ struct.unpack('!HIHH', wire[current: current + 10])
+ current += 10
+ rdlen -= 10
+ time_signed = (time_hi << 32) + time_lo
+ if rdlen < mac_len:
+ raise dns.exception.FormError
+ mac = wire[current: current + mac_len].unwrap()
+ current += mac_len
+ rdlen -= mac_len
+ if rdlen < 6:
+ raise dns.exception.FormError
+ (original_id, error, other_len) = \
+ struct.unpack('!HHH', wire[current: current + 6])
+ current += 6
+ rdlen -= 6
+ if rdlen < other_len:
+ raise dns.exception.FormError
+ other = wire[current: current + other_len].unwrap()
+ current += other_len
+ rdlen -= other_len
+ return cls(rdclass, rdtype, algorithm, time_signed, fudge, mac,
+ original_id, error, other)
diff --git a/dns/rdtypes/ANY/__init__.py b/dns/rdtypes/ANY/__init__.py
index ca41ef8..ea704c8 100644
--- a/dns/rdtypes/ANY/__init__.py
+++ b/dns/rdtypes/ANY/__init__.py
@@ -43,6 +43,7 @@ __all__ = [
'NSEC3',
'NSEC3PARAM',
'OPENPGPKEY',
+ 'OPT',
'PTR',
'RP',
'RRSIG',
@@ -51,6 +52,7 @@ __all__ = [
'SPF',
'SSHFP',
'TLSA',
+ 'TSIG',
'TXT',
'URI',
'X25',
diff --git a/dns/renderer.py b/dns/renderer.py
index 8b25487..be57a62 100644
--- a/dns/renderer.py
+++ b/dns/renderer.py
@@ -178,17 +178,12 @@ class Renderer:
"""Add a TSIG signature to the message."""
s = self.output.getvalue()
- (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s,
- keyname,
- secret,
- int(time.time()),
- fudge,
- id,
- tsig_error,
- other_data,
- request_mac,
- algorithm=algorithm)
- self._write_tsig(tsig_rdata, keyname)
+
+ tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge,
+ b'', id, tsig_error, other_data)
+ (tsig, _) = dns.tsig.sign(s, keyname, tsig[0], secret,
+ int(time.time()), request_mac)
+ self._write_tsig(tsig, keyname)
def add_multi_tsig(self, ctx, keyname, secret, fudge, id, tsig_error,
other_data, request_mac,
@@ -202,30 +197,23 @@ class Renderer:
add_multi_tsig() call for the previous message."""
s = self.output.getvalue()
- (tsig_rdata, self.mac, ctx) = dns.tsig.sign(s,
- keyname,
- secret,
- int(time.time()),
- fudge,
- id,
- tsig_error,
- other_data,
- request_mac,
- ctx=ctx,
- first=ctx is None,
- multi=True,
- algorithm=algorithm)
- self._write_tsig(tsig_rdata, keyname)
+
+ tsig = dns.message.Message._make_tsig(keyname, algorithm, 0, fudge,
+ b'', id, tsig_error, other_data)
+ (tsig, ctx) = dns.tsig.sign(s, keyname, tsig[0], secret,
+ int(time.time()), request_mac,
+ ctx, True)
+ self._write_tsig(tsig, keyname)
return ctx
- def _write_tsig(self, tsig_rdata, keyname):
+ def _write_tsig(self, tsig, keyname):
self._set_section(ADDITIONAL)
with self._track_size():
keyname.to_wire(self.output, self.compress, self.origin)
self.output.write(struct.pack('!HHIH', dns.rdatatype.TSIG,
dns.rdataclass.ANY, 0, 0))
rdata_start = self.output.tell()
- self.output.write(tsig_rdata)
+ tsig.to_wire(self.output)
after = self.output.tell()
self.output.seek(rdata_start - 2)
diff --git a/dns/tsig.py b/dns/tsig.py
index 5744f1a..12cbae6 100644
--- a/dns/tsig.py
+++ b/dns/tsig.py
@@ -85,61 +85,59 @@ BADTIME = 18
BADTRUNC = 22
-def sign(wire, keyname, secret, time, fudge, original_id, error,
- other_data, request_mac, ctx=None, multi=False, first=True,
- algorithm=default_algorithm):
+def sign(wire, keyname, rdata, secret, time=None, request_mac=None,
+ ctx=None, multi=False):
"""Return a (tsig_rdata, mac, ctx) tuple containing the HMAC TSIG rdata
for the input parameters, the HMAC MAC calculated by applying the
TSIG signature algorithm, and the TSIG digest context.
- @rtype: (string, string, hmac.HMAC object)
+ @rtype: (string, hmac.HMAC object)
@raises ValueError: I{other_data} is too long
@raises NotImplementedError: I{algorithm} is not supported
"""
- if isinstance(other_data, str):
- other_data = other_data.encode()
- (algorithm_name, digestmod) = get_algorithm(algorithm)
+ first = not (ctx and multi)
+ (algorithm_name, digestmod) = get_algorithm(rdata.algorithm)
if first:
ctx = hmac.new(secret, digestmod=digestmod)
- ml = len(request_mac)
- if ml > 0:
- ctx.update(struct.pack('!H', ml))
+ if request_mac:
+ ctx.update(struct.pack('!H', len(request_mac)))
ctx.update(request_mac)
- id = struct.pack('!H', original_id)
- ctx.update(id)
+ ctx.update(struct.pack('!H', rdata.original_id))
ctx.update(wire[2:])
if first:
ctx.update(keyname.to_digestable())
ctx.update(struct.pack('!H', dns.rdataclass.ANY))
ctx.update(struct.pack('!I', 0))
+ if time is None:
+ time = rdata.time_signed
upper_time = (time >> 32) & 0xffff
lower_time = time & 0xffffffff
- time_mac = struct.pack('!HIH', upper_time, lower_time, fudge)
- pre_mac = algorithm_name + time_mac
- ol = len(other_data)
- if ol > 65535:
+ time_encoded = struct.pack('!HIH', upper_time, lower_time, rdata.fudge)
+ other_len = len(rdata.other)
+ if other_len > 65535:
raise ValueError('TSIG Other Data is > 65535 bytes')
- post_mac = struct.pack('!HH', error, ol) + other_data
if first:
- ctx.update(pre_mac)
- ctx.update(post_mac)
+ ctx.update(algorithm_name + time_encoded)
+ ctx.update(struct.pack('!HH', rdata.error, other_len) + rdata.other)
else:
- ctx.update(time_mac)
+ ctx.update(time_encoded)
mac = ctx.digest()
- mpack = struct.pack('!H', len(mac))
- tsig_rdata = pre_mac + mpack + mac + id + post_mac
if multi:
ctx = hmac.new(secret, digestmod=digestmod)
- ml = len(mac)
- ctx.update(struct.pack('!H', ml))
+ ctx.update(struct.pack('!H', len(mac)))
ctx.update(mac)
else:
ctx = None
- return (tsig_rdata, mac, ctx)
+ tsig = dns.rdtypes.ANY.TSIG.TSIG(dns.rdataclass.ANY, dns.rdatatype.TSIG,
+ rdata.algorithm, time, rdata.fudge, mac,
+ rdata.original_id, rdata.error,
+ rdata.other)
+ return (tsig, ctx)
-def validate(wire, keyname, secret, now, request_mac, tsig_start, tsig_rdata,
- tsig_rdlen, ctx=None, multi=False, first=True):
+
+def validate(wire, keyname, rdata, secret, now, request_mac, tsig_start,
+ ctx=None, multi=False):
"""Validate the specified TSIG rdata against the other input parameters.
@raises FormError: The TSIG is badly formed.
@@ -153,41 +151,22 @@ def validate(wire, keyname, secret, now, request_mac, tsig_start, tsig_rdata,
raise dns.exception.FormError
adcount -= 1
new_wire = wire[0:10] + struct.pack("!H", adcount) + wire[12:tsig_start]
- current = tsig_rdata
- (aname, used) = dns.name.from_wire(wire, current)
- current = current + used
- (upper_time, lower_time, fudge, mac_size) = \
- struct.unpack("!HIHH", wire[current:current + 10])
- time = (upper_time << 32) + lower_time
- current += 10
- mac = wire[current:current + mac_size]
- current += mac_size
- (original_id, error, other_size) = \
- struct.unpack("!HHH", wire[current:current + 6])
- current += 6
- other_data = wire[current:current + other_size]
- current += other_size
- if current != tsig_rdata + tsig_rdlen:
- raise dns.exception.FormError
- if error != 0:
- if error == BADSIG:
+ if rdata.error != 0:
+ if rdata.error == BADSIG:
raise PeerBadSignature
- elif error == BADKEY:
+ elif rdata.error == BADKEY:
raise PeerBadKey
- elif error == BADTIME:
+ elif rdata.error == BADTIME:
raise PeerBadTime
- elif error == BADTRUNC:
+ elif rdata.error == BADTRUNC:
raise PeerBadTruncation
else:
- raise PeerError('unknown TSIG error code %d' % error)
- time_low = time - fudge
- time_high = time + fudge
- if now < time_low or now > time_high:
+ raise PeerError('unknown TSIG error code %d' % rdata.error)
+ if abs(rdata.time_signed - now) > rdata.fudge:
raise BadTime
- (junk, our_mac, ctx) = sign(new_wire, keyname, secret, time, fudge,
- original_id, error, other_data,
- request_mac, ctx, multi, first, aname)
- if our_mac != mac:
+ (our_rdata, ctx) = sign(new_wire, keyname, rdata, secret, None, request_mac,
+ ctx, multi)
+ if our_rdata.mac != rdata.mac:
raise BadSignature
return ctx
@@ -208,20 +187,3 @@ def get_algorithm(algorithm):
except KeyError:
raise NotImplementedError("TSIG algorithm " + str(algorithm) +
" is not supported")
-
-
-def get_algorithm_and_mac(wire, tsig_rdata, tsig_rdlen):
- """Return the tsig algorithm for the specified tsig_rdata
- @raises FormError: The TSIG is badly formed.
- """
- current = tsig_rdata
- (aname, used) = dns.name.from_wire(wire, current)
- current = current + used
- (upper_time, lower_time, fudge, mac_size) = \
- struct.unpack("!HIHH", wire[current:current + 10])
- current += 10
- mac = wire[current:current + mac_size]
- current += mac_size
- if current > tsig_rdata + tsig_rdlen:
- raise dns.exception.FormError
- return (aname, mac)
diff --git a/tests/test_renderer.py b/tests/test_renderer.py
index 345ef82..c60ccf9 100644
--- a/tests/test_renderer.py
+++ b/tests/test_renderer.py
@@ -3,9 +3,11 @@
import unittest
import dns.exception
+import dns.flags
import dns.message
import dns.renderer
-import dns.flags
+import dns.tsig
+import dns.tsigkeyring
basic_answer = \
"""flags QR
@@ -35,6 +37,50 @@ class RendererTestCase(unittest.TestCase):
expected.id = message.id
self.assertEqual(message, expected)
+ def test_tsig(self):
+ r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
+ qname = dns.name.from_text('foo.example')
+ r.add_question(qname, dns.rdatatype.A)
+ keyring = dns.tsigkeyring.from_text({'key' : '12345678'})
+ keyname = next(iter(keyring))
+ r.write_header()
+ r.add_tsig(keyname, keyring[keyname], 300, r.id, 0, b'', b'',
+ dns.tsig.HMAC_SHA256)
+ wire = r.get_wire()
+ message = dns.message.from_wire(wire, keyring=keyring)
+ expected = dns.message.make_query(qname, dns.rdatatype.A)
+ expected.id = message.id
+ self.assertEqual(message, expected)
+
+ def test_multi_tsig(self):
+ qname = dns.name.from_text('foo.example')
+ keyring = dns.tsigkeyring.from_text({'key' : '12345678'})
+ keyname = next(iter(keyring))
+
+ r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
+ r.add_question(qname, dns.rdatatype.A)
+ r.write_header()
+ ctx = r.add_multi_tsig(None, keyname, keyring[keyname], 300, r.id, 0,
+ b'', b'', dns.tsig.HMAC_SHA256)
+ wire = r.get_wire()
+ message = dns.message.from_wire(wire, keyring=keyring, multi=True)
+ expected = dns.message.make_query(qname, dns.rdatatype.A)
+ expected.id = message.id
+ self.assertEqual(message, expected)
+
+ r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
+ r.add_question(qname, dns.rdatatype.A)
+ r.write_header()
+ ctx = r.add_multi_tsig(ctx, keyname, keyring[keyname], 300, r.id, 0,
+ b'', b'', dns.tsig.HMAC_SHA256)
+ wire = r.get_wire()
+ message = dns.message.from_wire(wire, keyring=keyring,
+ tsig_ctx=message.tsig_ctx, multi=True)
+ expected = dns.message.make_query(qname, dns.rdatatype.A)
+ expected.id = message.id
+ self.assertEqual(message, expected)
+
+
def test_going_backwards_fails(self):
r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512)
qname = dns.name.from_text('foo.example')
diff --git a/tests/test_tsig.py b/tests/test_tsig.py
index 037d5aa..2722e15 100644
--- a/tests/test_tsig.py
+++ b/tests/test_tsig.py
@@ -42,12 +42,11 @@ class TSIGTestCase(unittest.TestCase):
# not raising is passing
dns.message.from_wire(w, keyring)
- def make_message_pair(self, qname='example', rdtype='A'):
+ def make_message_pair(self, qname='example', rdtype='A', tsig_error=0):
q = dns.message.make_query(qname, rdtype)
q.use_tsig(keyring=keyring, keyname=keyname)
- q.had_tsig = True # so make_response() does the right thing
q.to_wire() # to set q.mac
- r = dns.message.make_response(q)
+ r = dns.message.make_response(q, tsig_error=tsig_error)
return(q, r)
def test_peer_errors(self):
@@ -58,8 +57,7 @@ class TSIGTestCase(unittest.TestCase):
(99, dns.tsig.PeerError),
]
for err, ex in items:
- q, r = self.make_message_pair()
- r.tsig_error = err
+ q, r = self.make_message_pair(tsig_error=err)
w = r.to_wire()
def bad():
dns.message.from_wire(w, keyring=keyring, request_mac=q.mac)