summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2020-06-30 09:34:41 -0700
committerBrian Wellington <bwelling@xbill.org>2020-06-30 09:34:41 -0700
commitcb559f1e1e255a64201a5f86646cc5a894872313 (patch)
tree2a9675f86d04d0df258ae683deb4c15c3dd68498
parentbfdcb567502dcb1e4de443479547a2e26a4547f7 (diff)
downloaddnspython-cb559f1e1e255a64201a5f86646cc5a894872313.tar.gz
Test (and fix) renderer.add_multi_tsig().
-rw-r--r--dns/renderer.py2
-rw-r--r--tests/test_renderer.py29
2 files changed, 30 insertions, 1 deletions
diff --git a/dns/renderer.py b/dns/renderer.py
index 6e50d27..be57a62 100644
--- a/dns/renderer.py
+++ b/dns/renderer.py
@@ -202,7 +202,7 @@ class Renderer:
b'', id, tsig_error, other_data)
(tsig, ctx) = dns.tsig.sign(s, keyname, tsig[0], secret,
int(time.time()), request_mac,
- ctx, True, ctx is None)
+ ctx, True)
self._write_tsig(tsig, keyname)
return ctx
diff --git a/tests/test_renderer.py b/tests/test_renderer.py
index db9d0f3..c60ccf9 100644
--- a/tests/test_renderer.py
+++ b/tests/test_renderer.py
@@ -52,6 +52,35 @@ class RendererTestCase(unittest.TestCase):
expected.id = message.id
self.assertEqual(message, expected)
+ def test_multi_tsig(self):
+ qname = dns.name.from_text('foo.example')
+ keyring = dns.tsigkeyring.from_text({'key' : '12345678'})
+ keyname = next(iter(keyring))
+
+ r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
+ r.add_question(qname, dns.rdatatype.A)
+ r.write_header()
+ ctx = r.add_multi_tsig(None, keyname, keyring[keyname], 300, r.id, 0,
+ b'', b'', dns.tsig.HMAC_SHA256)
+ wire = r.get_wire()
+ message = dns.message.from_wire(wire, keyring=keyring, multi=True)
+ expected = dns.message.make_query(qname, dns.rdatatype.A)
+ expected.id = message.id
+ self.assertEqual(message, expected)
+
+ r = dns.renderer.Renderer(flags=dns.flags.RD, max_size=512)
+ r.add_question(qname, dns.rdatatype.A)
+ r.write_header()
+ ctx = r.add_multi_tsig(ctx, keyname, keyring[keyname], 300, r.id, 0,
+ b'', b'', dns.tsig.HMAC_SHA256)
+ wire = r.get_wire()
+ message = dns.message.from_wire(wire, keyring=keyring,
+ tsig_ctx=message.tsig_ctx, multi=True)
+ expected = dns.message.make_query(qname, dns.rdatatype.A)
+ expected.id = message.id
+ self.assertEqual(message, expected)
+
+
def test_going_backwards_fails(self):
r = dns.renderer.Renderer(flags=dns.flags.QR, max_size=512)
qname = dns.name.from_text('foo.example')