From 84e076521246abd2ecdcb34b458adc402e14e3d6 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sun, 26 Jul 2020 17:48:14 -0700 Subject: Wrap exceptions from rdata from_text() and from_wire(). --- dns/exception.py | 14 ++++++++++++++ dns/rdata.py | 51 ++++++++++++++++++++++++++------------------------- tests/test_rdata.py | 26 ++++++++++++++++++-------- 3 files changed, 58 insertions(+), 33 deletions(-) diff --git a/dns/exception.py b/dns/exception.py index 8f1d488..9486f45 100644 --- a/dns/exception.py +++ b/dns/exception.py @@ -126,3 +126,17 @@ class Timeout(DNSException): """The DNS operation timed out.""" supp_kwargs = {'timeout'} fmt = "The DNS operation timed out after {timeout} seconds" + + +class ExceptionWrapper: + def __init__(self, exception_class): + self.exception_class = exception_class + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None and not isinstance(exc_val, + self.exception_class): + raise self.exception_class() from exc_val + return False diff --git a/dns/rdata.py b/dns/rdata.py index 2d08dcc..0daa08d 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -459,35 +459,35 @@ def from_text(rdclass, rdtype, tok, origin=None, relativize=True, Returns an instance of the chosen Rdata subclass. """ - if isinstance(tok, str): tok = dns.tokenizer.Tokenizer(tok, idna_codec=idna_codec) rdclass = dns.rdataclass.RdataClass.make(rdclass) rdtype = dns.rdatatype.RdataType.make(rdtype) cls = get_rdata_class(rdclass, rdtype) - rdata = None - if cls != GenericRdata: - # peek at first token - token = tok.get() - tok.unget(token) - if token.is_identifier() and \ - token.value == r'\#': - # - # Known type using the generic syntax. Extract the - # wire form from the generic syntax, and then run - # from_wire on it. - # - grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, - relativize, relativize_to) - rdata = from_wire(rdclass, rdtype, grdata.data, 0, len(grdata.data), - origin) - if rdata is None: - rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize, - relativize_to) - token = tok.get_eol_as_token() - if token.comment is not None: - object.__setattr__(rdata, 'rdcomment', token.comment) - return rdata + with dns.exception.ExceptionWrapper(dns.exception.SyntaxError): + rdata = None + if cls != GenericRdata: + # peek at first token + token = tok.get() + tok.unget(token) + if token.is_identifier() and \ + token.value == r'\#': + # + # Known type using the generic syntax. Extract the + # wire form from the generic syntax, and then run + # from_wire on it. + # + grdata = GenericRdata.from_text(rdclass, rdtype, tok, origin, + relativize, relativize_to) + rdata = from_wire(rdclass, rdtype, grdata.data, 0, + len(grdata.data), origin) + if rdata is None: + rdata = cls.from_text(rdclass, rdtype, tok, origin, relativize, + relativize_to) + token = tok.get_eol_as_token() + if token.comment is not None: + object.__setattr__(rdata, 'rdcomment', token.comment) + return rdata def from_wire_parser(rdclass, rdtype, parser, origin=None): @@ -517,7 +517,8 @@ def from_wire_parser(rdclass, rdtype, parser, origin=None): rdclass = dns.rdataclass.RdataClass.make(rdclass) rdtype = dns.rdatatype.RdataType.make(rdtype) cls = get_rdata_class(rdclass, rdtype) - return cls.from_wire_parser(rdclass, rdtype, parser, origin) + with dns.exception.ExceptionWrapper(dns.exception.FormError): + return cls.from_wire_parser(rdclass, rdtype, parser, origin) def from_wire(rdclass, rdtype, wire, current, rdlen, origin=None): diff --git a/tests/test_rdata.py b/tests/test_rdata.py index 8d9937e..090ca9b 100644 --- a/tests/test_rdata.py +++ b/tests/test_rdata.py @@ -385,10 +385,12 @@ class RdataTestCase(unittest.TestCase): self.equal_wks('10.0.0.1 udp ( domain )', '10.0.0.1 17 ( 53 )') def test_misc_bad_WKS_text(self): - def bad(): + try: dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.WKS, '10.0.0.1 132 ( domain )') - self.assertRaises(NotImplementedError, bad) + self.assertTrue(False) # should not happen + except dns.exception.SyntaxError as e: + self.assertIsInstance(e.__cause__, NotImplementedError) def test_GPOS_float_converters(self): rd = dns.rdata.from_text('in', 'gpos', '49 0 0') @@ -426,7 +428,7 @@ class RdataTestCase(unittest.TestCase): '"0" "-180.1" "0"', ] for gpos in bad_gpos: - with self.assertRaises(dns.exception.FormError): + with self.assertRaises(dns.exception.SyntaxError): dns.rdata.from_text(dns.rdataclass.IN, dns.rdatatype.GPOS, gpos) def test_bad_GPOS_wire(self): @@ -556,11 +558,11 @@ class RdataTestCase(unittest.TestCase): def test_CERT_algorithm(self): rd = dns.rdata.from_text('in', 'cert', 'SPKI 1 0 Ym9ndXM=') self.assertEqual(rd.algorithm, 0) - with self.assertRaises(ValueError): + with self.assertRaises(dns.exception.SyntaxError): dns.rdata.from_text('in', 'cert', 'SPKI 1 -1 Ym9ndXM=') - with self.assertRaises(ValueError): + with self.assertRaises(dns.exception.SyntaxError): dns.rdata.from_text('in', 'cert', 'SPKI 1 256 Ym9ndXM=') - with self.assertRaises(ValueError): + with self.assertRaises(dns.exception.SyntaxError): dns.rdata.from_text('in', 'cert', 'SPKI 1 BOGUS Ym9ndXM=') def test_bad_URI_text(self): @@ -603,16 +605,24 @@ class RdataTestCase(unittest.TestCase): ' Ym9ndXM=') def test_bad_sigtime(self): - with self.assertRaises(dns.rdtypes.ANY.RRSIG.BadSigTime): + try: dns.rdata.from_text('in', 'rrsig', 'NSEC 1 3 3600 ' + '202001010000000 20030101000000 ' + '2143 foo Ym9ndXM=') - with self.assertRaises(dns.rdtypes.ANY.RRSIG.BadSigTime): + self.assertTrue(False) # should not happen + except dns.exception.SyntaxError as e: + self.assertIsInstance(e.__cause__, + dns.rdtypes.ANY.RRSIG.BadSigTime) + try: dns.rdata.from_text('in', 'rrsig', 'NSEC 1 3 3600 ' + '20200101000000 2003010100000 ' + '2143 foo Ym9ndXM=') + self.assertTrue(False) # should not happen + except dns.exception.SyntaxError as e: + self.assertIsInstance(e.__cause__, + dns.rdtypes.ANY.RRSIG.BadSigTime) def test_empty_TXT(self): # hit too long -- cgit v1.2.1