summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-07-26 17:48:14 -0700
committerBob Halley <halley@dnspython.org>2020-07-26 17:48:14 -0700
commit84e076521246abd2ecdcb34b458adc402e14e3d6 (patch)
tree255f5e876aa6161b40780b0681b0f1ac9c953a47
parentcb49bfc57cbf68f0e31f0c2f541eb64a06463eca (diff)
downloaddnspython-84e076521246abd2ecdcb34b458adc402e14e3d6.tar.gz
Wrap exceptions from rdata from_text() and from_wire().wrap
-rw-r--r--dns/exception.py14
-rw-r--r--dns/rdata.py51
-rw-r--r--tests/test_rdata.py26
3 files changed, 58 insertions, 33 deletions
diff --git a/dns/exception.py b/dns/exception.py
index 8f1d488..9486f45 100644
--- a/dns/exception.py
+++ b/dns/exception.py
@@ -126,3 +126,17 @@ class Timeout(DNSException):
"""The DNS operation timed out."""
supp_kwargs = {'timeout'}
fmt = "The DNS operation timed out after {timeout} seconds"
+
+
+class ExceptionWrapper:
+ def __init__(self, exception_class):
+ self.exception_class = exception_class
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if exc_type is not None and not isinstance(exc_val,
+ self.exception_class):
+ raise self.exception_class() from exc_val
+ return False
diff --git a/dns/rdata.py b/dns/rdata.py
index 2d08dcc..0daa08d 100644
--- a/dns/rdata.py
+++ b/dns/rdata.py
@@ -459,35 +459,35 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True,
Returns an instance of the chosen Rdata subclass.
"""
-
if isinstance(tok, str):
tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec)
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype)
- rdata = None
- if cls != GenericRdata:
- # peek at first token
- token = tok.get()
- tok.unget(token)
- if token.is_identifier() and \
- token.value == r'\#':
- #
- # Known type using the generic syntax. Extract the
- # wire form from the generic syntax, and then run
- # from_wire on it.
- #
- grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
- relativize, relativize_to)
- rdata = from_wire(rdclass, rdtype, grdata.data, 0, len(grdata.data),
- origin)
- if rdata is None:
- rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize,
- relativize_to)
- token = tok.get_eol_as_token()
- if token.comment is not None:
- object.__setattr__(rdata, 'rdcomment', token.comment)
- return rdata
+ with dns.exception.ExceptionWrapper(dns.exception.SyntaxError):
+ rdata = None
+ if cls != GenericRdata:
+ # peek at first token
+ token = tok.get()
+ tok.unget(token)
+ if token.is_identifier() and \
+ token.value == r'\#':
+ #
+ # Known type using the generic syntax. Extract the
+ # wire form from the generic syntax, and then run
+ # from_wire on it.
+ #
+ grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin,
+ relativize, relativize_to)
+ rdata = from_wire(rdclass, rdtype, grdata.data, 0,
+ len(grdata.data), origin)
+ if rdata is None:
+ rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize,
+ relativize_to)
+ token = tok.get_eol_as_token()
+ if token.comment is not None:
+ object.__setattr__(rdata, 'rdcomment', token.comment)
+ return rdata
def from_wire_parser(rdclass, rdtype, parser, origin=None):
@@ -517,7 +517,8 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None):
rdclass = dns.rdataclass.RdataClass.make(rdclass)
rdtype = dns.rdatatype.RdataType.make(rdtype)
cls = get_rdata_class(rdclass, rdtype)
- return cls.from_wire_parser(rdclass, rdtype, parser, origin)
+ with dns.exception.ExceptionWrapper(dns.exception.FormError):
+ return cls.from_wire_parser(rdclass, rdtype, parser, origin)
def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None):
diff --git a/tests/test_rdata.py b/tests/test_rdata.py
index 8d9937e..090ca9b 100644
--- a/tests/test_rdata.py
+++ b/tests/test_rdata.py
@@ -385,10 +385,12 @@ class RdataTestCase(unittest.TestCase):
self.equal_wks('10.0.0.1 udp ( domain )', '10.0.0.1 17 ( 53 )')
def test_misc_bad_WKS_text(self):
- def bad():
+ try:
dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.WKS,
'10.0.0.1 132 ( domain )')
- self.assertRaises(NotImplementedError, bad)
+ self.assertTrue(False) # should not happen
+ except dns.exception.SyntaxError as e:
+ self.assertIsInstance(e.__cause__, NotImplementedError)
def test_GPOS_float_converters(self):
rd = dns.rdata.from_text('in', 'gpos', '49 0 0')
@@ -426,7 +428,7 @@ class RdataTestCase(unittest.TestCase):
'"0" "-180.1" "0"',
]
for gpos in bad_gpos:
- with self.assertRaises(dns.exception.FormError):
+ with self.assertRaises(dns.exception.SyntaxError):
dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.GPOS, gpos)
def test_bad_GPOS_wire(self):
@@ -556,11 +558,11 @@ class RdataTestCase(unittest.TestCase):
def test_CERT_algorithm(self):
rd = dns.rdata.from_text('in', 'cert', 'SPKI 1 0 Ym9ndXM=')
self.assertEqual(rd.algorithm, 0)
- with self.assertRaises(ValueError):
+ with self.assertRaises(dns.exception.SyntaxError):
dns.rdata.from_text('in', 'cert', 'SPKI 1 -1 Ym9ndXM=')
- with self.assertRaises(ValueError):
+ with self.assertRaises(dns.exception.SyntaxError):
dns.rdata.from_text('in', 'cert', 'SPKI 1 256 Ym9ndXM=')
- with self.assertRaises(ValueError):
+ with self.assertRaises(dns.exception.SyntaxError):
dns.rdata.from_text('in', 'cert', 'SPKI 1 BOGUS Ym9ndXM=')
def test_bad_URI_text(self):
@@ -603,16 +605,24 @@ class RdataTestCase(unittest.TestCase):
' Ym9ndXM=')
def test_bad_sigtime(self):
- with self.assertRaises(dns.rdtypes.ANY.RRSIG.BadSigTime):
+ try:
dns.rdata.from_text('in', 'rrsig',
'NSEC 1 3 3600 ' +
'202001010000000 20030101000000 ' +
'2143 foo Ym9ndXM=')
- with self.assertRaises(dns.rdtypes.ANY.RRSIG.BadSigTime):
+ self.assertTrue(False) # should not happen
+ except dns.exception.SyntaxError as e:
+ self.assertIsInstance(e.__cause__,
+ dns.rdtypes.ANY.RRSIG.BadSigTime)
+ try:
dns.rdata.from_text('in', 'rrsig',
'NSEC 1 3 3600 ' +
'20200101000000 2003010100000 ' +
'2143 foo Ym9ndXM=')
+ self.assertTrue(False) # should not happen
+ except dns.exception.SyntaxError as e:
+ self.assertIsInstance(e.__cause__,
+ dns.rdtypes.ANY.RRSIG.BadSigTime)
def test_empty_TXT(self):
# hit too long