From 8d2f1ba94c573ea3791572c5e7565c5c9d82a80d Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Sat, 22 Aug 2020 10:42:12 -0700 Subject: finish type constructor type checking --- dns/rdata.py | 13 ++++++++----- dns/rdtypes/ANY/AMTRELAY.py | 8 ++++---- dns/rdtypes/ANY/CAA.py | 12 +++++------- dns/rdtypes/ANY/CERT.py | 8 ++++---- dns/rdtypes/ANY/GPOS.py | 18 ++++++------------ dns/rdtypes/ANY/HIP.py | 11 +++++------ dns/rdtypes/ANY/LOC.py | 39 +++++++++++++++++++++------------------ dns/rdtypes/ANY/NSEC3PARAM.py | 11 ++++------- dns/rdtypes/ANY/OPT.py | 5 ++++- dns/rdtypes/ANY/RRSIG.py | 18 +++++++++--------- dns/rdtypes/ANY/SSHFP.py | 6 +++--- dns/rdtypes/ANY/TKEY.py | 14 +++++++------- dns/rdtypes/ANY/TLSA.py | 8 ++++---- dns/rdtypes/ANY/TSIG.py | 15 ++++++++------- dns/rdtypes/ANY/URI.py | 11 ++++------- dns/rdtypes/txtbase.py | 7 ++----- 16 files changed, 98 insertions(+), 106 deletions(-) diff --git a/dns/rdata.py b/dns/rdata.py index ee26ceb..042623d 100644 --- a/dns/rdata.py +++ b/dns/rdata.py @@ -335,11 +335,6 @@ class Rdata: object.__setattr__(rd, 'rdcomment', rdcomment) return rd - def as_value(self, value): - # This is the "additional type checking" placeholder that actually - # doesn't do any additional checking. - return value - # Type checking and conversion helpers. These are class methods as # they don't touch object state and may be useful to others. @@ -396,6 +391,14 @@ class Rdata: raise ValueError('not a uint32') return value + @classmethod + def _as_uint48(cls, value): + if not isinstance(value, int): + raise ValueError('not an integer') + if value < 0 or value > 281474976710655: + raise ValueError('not a uint48') + return value + @classmethod def _as_int(cls, value, low=None, high=None): if not isinstance(value, int): diff --git a/dns/rdtypes/ANY/AMTRELAY.py b/dns/rdtypes/ANY/AMTRELAY.py index de6e99e..5a7eb91 100644 --- a/dns/rdtypes/ANY/AMTRELAY.py +++ b/dns/rdtypes/ANY/AMTRELAY.py @@ -38,10 +38,10 @@ class AMTRELAY(dns.rdata.Rdata): relay_type, relay): super().__init__(rdclass, rdtype) Relay(relay_type, relay).check() - self.precedence = self.as_value(precedence) - self.discovery_optional = self.as_value(discovery_optional) - self.relay_type = self.as_value(relay_type) - self.relay = self.as_value(relay) + self.precedence = self._as_uint8(precedence) + self.discovery_optional = self._as_bool(discovery_optional) + self.relay_type = self._as_uint8(relay_type) + self.relay = relay def to_text(self, origin=None, relativize=True, **kw): relay = Relay(self.relay_type, self.relay).to_text(origin, relativize) diff --git a/dns/rdtypes/ANY/CAA.py b/dns/rdtypes/ANY/CAA.py index 7c6dd01..c86b45e 100644 --- a/dns/rdtypes/ANY/CAA.py +++ b/dns/rdtypes/ANY/CAA.py @@ -34,9 +34,11 @@ class CAA(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, flags, tag, value): super().__init__(rdclass, rdtype) - self.flags = self.as_value(flags) - self.tag = self.as_value(tag) - self.value = self.as_value(value) + self.flags = self._as_uint8(flags) + self.tag = self._as_bytes(tag, True, 255) + if not tag.isalnum(): + raise ValueError("tag is not alphanumeric") + self.value = self._as_bytes(value) def to_text(self, origin=None, relativize=True, **kw): return '%u %s "%s"' % (self.flags, @@ -48,10 +50,6 @@ class CAA(dns.rdata.Rdata): relativize_to=None): flags = tok.get_uint8() tag = tok.get_string().encode() - if len(tag) > 255: - raise dns.exception.SyntaxError("tag too long") - if not tag.isalnum(): - raise dns.exception.SyntaxError("tag is not alphanumeric") value = tok.get_string().encode() return cls(rdclass, rdtype, flags, tag, value) diff --git a/dns/rdtypes/ANY/CERT.py b/dns/rdtypes/ANY/CERT.py index c78322a..6d663cc 100644 --- a/dns/rdtypes/ANY/CERT.py +++ b/dns/rdtypes/ANY/CERT.py @@ -67,10 +67,10 @@ class CERT(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, certificate_type, key_tag, algorithm, certificate): super().__init__(rdclass, rdtype) - self.certificate_type = self.as_value(certificate_type) - self.key_tag = self.as_value(key_tag) - self.algorithm = self.as_value(algorithm) - self.certificate = self.as_value(certificate) + self.certificate_type = self._as_uint16(certificate_type) + self.key_tag = self._as_uint16(key_tag) + self.algorithm = self._as_uint8(algorithm) + self.certificate = self._as_bytes(certificate) def to_text(self, origin=None, relativize=True, **kw): certificate_type = _ctype_to_text(self.certificate_type) diff --git a/dns/rdtypes/ANY/GPOS.py b/dns/rdtypes/ANY/GPOS.py index f9e3ed8..29fa8f8 100644 --- a/dns/rdtypes/ANY/GPOS.py +++ b/dns/rdtypes/ANY/GPOS.py @@ -42,12 +42,6 @@ def _validate_float_string(what): raise dns.exception.FormError -def _sanitize(value): - if isinstance(value, str): - return value.encode() - return value - - @dns.immutable.immutable class GPOS(dns.rdata.Rdata): @@ -68,15 +62,15 @@ class GPOS(dns.rdata.Rdata): if isinstance(altitude, float) or \ isinstance(altitude, int): altitude = str(altitude) - latitude = _sanitize(latitude) - longitude = _sanitize(longitude) - altitude = _sanitize(altitude) + latitude = self._as_bytes(latitude, True, 255) + longitude = self._as_bytes(longitude, True, 255) + altitude = self._as_bytes(altitude, True, 255) _validate_float_string(latitude) _validate_float_string(longitude) _validate_float_string(altitude) - self.latitude = self.as_value(latitude) - self.longitude = self.as_value(longitude) - self.altitude = self.as_value(altitude) + self.latitude = latitude + self.longitude = longitude + self.altitude = altitude flat = self.float_latitude if flat < -90.0 or flat > 90.0: raise dns.exception.FormError('bad latitude') diff --git a/dns/rdtypes/ANY/HIP.py b/dns/rdtypes/ANY/HIP.py index 4ed3507..610260d 100644 --- a/dns/rdtypes/ANY/HIP.py +++ b/dns/rdtypes/ANY/HIP.py @@ -36,10 +36,11 @@ class HIP(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, hit, algorithm, key, servers): super().__init__(rdclass, rdtype) - self.hit = self.as_value(hit) - self.algorithm = self.as_value(algorithm) - self.key = self.as_value(key) - self.servers = self.as_value(dns.rdata._constify(servers)) + self.hit = self._as_bytes(hit, True, 255) + self.algorithm = self._as_uint8(algorithm) + self.key = self._as_bytes(key, True) + self.servers = dns.rdata._constify([dns.rdata.Rdata._as_name(s) + for s in servers]) def to_text(self, origin=None, relativize=True, **kw): hit = binascii.hexlify(self.hit).decode() @@ -57,8 +58,6 @@ class HIP(dns.rdata.Rdata): relativize_to=None): algorithm = tok.get_uint8() hit = binascii.unhexlify(tok.get_string().encode()) - if len(hit) > 255: - raise dns.exception.SyntaxError("HIT too long") key = base64.b64decode(tok.get_string().encode()) servers = [] for token in tok.get_remaining(): diff --git a/dns/rdtypes/ANY/LOC.py b/dns/rdtypes/ANY/LOC.py index d2a7783..60b10b9 100644 --- a/dns/rdtypes/ANY/LOC.py +++ b/dns/rdtypes/ANY/LOC.py @@ -91,6 +91,19 @@ def _decode_size(what, desc): return base * pow(10, exponent) +def _check_coordinate_list(value, low, high): + if value[0] < low or value[0] > high: + raise ValueError(f'not in range [{low}, {high}]') + if value[1] < 0 or value[1] > 59: + raise ValueError('bad minutes value') + if value[2] < 0 or value[2] > 59: + raise ValueError('bad seconds value') + if value[3] < 0 or value[3] > 999: + raise ValueError('bad milliseconds value') + if value[4] != 1 and value[4] != -1: + raise ValueError('bad hemisphere value') + + @dns.immutable.immutable class LOC(dns.rdata.Rdata): @@ -117,16 +130,18 @@ class LOC(dns.rdata.Rdata): latitude = float(latitude) if isinstance(latitude, float): latitude = _float_to_tuple(latitude) - self.latitude = self.as_value(dns.rdata._constify(latitude)) + _check_coordinate_list(latitude, -90, 90) + self.latitude = dns.rdata._constify(latitude) if isinstance(longitude, int): longitude = float(longitude) if isinstance(longitude, float): longitude = _float_to_tuple(longitude) - self.longitude = self.as_value(dns.rdata._constify(longitude)) - self.altitude = self.as_value(float(altitude)) - self.size = self.as_value(float(size)) - self.horizontal_precision = self.as_value(float(hprec)) - self.vertical_precision = self.as_value(float(vprec)) + _check_coordinate_list(longitude, -180, 180) + self.longitude = dns.rdata._constify(longitude) + self.altitude = float(altitude) + self.size = float(size) + self.horizontal_precision = float(hprec) + self.vertical_precision = float(vprec) def to_text(self, origin=None, relativize=True, **kw): if self.latitude[4] > 0: @@ -165,13 +180,9 @@ class LOC(dns.rdata.Rdata): vprec = _default_vprec latitude[0] = tok.get_int() - if latitude[0] > 90: - raise dns.exception.SyntaxError('latitude >= 90') t = tok.get_string() if t.isdigit(): latitude[1] = int(t) - if latitude[1] >= 60: - raise dns.exception.SyntaxError('latitude minutes >= 60') t = tok.get_string() if '.' in t: (seconds, milliseconds) = t.split('.') @@ -179,8 +190,6 @@ class LOC(dns.rdata.Rdata): raise dns.exception.SyntaxError( 'bad latitude seconds value') latitude[2] = int(seconds) - if latitude[2] >= 60: - raise dns.exception.SyntaxError('latitude seconds >= 60') l = len(milliseconds) if l == 0 or l > 3 or not milliseconds.isdigit(): raise dns.exception.SyntaxError( @@ -202,13 +211,9 @@ class LOC(dns.rdata.Rdata): raise dns.exception.SyntaxError('bad latitude hemisphere value') longitude[0] = tok.get_int() - if longitude[0] > 180: - raise dns.exception.SyntaxError('longitude > 180') t = tok.get_string() if t.isdigit(): longitude[1] = int(t) - if longitude[1] >= 60: - raise dns.exception.SyntaxError('longitude minutes >= 60') t = tok.get_string() if '.' in t: (seconds, milliseconds) = t.split('.') @@ -216,8 +221,6 @@ class LOC(dns.rdata.Rdata): raise dns.exception.SyntaxError( 'bad longitude seconds value') longitude[2] = int(seconds) - if longitude[2] >= 60: - raise dns.exception.SyntaxError('longitude seconds >= 60') l = len(milliseconds) if l == 0 or l > 3 or not milliseconds.isdigit(): raise dns.exception.SyntaxError( diff --git a/dns/rdtypes/ANY/NSEC3PARAM.py b/dns/rdtypes/ANY/NSEC3PARAM.py index d31116f..299bf6e 100644 --- a/dns/rdtypes/ANY/NSEC3PARAM.py +++ b/dns/rdtypes/ANY/NSEC3PARAM.py @@ -32,13 +32,10 @@ class NSEC3PARAM(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, algorithm, flags, iterations, salt): super().__init__(rdclass, rdtype) - self.algorithm = self.as_value(algorithm) - self.flags = self.as_value(flags) - self.iterations = self.as_value(iterations) - if isinstance(salt, str): - self.salt = self.as_value(salt.encode()) - else: - self.salt = self.as_value(salt) + self.algorithm = self._as_uint8(algorithm) + self.flags = self._as_uint8(flags) + self.iterations = self._as_uint16(iterations) + self.salt = self._as_bytes(salt, True, 255) def to_text(self, origin=None, relativize=True, **kw): if self.salt == b'': diff --git a/dns/rdtypes/ANY/OPT.py b/dns/rdtypes/ANY/OPT.py index d962689..1968ce2 100644 --- a/dns/rdtypes/ANY/OPT.py +++ b/dns/rdtypes/ANY/OPT.py @@ -45,7 +45,10 @@ class OPT(dns.rdata.Rdata): """ super().__init__(rdclass, rdtype) - self.options = self.as_value(dns.rdata._constify(options)) + for option in options: + if not isinstance(option, dns.edns.Option): + raise ValueError('option is not a dns.edns.option') + self.options = dns.rdata._constify(options) def _to_wire(self, file, compress=None, origin=None, canonicalize=False): for opt in self.options: diff --git a/dns/rdtypes/ANY/RRSIG.py b/dns/rdtypes/ANY/RRSIG.py index 53cc55a..93c7f10 100644 --- a/dns/rdtypes/ANY/RRSIG.py +++ b/dns/rdtypes/ANY/RRSIG.py @@ -64,15 +64,15 @@ class RRSIG(dns.rdata.Rdata): original_ttl, expiration, inception, key_tag, signer, signature): super().__init__(rdclass, rdtype) - self.type_covered = self.as_value(type_covered) - self.algorithm = self.as_value(algorithm) - self.labels = self.as_value(labels) - self.original_ttl = self.as_value(original_ttl) - self.expiration = self.as_value(expiration) - self.inception = self.as_value(inception) - self.key_tag = self.as_value(key_tag) - self.signer = self.as_value(signer) - self.signature = self.as_value(signature) + self.type_covered = self._as_rdatatype(type_covered) + self.algorithm = dns.dnssec.Algorithm.make(algorithm) + self.labels = self._as_uint8(labels) + self.original_ttl = self._as_ttl(original_ttl) + self.expiration = self._as_uint32(expiration) + self.inception = self._as_uint32(inception) + self.key_tag = self._as_uint16(key_tag) + self.signer = self._as_name(signer) + self.signature = self._as_bytes(signature) def covers(self): return self.type_covered diff --git a/dns/rdtypes/ANY/SSHFP.py b/dns/rdtypes/ANY/SSHFP.py index dd222b4..4fd917c 100644 --- a/dns/rdtypes/ANY/SSHFP.py +++ b/dns/rdtypes/ANY/SSHFP.py @@ -35,9 +35,9 @@ class SSHFP(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, algorithm, fp_type, fingerprint): super().__init__(rdclass, rdtype) - self.algorithm = self.as_value(algorithm) - self.fp_type = self.as_value(fp_type) - self.fingerprint = self.as_value(fingerprint) + self.algorithm = self._as_uint8(algorithm) + self.fp_type = self._as_uint8(fp_type) + self.fingerprint = self._as_bytes(fingerprint, True) def to_text(self, origin=None, relativize=True, **kw): return '%d %d %s' % (self.algorithm, diff --git a/dns/rdtypes/ANY/TKEY.py b/dns/rdtypes/ANY/TKEY.py index 871578a..f8c4737 100644 --- a/dns/rdtypes/ANY/TKEY.py +++ b/dns/rdtypes/ANY/TKEY.py @@ -35,13 +35,13 @@ class TKEY(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, algorithm, inception, expiration, mode, error, key, other=b''): super().__init__(rdclass, rdtype) - self.algorithm = self.as_value(algorithm) - self.inception = self.as_value(inception) - self.expiration = self.as_value(expiration) - self.mode = self.as_value(mode) - self.error = self.as_value(error) - self.key = self.as_value(dns.rdata._constify(key)) - self.other = self.as_value(dns.rdata._constify(other)) + self.algorithm = self._as_name(algorithm) + self.inception = self._as_uint32(inception) + self.expiration = self._as_uint32(expiration) + self.mode = self._as_uint16(mode) + self.error = self._as_uint16(error) + self.key = self._as_bytes(key) + self.other = self._as_bytes(other) def to_text(self, origin=None, relativize=True, **kw): _algorithm = self.algorithm.choose_relativity(origin, relativize) diff --git a/dns/rdtypes/ANY/TLSA.py b/dns/rdtypes/ANY/TLSA.py index 5e7dc19..ad8dc8d 100644 --- a/dns/rdtypes/ANY/TLSA.py +++ b/dns/rdtypes/ANY/TLSA.py @@ -35,10 +35,10 @@ class TLSA(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, usage, selector, mtype, cert): super().__init__(rdclass, rdtype) - self.usage = self.as_value(usage) - self.selector = self.as_value(selector) - self.mtype = self.as_value(mtype) - self.cert = self.as_value(cert) + self.usage = self._as_uint8(usage) + self.selector = self._as_uint8(selector) + self.mtype = self._as_uint8(mtype) + self.cert = self._as_bytes(cert) def to_text(self, origin=None, relativize=True, **kw): return '%d %d %d %s' % (self.usage, diff --git a/dns/rdtypes/ANY/TSIG.py b/dns/rdtypes/ANY/TSIG.py index e179d62..e49bf73 100644 --- a/dns/rdtypes/ANY/TSIG.py +++ b/dns/rdtypes/ANY/TSIG.py @@ -20,6 +20,7 @@ import struct import dns.exception import dns.immutable +import dns.rcode import dns.rdata @@ -55,13 +56,13 @@ class TSIG(dns.rdata.Rdata): """ super().__init__(rdclass, rdtype) - self.algorithm = self.as_value(algorithm) - self.time_signed = self.as_value(time_signed) - self.fudge = self.as_value(fudge) - self.mac = self.as_value(dns.rdata._constify(mac)) - self.original_id = self.as_value(original_id) - self.error = self.as_value(error) - self.other = self.as_value(dns.rdata._constify(other)) + self.algorithm = self._as_name(algorithm) + self.time_signed = self._as_uint48(time_signed) + self.fudge = self._as_uint16(fudge) + self.mac = dns.rdata._constify(self._as_bytes(mac)) + self.original_id = self._as_uint16(original_id) + self.error = dns.rcode.Rcode.make(error) + self.other = self._as_bytes(other) def to_text(self, origin=None, relativize=True, **kw): algorithm = self.algorithm.choose_relativity(origin, relativize) diff --git a/dns/rdtypes/ANY/URI.py b/dns/rdtypes/ANY/URI.py index 0892bd8..60a43c8 100644 --- a/dns/rdtypes/ANY/URI.py +++ b/dns/rdtypes/ANY/URI.py @@ -35,14 +35,11 @@ class URI(dns.rdata.Rdata): def __init__(self, rdclass, rdtype, priority, weight, target): super().__init__(rdclass, rdtype) - self.priority = self.as_value(priority) - self.weight = self.as_value(weight) - if len(target) < 1: + self.priority = self._as_uint16(priority) + self.weight = self._as_uint16(weight) + self.target = self._as_bytes(target, True) + if len(self.target) == 0: raise dns.exception.SyntaxError("URI target cannot be empty") - if isinstance(target, str): - self.target = self.as_value(target.encode()) - else: - self.target = self.as_value(target) def to_text(self, origin=None, relativize=True, **kw): return '%d %d "%s"' % (self.priority, self.weight, diff --git a/dns/rdtypes/txtbase.py b/dns/rdtypes/txtbase.py index 6539c5a..a170ced 100644 --- a/dns/rdtypes/txtbase.py +++ b/dns/rdtypes/txtbase.py @@ -46,12 +46,9 @@ class TXTBase(dns.rdata.Rdata): strings = (strings,) encoded_strings = [] for string in strings: - if isinstance(string, str): - string = string.encode() - else: - string = dns.rdata._constify(string) + string = self._as_bytes(string, True, 255) encoded_strings.append(string) - self.strings = self.as_value(tuple(encoded_strings)) + self.strings = dns.rdata._constify(encoded_strings) def to_text(self, origin=None, relativize=True, **kw): txt = '' -- cgit v1.2.1