summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBrian Wellington <bwelling@xbill.org>2020-07-01 14:58:14 -0700
committerBrian Wellington <bwelling@xbill.org>2020-07-01 15:01:04 -0700
commit77e747da6bd5b9b7732d1a28ee0572bf9bcd827b (patch)
treedbb702bd8f07a2583c473d7e073b8a4eae9b2feb
parent0703a8e6d132d4c988d5886233e49d3740c0e6a8 (diff)
downloaddnspython-77e747da6bd5b9b7732d1a28ee0572bf9bcd827b.tar.gz
Better deal with backwards compatibility.
If dns.tsigkeyring.from_text() creates dns.tsig.Key objects with the default algorithm, that causes problems for code that specifies a different algorithm. There's no good way to handle this, so change dns.tsigkeyring.from_text() to not create dns.tsig.Key objects unless it knows the algorithm.
-rw-r--r--dns/tsig.py2
-rw-r--r--dns/tsigkeyring.py23
-rw-r--r--tests/test_resolution.py5
-rw-r--r--tests/test_tsigkeyring.py26
4 files changed, 31 insertions, 25 deletions
diff --git a/dns/tsig.py b/dns/tsig.py
index 89183cf..08ab41e 100644
--- a/dns/tsig.py
+++ b/dns/tsig.py
@@ -209,6 +209,8 @@ class Key:
if isinstance(secret, str):
secret = base64.decodebytes(secret.encode())
self.secret = secret
+ if isinstance(algorithm, str):
+ algorithm = dns.name.from_text(algorithm)
self.algorithm = algorithm
def __eq__(self, other):
diff --git a/dns/tsigkeyring.py b/dns/tsigkeyring.py
index b93bdb7..aa3cae9 100644
--- a/dns/tsigkeyring.py
+++ b/dns/tsigkeyring.py
@@ -24,40 +24,37 @@ import dns.name
def from_text(textring):
"""Convert a dictionary containing (textual DNS name, base64 secret)
- or (textual DNS name, (algorithm, base64 secret)) where algorithm
- can be a dns.name.Name or string into a binary keyring which has
- (dns.name.Name, dns.tsig.Key) pairs.
+ pairs into a binary keyring which has (dns.name.Name, bytes) pairs, or
+ a dictionary containing (textual DNS name, (algorithm, base64 secret))
+ pairs into a binary keyring which has (dns.name.Name, dns.tsig.Key) pairs.
@rtype: dict"""
keyring = {}
for (name, value) in textring.items():
name = dns.name.from_text(name)
if isinstance(value, str):
- algorithm = dns.tsig.default_algorithm
- secret = value
+ keyring[name] = dns.tsig.Key(name, value).secret
else:
(algorithm, secret) = value
- if isinstance(algorithm, str):
- algorithm = dns.name.from_text(algorithm)
- keyring[name] = dns.tsig.Key(name, secret, algorithm)
+ keyring[name] = dns.tsig.Key(name, secret, algorithm)
return keyring
def to_text(keyring):
"""Convert a dictionary containing (dns.name.Name, dns.tsig.Key) pairs
into a text keyring which has (textual DNS name, (textual algorithm,
- base64 secret)) pairs.
+ base64 secret)) pairs, or a dictionary containing (dns.name.Name, bytes)
+ pairs into a text keyring which has (textual DNS name, base64 secret) pairs.
@rtype: dict"""
textring = {}
+ b64encode = lambda secret: base64.encodebytes(secret).decode().rstrip()
for (name, key) in keyring.items():
name = name.to_text()
if isinstance(key, bytes):
- algorithm = dns.tsig.default_algorithm
- secret = key
+ textring[name] = b64encode(key)
else:
algorithm = key.algorithm
secret = key.secret
- textring[name] = (algorithm.to_text(),
- base64.encodebytes(secret).decode().rstrip())
+ textring[name] = (key.algorithm.to_text(), b64encode(key.secret))
return textring
diff --git a/tests/test_resolution.py b/tests/test_resolution.py
index aa1cd0c..9145f16 100644
--- a/tests/test_resolution.py
+++ b/tests/test_resolution.py
@@ -197,11 +197,12 @@ class ResolutionTestCase(unittest.TestCase):
self.resolver.keyring = dns.tsigkeyring.from_text({
'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ=='
})
- key = next(iter(self.resolver.keyring.values()))
+ (keyname, secret) = next(iter(self.resolver.keyring.items()))
self.resolver.keyname = dns.name.from_text('keyname.')
(request, answer) = self.resn.next_request()
self.assertFalse(request is None)
- self.assertEqual(request.keyring, key)
+ self.assertEqual(request.keyring.name, keyname)
+ self.assertEqual(request.keyring.secret, secret)
def test_next_request_flags(self):
self.resolver.flags = dns.flags.RD | dns.flags.CD
diff --git a/tests/test_tsigkeyring.py b/tests/test_tsigkeyring.py
index 25c41cc..47f8806 100644
--- a/tests/test_tsigkeyring.py
+++ b/tests/test_tsigkeyring.py
@@ -10,14 +10,14 @@ text_keyring = {
'keyname.' : ('hmac-sha256.', 'NjHwPsMKjdN++dOfE5iAiQ==')
}
-old_text_keyring = {
- 'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ=='
-}
-
alt_text_keyring = {
'keyname.' : (dns.tsig.HMAC_SHA256, 'NjHwPsMKjdN++dOfE5iAiQ==')
}
+old_text_keyring = {
+ 'keyname.' : 'NjHwPsMKjdN++dOfE5iAiQ=='
+}
+
key = dns.tsig.Key('keyname.', 'NjHwPsMKjdN++dOfE5iAiQ==')
rich_keyring = { key.name : key }
@@ -31,16 +31,16 @@ class TSIGKeyRingTestCase(unittest.TestCase):
rkeyring = dns.tsigkeyring.from_text(text_keyring)
self.assertEqual(rkeyring, rich_keyring)
- def test_from_old_text(self):
- """old format text keyring -> rich keyring"""
- rkeyring = dns.tsigkeyring.from_text(old_text_keyring)
- self.assertEqual(rkeyring, rich_keyring)
-
def test_from_alt_text(self):
"""alternate format text keyring -> rich keyring"""
rkeyring = dns.tsigkeyring.from_text(alt_text_keyring)
self.assertEqual(rkeyring, rich_keyring)
+ def test_from_old_text(self):
+ """old format text keyring -> rich keyring"""
+ rkeyring = dns.tsigkeyring.from_text(old_text_keyring)
+ self.assertEqual(rkeyring, old_rich_keyring)
+
def test_to_text(self):
"""text keyring -> rich keyring -> text keyring"""
tkeyring = dns.tsigkeyring.to_text(rich_keyring)
@@ -49,10 +49,16 @@ class TSIGKeyRingTestCase(unittest.TestCase):
def test_old_to_text(self):
"""text keyring -> rich keyring -> text keyring"""
tkeyring = dns.tsigkeyring.to_text(old_rich_keyring)
- self.assertEqual(tkeyring, text_keyring)
+ self.assertEqual(tkeyring, old_text_keyring)
def test_from_and_to_text(self):
"""text keyring -> rich keyring -> text keyring"""
rkeyring = dns.tsigkeyring.from_text(text_keyring)
tkeyring = dns.tsigkeyring.to_text(rkeyring)
self.assertEqual(tkeyring, text_keyring)
+
+ def test_old_from_and_to_text(self):
+ """text keyring -> rich keyring -> text keyring"""
+ rkeyring = dns.tsigkeyring.from_text(old_text_keyring)
+ tkeyring = dns.tsigkeyring.to_text(rkeyring)
+ self.assertEqual(tkeyring, old_text_keyring)