From d23fc57dd79339eb14bf0983cf76ff7010d09b2a Mon Sep 17 00:00:00 2001 From: Nick Hall Date: Fri, 10 Jul 2020 22:48:37 +0100 Subject: Add a number of additional tests to improve TSIG test coverage relating to gss-tsig change and some associated refactoring. --- tests/test_tsig.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 3 deletions(-) diff --git a/tests/test_tsig.py b/tests/test_tsig.py index 2d4d92b..41b778e 100644 --- a/tests/test_tsig.py +++ b/tests/test_tsig.py @@ -2,11 +2,14 @@ import unittest from unittest.mock import Mock +import time import dns.rcode import dns.tsig import dns.tsigkeyring import dns.message +from dns.rdatatype import RdataType +from dns.rdataclass import RdataClass keyring = dns.tsigkeyring.from_text( { @@ -43,11 +46,56 @@ class TSIGTestCase(unittest.TestCase): m.use_tsig(keyring, keyname, tsig_error=dns.rcode.BADKEY) self.assertEqual(m.tsig_error, dns.rcode.BADKEY) + def test_verify_mac_for_context(self): + dummy_ctx = None + dummy_expected = None + key = dns.tsig.Key('foo.com', 'abcd', 'bogus') + with self.assertRaises(NotImplementedError): + dns.tsig._verify_mac_for_context(dummy_ctx, key, dummy_expected) + + key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512') + ctx = dns.tsig.get_context(key) + bad_expected = b'xxxxxxxxxx' + with self.assertRaises(dns.tsig.BadSignature): + dns.tsig._verify_mac_for_context(ctx, key, bad_expected) + + def test_validate(self): + # make message and grab the TSIG + m = dns.message.make_query('example', 'a') + m.use_tsig(keyring, keyname, algorithm=dns.tsig.HMAC_SHA256) + w = m.to_wire() + tsig = m.tsig[0] + + # get the time and create a key with matching characteristics + now = int(time.time()) + key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha256') + + # add enough to the time to take it over the fudge amount + with self.assertRaises(dns.tsig.BadTime): + dns.tsig.validate(w, key, dns.name.from_text('foo.com'), + tsig, now + 1000, b'', 0) + + # change the key name + with self.assertRaises(dns.tsig.BadKey): + dns.tsig.validate(w, key, dns.name.from_text('bar.com'), + tsig, now, b'', 0) + + # change the key algorithm + key = dns.tsig.Key('foo.com', 'abcd', 'hmac-sha512') + with self.assertRaises(dns.tsig.BadAlgorithm): + dns.tsig.validate(w, key, dns.name.from_text('foo.com'), + tsig, now, b'', 0) + def test_gssapi_context(self): + def verify_signature(data, mac): + if data == b'throw': + raise Exception + return None + # mock out the gssapi context to return some dummy values gssapi_context_mock = Mock() gssapi_context_mock.get_signature.return_value = b'xxxxxxxxxxx' - gssapi_context_mock.verify_signature.return_value = None + gssapi_context_mock.verify_signature.side_effect = verify_signature # create the key and add it to the keyring key = dns.tsig.Key('gsstsigtest', gssapi_context_mock, 'gss-tsig') @@ -56,6 +104,17 @@ class TSIGTestCase(unittest.TestCase): gsskeyname = dns.name.from_text('gsstsigtest') keyring[gsskeyname] = key + # make sure we can get the keyring (no exception == success) + text = dns.tsigkeyring.to_text(keyring) + self.assertNotEqual(text, '') + + # test exceptional case for _verify_mac_for_context + with self.assertRaises(dns.tsig.BadSignature): + ctx.update(b'throw') + dns.tsig._verify_mac_for_context(ctx, key, 'bogus') + gssapi_context_mock.verify_signature.assert_called() + self.assertEqual(gssapi_context_mock.verify_signature.call_count, 1) + # create example message and go to/from wire to simulate sign/verify m = dns.message.make_query('example', 'a') m.use_tsig(keyring, gsskeyname) @@ -64,8 +123,10 @@ class TSIGTestCase(unittest.TestCase): dns.message.from_wire(w, keyring) # assertions to make sure the "gssapi" functions were called - gssapi_context_mock.get_signature.assert_called_once() - gssapi_context_mock.verify_signature.assert_called_once() + gssapi_context_mock.get_signature.assert_called() + self.assertEqual(gssapi_context_mock.get_signature.call_count, 1) + gssapi_context_mock.verify_signature.assert_called() + self.assertEqual(gssapi_context_mock.verify_signature.call_count, 2) def test_sign_and_validate(self): m = dns.message.make_query('example', 'a') -- cgit v1.2.1