From 54deb97c2a5331fe99a12d720f24fb481ec31576 Mon Sep 17 00:00:00 2001 From: Bob Halley Date: Thu, 30 Jul 2020 09:21:03 -0700 Subject: txn checkpoint --- dns/__init__.py | 4 + dns/_immutable_attr.py | 64 ++++++ dns/_immutable_ctx.py | 58 +++++ dns/exception.py | 2 +- dns/immutable.py | 17 +- dns/masterfile.py | 404 +++++++++++++++++++++++++++++++++++ dns/node.py | 30 +++ dns/rdataset.py | 47 +++++ dns/transaction.py | 383 +++++++++++++++++++++++++++++++++ dns/versioned.py | 392 ++++++++++++++++++++++++++++++++++ dns/zone.py | 525 ++++++++++------------------------------------ tests/test_immutable.py | 134 +++++++++++- tests/test_rdataset.py | 29 +++ tests/test_transaction.py | 451 +++++++++++++++++++++++++++++++++++++++ 14 files changed, 2115 insertions(+), 425 deletions(-) create mode 100644 dns/_immutable_attr.py create mode 100644 dns/_immutable_ctx.py create mode 100644 dns/masterfile.py create mode 100644 dns/transaction.py create mode 100644 dns/versioned.py create mode 100644 tests/test_transaction.py diff --git a/dns/__init__.py b/dns/__init__.py index b944701..eafdcc4 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -27,9 +27,11 @@ __all__ = [ 'entropy', 'exception', 'flags', + 'immutable', 'inet', 'ipv4', 'ipv6', + 'masterfile', 'message', 'name', 'namedict', @@ -48,12 +50,14 @@ __all__ = [ 'serial', 'set', 'tokenizer', + 'transaction', 'tsig', 'tsigkeyring', 'ttl', 'rdtypes', 'update', 'version', + 'versioned', 'wire', 'zone', ] diff --git a/dns/_immutable_attr.py b/dns/_immutable_attr.py new file mode 100644 index 0000000..4221967 --- /dev/null +++ b/dns/_immutable_attr.py @@ -0,0 +1,64 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# This implementation of the immutable decorator is for python 3.6, +# which doesn't have Context Variables. This implementation is somewhat +# costly for classes with slots, as it adds a __dict__ to them. + +class _Immutable: + """Immutable mixin class""" + + # Note we MUST NOT have __slots__ as that causes + # + # TypeError: multiple bases have instance lay-out conflict + # + # when we get mixed in with another class with slots. When we + # get mixed into something with slots, it effectively adds __dict__ to + # the slots of the other class, which allows attribute setting to work, + # albeit at the cost of the dictionary. + + def __setattr__(self, name, value): + if not hasattr(self, '_immutable_init') or \ + self._immutable_init is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if not hasattr(self, '_immutable_init') or \ + self._immutable_init is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__delattr__(name) + + +def _immutable_init(f): + def nf(*args, **kwargs): + try: + # Are we already initializing an immutable class? + previous = args[0]._immutable_init + except AttributeError: + # We are the first! + previous = None + object.__setattr__(args[0], '_immutable_init', args[0]) + # call the actual __init__ + f(*args, **kwargs) + if not previous: + # If we started the initialzation, establish immutability + # by removing the attribute that allows mutation + object.__delattr__(args[0], '_immutable_init') + return nf + + +def immutable(cls): + if _Immutable in cls.__mro__: + # Some ancestor already has the mixin, so just make sure we keep + # following the __init__ protocol. + cls.__init__ = _immutable_init(cls.__init__) + ncls = cls + else: + # Mixin the Immutable class and follow the __init__ protocol. + class ncls(_Immutable, cls): + @_immutable_init + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + return ncls diff --git a/dns/_immutable_ctx.py b/dns/_immutable_ctx.py new file mode 100644 index 0000000..017310d --- /dev/null +++ b/dns/_immutable_ctx.py @@ -0,0 +1,58 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# This implementation of the immutable decorator requires python >= +# 3.7, and is significantly more storage efficient when making classes +# with slots immutable. It's also faster. + +import contextvars + +_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False) + + +class _Immutable: + """Immutable mixin class""" + + # We set slots to the empty list to say "we don't have any attributes". + # We do this so that if we're mixed in with a class with __slots__, we + # don't cause a __dict__ to be added which would waste space. + + __slots__ = () + + def __setattr__(self, name, value): + if _in__init__.get() is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if _in__init__.get() is not self: + raise TypeError("object doesn't support attribute assignment") + else: + super().__delattr__(name) + + +def _immutable_init(f): + def nf(*args, **kwargs): + previous = _in__init__.set(args[0]) + # call the actual __init__ + f(*args, **kwargs) + _in__init__.reset(previous) + return nf + + +def immutable(cls): + if _Immutable in cls.__mro__: + # Some ancestor already has the mixin, so just make sure we keep + # following the __init__ protocol. + cls.__init__ = _immutable_init(cls.__init__) + ncls = cls + else: + # Mixin the Immutable class and follow the __init__ protocol. + class ncls(_Immutable, cls): + # We have to do the __slots__ declaration here too! + __slots__ = () + + @_immutable_init + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + return ncls diff --git a/dns/exception.py b/dns/exception.py index 9486f45..9392373 100644 --- a/dns/exception.py +++ b/dns/exception.py @@ -138,5 +138,5 @@ class ExceptionWrapper: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None and not isinstance(exc_val, self.exception_class): - raise self.exception_class() from exc_val + raise self.exception_class(str(exc_val)) from exc_val return False diff --git a/dns/immutable.py b/dns/immutable.py index dc48fe8..7cc39dd 100644 --- a/dns/immutable.py +++ b/dns/immutable.py @@ -3,13 +3,19 @@ import collections.abc import sys +# pylint: disable=unused-import if sys.version_info >= (3, 7): odict = dict + from dns._immutable_ctx import immutable else: - from collections import OrderedDict as odict # pragma: no cover + # pragma: no cover + from collections import OrderedDict as odict + from dns._immutable_attr import immutable # noqa +# pylint: enable=unused-import -class ImmutableDict(collections.abc.Mapping): +@immutable +class Dict(collections.abc.Mapping): def __init__(self, dictionary, no_copy=False): """Make an immutable dictionary from the specified dictionary. @@ -28,9 +34,10 @@ class ImmutableDict(collections.abc.Mapping): def __hash__(self): if self._hash is None: - self._hash = 0 + h = 0 for key in sorted(self._odict.keys()): - self._hash ^= hash(key) + h ^= hash(key) + object.__setattr__(self, '_hash', h) return self._hash def __len__(self): @@ -58,5 +65,5 @@ def constify(o): cdict = odict() for k, v in o.items(): cdict[k] = constify(v) - return ImmutableDict(cdict, True) + return Dict(cdict, True) return o diff --git a/dns/masterfile.py b/dns/masterfile.py new file mode 100644 index 0000000..30553b5 --- /dev/null +++ b/dns/masterfile.py @@ -0,0 +1,404 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +"""DNS Zones.""" + +import re +import sys + +import dns.exception +import dns.name +import dns.node +import dns.rdataclass +import dns.rdatatype +import dns.rdata +import dns.rdtypes.ANY.SOA +import dns.rrset +import dns.tokenizer +import dns.transaction +import dns.ttl +import dns.grange + + +class UnknownOrigin(dns.exception.DNSException): + """Unknown origin""" + + +class Reader: + + """Read a DNS master file into a transaction.""" + + def __init__(self, tok, origin, rdclass, relativize, txn, + allow_include=False): + if isinstance(origin, str): + origin = dns.name.from_text(origin) + self.tok = tok + self.current_origin = origin + self.relativize = relativize + self.last_ttl = 0 + self.last_ttl_known = False + self.default_ttl = 0 + self.default_ttl_known = False + self.last_name = self.current_origin + self.zone_origin = origin + self.zone_rdclass = rdclass + self.txn = txn + self.saved_state = [] + self.current_file = None + self.allow_include = allow_include + + def _eat_line(self): + while 1: + token = self.tok.get() + if token.is_eol_or_eof(): + break + + def _rr_line(self): + """Process one line from a DNS master file.""" + # Name + if self.current_origin is None: + raise UnknownOrigin + token = self.tok.get(want_leading=True) + if not token.is_whitespace(): + self.last_name = self.tok.as_name(token, self.current_origin) + else: + token = self.tok.get() + if token.is_eol_or_eof(): + # treat leading WS followed by EOL/EOF as if they were EOL/EOF. + return + self.tok.unget(token) + name = self.last_name + if not name.is_subdomain(self.zone_origin): + self._eat_line() + return + if self.relativize: + name = name.relativize(self.zone_origin) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + + # TTL + ttl = None + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.ttl.BadTTL: + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + + # Class + try: + rdclass = dns.rdataclass.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.exception.SyntaxError: + raise + except Exception: + rdclass = self.zone_rdclass + if rdclass != self.zone_rdclass: + raise dns.exception.SyntaxError("RR class is not zone's class") + # Type + try: + rdtype = dns.rdatatype.from_text(token.value) + except Exception: + raise dns.exception.SyntaxError( + "unknown rdatatype '%s'" % token.value) + try: + rd = dns.rdata.from_text(rdclass, rdtype, self.tok, + self.current_origin, self.relativize, + self.zone_origin) + except dns.exception.SyntaxError: + # Catch and reraise. + raise + except Exception: + # All exceptions that occur in the processing of rdata + # are treated as syntax errors. This is not strictly + # correct, but it is correct almost all of the time. + # We convert them to syntax errors so that we can emit + # helpful filename:line info. + (ty, va) = sys.exc_info()[:2] + raise dns.exception.SyntaxError( + "caught exception {}: {}".format(str(ty), str(va))) + + if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: + # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default + # TTL from the SOA minttl if no $TTL statement is present before the + # SOA is parsed. + self.default_ttl = rd.minimum + self.default_ttl_known = True + if ttl is None: + # if we didn't have a TTL on the SOA, set it! + ttl = rd.minimum + + # TTL check. We had to wait until now to do this as the SOA RR's + # own TTL can be inferred from its minimum. + if ttl is None: + raise dns.exception.SyntaxError("Missing default TTL value") + + self.txn.add(name, ttl, rd) + + def _parse_modify(self, side): + # Here we catch everything in '{' '}' in a group so we can replace it + # with ''. + is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$") + is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$") + is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$") + # Sometimes there are modifiers in the hostname. These come after + # the dollar sign. They are in the form: ${offset[,width[,base]]}. + # Make names + g1 = is_generate1.match(side) + if g1: + mod, sign, offset, width, base = g1.groups() + if sign == '': + sign = '+' + g2 = is_generate2.match(side) + if g2: + mod, sign, offset = g2.groups() + if sign == '': + sign = '+' + width = 0 + base = 'd' + g3 = is_generate3.match(side) + if g3: + mod, sign, offset, width = g3.groups() + if sign == '': + sign = '+' + base = 'd' + + if not (g1 or g2 or g3): + mod = '' + sign = '+' + offset = 0 + width = 0 + base = 'd' + + if base != 'd': + raise NotImplementedError() + + return mod, sign, offset, width, base + + def _generate_line(self): + # range lhs [ttl] [class] type rhs [ comment ] + """Process one line containing the GENERATE statement from a DNS + master file.""" + if self.current_origin is None: + raise UnknownOrigin + + token = self.tok.get() + # Range (required) + try: + start, stop, step = dns.grange.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError + + # lhs (required) + try: + lhs = token.value + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError + + # TTL + try: + ttl = dns.ttl.from_text(token.value) + self.last_ttl = ttl + self.last_ttl_known = True + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.ttl.BadTTL: + if not (self.last_ttl_known or self.default_ttl_known): + raise dns.exception.SyntaxError("Missing default TTL value") + if self.default_ttl_known: + ttl = self.default_ttl + elif self.last_ttl_known: + ttl = self.last_ttl + # Class + try: + rdclass = dns.rdataclass.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except dns.exception.SyntaxError: + raise dns.exception.SyntaxError + except Exception: + rdclass = self.zone_rdclass + if rdclass != self.zone_rdclass: + raise dns.exception.SyntaxError("RR class is not zone's class") + # Type + try: + rdtype = dns.rdatatype.from_text(token.value) + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError + except Exception: + raise dns.exception.SyntaxError("unknown rdatatype '%s'" % + token.value) + + # rhs (required) + rhs = token.value + + # The code currently only supports base 'd', so the last value + # in the tuple _parse_modify returns is ignored + lmod, lsign, loffset, lwidth, _ = self._parse_modify(lhs) + rmod, rsign, roffset, rwidth, _ = self._parse_modify(rhs) + for i in range(start, stop + 1, step): + # +1 because bind is inclusive and python is exclusive + + if lsign == '+': + lindex = i + int(loffset) + elif lsign == '-': + lindex = i - int(loffset) + + if rsign == '-': + rindex = i - int(roffset) + elif rsign == '+': + rindex = i + int(roffset) + + lzfindex = str(lindex).zfill(int(lwidth)) + rzfindex = str(rindex).zfill(int(rwidth)) + + name = lhs.replace('$%s' % (lmod), lzfindex) + rdata = rhs.replace('$%s' % (rmod), rzfindex) + + self.last_name = dns.name.from_text(name, self.current_origin, + self.tok.idna_codec) + name = self.last_name + if not name.is_subdomain(self.zone_origin): + self._eat_line() + return + if self.relativize: + name = name.relativize(self.zone_origin) + + try: + rd = dns.rdata.from_text(rdclass, rdtype, rdata, + self.current_origin, self.relativize, + self.zone_origin) + except dns.exception.SyntaxError: + # Catch and reraise. + raise + except Exception: + # All exceptions that occur in the processing of rdata + # are treated as syntax errors. This is not strictly + # correct, but it is correct almost all of the time. + # We convert them to syntax errors so that we can emit + # helpful filename:line info. + (ty, va) = sys.exc_info()[:2] + raise dns.exception.SyntaxError("caught exception %s: %s" % + (str(ty), str(va))) + + self.txn.add(name, ttl, rd) + + def read(self): + """Read a DNS master file and build a zone object. + + @raises dns.zone.NoSOA: No SOA RR was found at the zone origin + @raises dns.zone.NoNS: No NS RRset was found at the zone origin + """ + + try: + while 1: + token = self.tok.get(True, True) + if token.is_eof(): + if self.current_file is not None: + self.current_file.close() + if len(self.saved_state) > 0: + (self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known) = self.saved_state.pop(-1) + continue + break + elif token.is_eol(): + continue + elif token.is_comment(): + self.tok.get_eol() + continue + elif token.value[0] == '$': + c = token.value.upper() + if c == '$TTL': + token = self.tok.get() + if not token.is_identifier(): + raise dns.exception.SyntaxError("bad $TTL") + self.default_ttl = dns.ttl.from_text(token.value) + self.default_ttl_known = True + self.tok.get_eol() + elif c == '$ORIGIN': + self.current_origin = self.tok.get_name() + self.tok.get_eol() + if self.zone_origin is None: + self.zone_origin = self.current_origin + self.txn._set_origin(self.current_origin) + elif c == '$INCLUDE' and self.allow_include: + token = self.tok.get() + filename = token.value + token = self.tok.get() + if token.is_identifier(): + new_origin =\ + dns.name.from_text(token.value, + self.current_origin, + self.tok.idna_codec) + self.tok.get_eol() + elif not token.is_eol_or_eof(): + raise dns.exception.SyntaxError( + "bad origin in $INCLUDE") + else: + new_origin = self.current_origin + self.saved_state.append((self.tok, + self.current_origin, + self.last_name, + self.current_file, + self.last_ttl, + self.last_ttl_known, + self.default_ttl, + self.default_ttl_known)) + self.current_file = open(filename, 'r') + self.tok = dns.tokenizer.Tokenizer(self.current_file, + filename) + self.current_origin = new_origin + elif c == '$GENERATE': + self._generate_line() + else: + raise dns.exception.SyntaxError( + "Unknown master file directive '" + c + "'") + continue + self.tok.unget(token) + self._rr_line() + except dns.exception.SyntaxError as detail: + (filename, line_number) = self.tok.where() + if detail is None: + detail = "syntax error" + ex = dns.exception.SyntaxError( + "%s:%d: %s" % (filename, line_number, detail)) + tb = sys.exc_info()[2] + raise ex.with_traceback(tb) from None diff --git a/dns/node.py b/dns/node.py index b7e21b5..8e1451f 100644 --- a/dns/node.py +++ b/dns/node.py @@ -183,3 +183,33 @@ class Node: self.delete_rdataset(replacement.rdclass, replacement.rdtype, replacement.covers) self.rdatasets.append(replacement) + + +@dns.immutable.immutable +class ImmutableNode(Node): + + """An ImmutableNode is an immutable set of rdatasets.""" + + def __init__(self, node): + super().__init__() + self.rdatasets = tuple( + [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets] + ) + + def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().find_rdataset(rdclass, rdtype, covers, False) + + def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise TypeError("immutable") + return super().get_rdataset(rdclass, rdtype, covers, False) + + def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE): + raise TypeError("immutable") + + def replace_rdataset(self, replacement): + raise TypeError("immutable") diff --git a/dns/rdataset.py b/dns/rdataset.py index ba93ab4..1f372cd 100644 --- a/dns/rdataset.py +++ b/dns/rdataset.py @@ -22,6 +22,7 @@ import random import struct import dns.exception +import dns.immutable import dns.rdatatype import dns.rdataclass import dns.rdata @@ -306,6 +307,52 @@ class Rdataset(dns.set.Set): return False +@dns.immutable.immutable +class ImmutableRdataset(Rdataset): + + """An immutable DNS rdataset.""" + + def __init__(self, rdataset): + """Create an immutable rdataset from the specified rdataset.""" + + super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers, + rdataset.ttl) + self.items = dns.immutable.Dict(rdataset.items) + + def update_ttl(self, ttl): + raise TypeError('immutable') + + def add(self, rd, ttl=None): + raise TypeError('immutable') + + def union_update(self, other): + raise TypeError('immutable') + + def intersection_update(self, other): + raise TypeError('immutable') + + def update(self, other): + raise TypeError('immutable') + + def __delitem__(self, i): + raise TypeError('immutable') + + def __ior__(self, other): + raise TypeError('immutable') + + def __iand__(self, other): + raise TypeError('immutable') + + def __iadd__(self, other): + raise TypeError('immutable') + + def __isub__(self, other): + raise TypeError('immutable') + + def clear(self): + raise TypeError('immutable') + + def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None, origin=None, relativize=True, relativize_to=None): """Create an rdataset with the specified class, type, and TTL, and with diff --git a/dns/transaction.py b/dns/transaction.py new file mode 100644 index 0000000..20d6939 --- /dev/null +++ b/dns/transaction.py @@ -0,0 +1,383 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import collections + +import dns.exception +import dns.name +import dns.rdataclass +import dns.rdataset +import dns.rdatatype +import dns.rrset +import dns.ttl + + +class TransactionManager: + def reader(self): + """Begin a read-only transaction.""" + raise NotImplementedError # pragma: no cover + + def writer(self, replacement=False): + """Begin a writable transaction. + + *replacement*, a `bool`. If `True`, the content of the + transaction completely replaces any prior content. If False, + the default, then the content of the transaction updates the + existing content. + """ + raise NotImplementedError # pragma: no cover + + +class DeleteNotExact(dns.exception.DNSException): + """Existing data did not match data specified by an exact delete.""" + + +class ReadOnly(dns.exception.DNSException): + """Tried to write to a read-only transaction.""" + + +class Transaction: + + def __init__(self, replacement=False, read_only=False): + self.replacement = replacement + self.read_only = read_only + + # + # This is the high level API + # + + def get(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE): + """Return the rdataset associated with *name*, *rdclass*, *rdtype*, + and *covers*, or `None` if not found. + + Note that the returned rdataset is immutable. + """ + if isinstance(name, str): + name = dns.name.from_text(name, None) + rdclass = dns.rdataclass.RdataClass.make(rdclass) + rdtype = dns.rdatatype.RdataType.make(rdtype) + rdataset = self._get_rdataset(name, rdclass, rdtype, covers) + if rdataset is not None and \ + not isinstance(rdataset, dns.rdataset.ImmutableRdataset): + rdataset = dns.rdataset.ImmutableRdataset(rdataset) + return rdataset + + def _check_read_only(self): + if self.read_only: + raise ReadOnly + + def add(self, *args): + """Add records. + + The arguments may be: + + - rrset + + - name, rdataset... + + - name, ttl, rdata... + """ + self._check_read_only() + return self._add(False, args) + + def replace(self, *args): + """Replace the existing rdataset at the name with the specified + rdataset, or add the specified rdataset if there was no existing + rdataset. + + The arguments may be: + + - rrset + + - name, rdataset... + + - name, ttl, rdata... + + Note that if you want to replace the entire node, you should do + a delete of the name followed by one or more calls to add() or + replace(). + """ + self._check_read_only() + return self._add(True, args) + + def delete(self, *args): + """Delete records. + + It is not an error if some of the records are not in the existing + set. + + The arguments may be: + + - rrset + + - name + + - name, rdataclass, rdatatype, [covers] + + - name, rdataset... + + - name, rdata... + """ + self._check_read_only() + return self._delete(False, args) + + def delete_exact(self, *args): + """Delete records. + + The arguments may be: + + - rrset + + - name + + - name, rdataclass, rdatatype, [covers] + + - name, rdataset... + + - name, rdata... + + Raises dns.transaction.DeleteNotExact if some of the records + are not in the existing set. + + """ + self._check_read_only() + return self._delete(True, args) + + def name_exists(self, name): + """Does the specified name exist?""" + if isinstance(name, str): + name = dns.name.from_text(name, None) + return self._name_exists(name) + + def set_serial(self, increment=1, value=None, name=dns.name.empty, + rdclass=dns.rdataclass.IN): + if isinstance(name, str): + name = dns.name.from_text(name, None) + rdataset = self._get_rdataset(name, rdclass, dns.rdatatype.SOA, + dns.rdatatype.NONE) + if rdataset is None or len(rdataset) == 0: + raise KeyError + if value is not None: + serial = value + else: + serial = rdataset[0].serial + serial += increment + if serial > 0xffffffff or serial < 1: + serial = 1 + rdata = rdataset[0].replace(serial=serial) + new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata) + self.replace(name, new_rdataset) + + def __iter__(self): + return self._iterate_rdatasets() + + # + # Helper methods + # + + def _raise_if_not_empty(self, method, args): + if len(args) != 0: + raise TypeError(f'extra parameters to {method}') + + def _rdataset_from_args(self, method, deleting, args): + try: + arg = args.popleft() + if isinstance(arg, dns.rdataset.Rdataset): + rdataset = arg + else: + if deleting: + ttl = 0 + else: + if isinstance(arg, int): + ttl = arg + if ttl > dns.ttl.MAX_TTL: + raise ValueError(f'{method}: TTL value too big') + else: + raise TypeError(f'{method}: expected a TTL') + arg = args.popleft() + if isinstance(arg, dns.rdata.Rdata): + rdataset = dns.rdataset.from_rdata(ttl, arg) + else: + raise TypeError(f'{method}: expected an Rdata') + return rdataset + except IndexError: + if deleting: + return None + else: + # reraise + raise TypeError(f'{method}: expected more arguments') + + def _add(self, replace, args): + try: + args = collections.deque(args) + if replace: + method = 'replace()' + else: + method = 'add()' + arg = args.popleft() + if isinstance(arg, str): + arg = dns.name.from_text(arg, None) + if isinstance(arg, dns.name.Name): + name = arg + rdataset = self._rdataset_from_args(method, False, args) + elif isinstance(arg, dns.rrset.RRset): + rrset = arg + name = rrset.name + # rrsets are also rdatasets, but they don't print the + # same, so convert. + rdataset = dns.rdataset.Rdataset(rrset.rdclass, rrset.rdtype, + rrset.covers, rrset.ttl) + rdataset.union_update(rrset) + else: + raise TypeError(f'{method} requires a name or RRset ' + + 'as the first argument') + self._raise_if_not_empty(method, args) + if not replace: + existing = self._get_rdataset(name, rdataset.rdclass, + rdataset.rdtype, rdataset.covers) + if existing is not None: + if isinstance(existing, dns.rdataset.ImmutableRdataset): + trds = dns.rdataset.Rdataset(existing.rdclass, + existing.rdtype, + existing.covers) + trds.update(existing) + existing = trds + rdataset = existing.union(rdataset) + self._put_rdataset(name, rdataset) + except IndexError: + raise TypeError(f'not enough parameters to {method}') + + def _delete(self, exact, args): + try: + args = collections.deque(args) + if exact: + method = 'delete_exact()' + else: + method = 'delete()' + arg = args.popleft() + if isinstance(arg, str): + arg = dns.name.from_text(arg, None) + if isinstance(arg, dns.name.Name): + name = arg + if len(args) > 0 and isinstance(args[0], int): + # deleting by type and class + rdclass = dns.rdataclass.RdataClass.make(args.popleft()) + rdtype = dns.rdatatype.RdataType.make(args.popleft()) + if len(args) > 0: + covers = dns.rdatatype.RdataType.make(args.popleft()) + else: + covers = dns.rdatatype.NONE + self._raise_if_not_empty(method, args) + existing = self._get_rdataset(name, rdclass, rdtype, covers) + if existing is None: + if exact: + raise DeleteNotExact(f'{method}: missing rdataset') + else: + self._delete_rdataset(name, rdclass, rdtype, covers) + return + else: + rdataset = self._rdataset_from_args(method, True, args) + elif isinstance(arg, dns.rrset.RRset): + rdataset = arg # rrsets are also rdatasets + name = rdataset.name + else: + raise TypeError(f'{method} requires a name or RRset ' + + 'as the first argument') + self._raise_if_not_empty(method, args) + if rdataset: + existing = self._get_rdataset(name, rdataset.rdclass, + rdataset.rdtype, rdataset.covers) + if existing is not None: + if exact: + intersection = existing.intersection(rdataset) + if intersection != rdataset: + raise DeleteNotExact(f'{method}: missing rdatas') + rdataset = existing.difference(rdataset) + if len(rdataset) == 0: + self._delete_rdataset(name, rdataset.rdclass, + rdataset.rdtype, rdataset.covers) + else: + self._put_rdataset(name, rdataset) + elif exact: + raise DeleteNotExact(f'{method}: missing rdataset') + else: + if exact and not self._name_exists(name): + raise DeleteNotExact(f'{method}: name not known') + self._delete_name(name) + except IndexError: + raise TypeError(f'not enough parameters to {method}') + + # + # Transactions are context managers. + # + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self._end_transaction(True) + else: + self._end_transaction(False) + return False + + # + # This is the low level API, which must be implemented by subclasses + # of Transaction. + # + + def _get_rdataset(self, name, rdclass, rdtype, covers): + """Return the rdataset associated with *name*, *rdclass*, *rdtype*, + and *covers*, or `None` if not found.""" + raise NotImplementedError # pragma: no cover + + def _put_rdataset(self, name, rdataset): + """Store the rdataset.""" + raise NotImplementedError # pragma: no cover + + def _delete_name(self, name): + """Delete all data associated with *name*. + + It is not an error if the rdataset does not exist. + """ + raise NotImplementedError # pragma: no cover + + def _delete_rdataset(self, name, rdclass, rdtype, covers): + """Delete all data associated with *name*, *rdclass*, *rdtype*, and + *covers*. + + It is not an error if the rdataset does not exist. + """ + raise NotImplementedError # pragma: no cover + + def _name_exists(self, name): + """Does name exist? + + Returns a bool. + """ + raise NotImplementedError # pragma: no cover + + def _end_transaction(self, commit): + """End the transaction. + + *commit*, a bool. If ``True``, commit the transaction, otherwise + roll it back. + + Raises an exception if committing failed. + """ + raise NotImplementedError # pragma: no cover + + def _set_origin(self, origin): + """Set the origin. + + This method is called when reading a possibly relativized + source, and an origin setting operation occurs (e.g. $ORIGIN + in a masterfile). + """ + raise NotImplementedError # pragma: no cover + + def _iterate_rdatasets(self): + """Return an iterator that yields (name, rdataset) tuples. + + Not all Transaction subclasses implement this. + """ + raise NotImplementedError # pragma: no cover diff --git a/dns/versioned.py b/dns/versioned.py new file mode 100644 index 0000000..6f911e1 --- /dev/null +++ b/dns/versioned.py @@ -0,0 +1,392 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +"""DNS Versioned Zones.""" + +import collections +try: + import threading as _threading +except ImportError: # pragma: no cover + import dummy_threading as _threading # type: ignore + +import dns.exception +import dns.immutable +import dns.name +import dns.node +import dns.rdataclass +import dns.rdatatype +import dns.rdata +import dns.rdtypes.ANY.SOA +import dns.transaction +import dns.zone + + +class UseTransaction(dns.exception.DNSException): + """To alter a versioned zone, use a transaction.""" + + +class Version: + def __init__(self, zone, id): + self.zone = zone + self.id = id + self.nodes = {} + + def _validate_name(self, name): + if name.is_absolute(): + if not name.is_subdomain(self.zone.origin): + raise KeyError("name is not a subdomain of the zone origin") + if self.zone.relativize: + name = name.relativize(self.origin) + return name + + def get_node(self, name): + name = self._validate_name(name) + return self.nodes.get(name) + + def get_rdataset(self, name, rdtype, covers): + node = self.get_node(name) + if node is None: + return None + return node.get_rdataset(self.zone.rdclass, rdtype, covers) + + def items(self): + return self.nodes.items() # pylint: disable=dict-items-not-iterating + + def _print(self): # pragma: no cover + # XXXRTH This is for debugging + print('VERSION', self.id) + for (name, node) in self.nodes.items(): + for rdataset in node: + print(rdataset.to_text(name)) + + +class WritableVersion(Version): + def __init__(self, zone, replacement=False): + if len(zone.versions) > 0: + id = zone.versions[-1].id + 1 + else: + id = 1 + super().__init__(zone, id) + if not replacement: + # We copy the map, because that gives us a simple and thread-safe + # way of doing versions, and we have a garbage collector to help + # us. We only make new node objects if we actually change the + # node. + self.nodes.update(zone.nodes) + # We have to copy the zone origin as it may be None in the first + # version, and we don't want to mutate the zone until we commit. + self.origin = zone.origin + self.changed = set() + + def _validate_name(self, name): + if name.is_absolute(): + if not name.is_subdomain(self.origin): + raise KeyError("name is not a subdomain of the zone origin") + if self.zone.relativize: + name = name.relativize(self.origin) + return name + + def _maybe_cow(self, name): + name = self._validate_name(name) + node = self.nodes.get(name) + if node is None or node.id != self.id: + new_node = self.zone.node_factory() + new_node.id = self.id + if node is not None: + # moo! copy on write! + new_node.rdatasets.extend(node.rdatasets) + self.nodes[name] = new_node + self.changed.add(name) + return new_node + else: + return node + + def delete_node(self, name): + name = self._validate_name(name) + if name in self.nodes: + del self.nodes[name] + return True + return False + + def put_rdataset(self, name, rdataset): + node = self._maybe_cow(name) + node.replace_rdataset(rdataset) + + def delete_rdataset(self, name, rdtype, covers): + node = self._maybe_cow(name) + if not node.get_rdataset(self.zone.rdclass, rdtype, covers): + return False + node.delete_rdataset(self.zone.rdclass, rdtype, covers) + if len(node) == 0: + del self.nodes[name] + return True + + +@dns.immutable.immutable +class ImmutableVersion(Version): + def __init__(self, version): + # We tell super() that it's a replacement as we don't want it + # to copy the nodes, as we're about to do that with an + # immutable Dict. + super().__init__(version.zone, True) + # set the right id! + self.id = version.id + # Make changed nodes immutable + for name in version.changed: + node = version.nodes.get(name) + # it might not exist if we deleted it in the version + if node: + version.nodes[name] = ImmutableNode(node) + self.nodes = dns.immutable.Dict(version.nodes, True) + + +# A node with a version id. + +class Node(dns.node.Node): + __slots__ = ['id'] + + def __init__(self): + super().__init__() + # A proper id will get set by the Version + self.id = 0 + + +# It would be nice if this were a subclass of Node (just above) but it's +# less code duplication this way as we inherit all of the method disabling +# code. + +@dns.immutable.immutable +class ImmutableNode(dns.node.ImmutableNode): + __slots__ = ['id'] + + def __init__(self, node): + super().__init__(node) + self.id = node.id + + +class Zone(dns.zone.Zone): + + __slots__ = ['versions', '_write_txn', '_write_waiters', '_write_event', + '_pruning_policy'] + + node_factory = Node + + def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True, + pruning_policy=None): + """Initialize a versioned zone object. + + *origin* is the origin of the zone. It may be a ``dns.name.Name``, + a ``str``, or ``None``. If ``None``, then the zone's origin will + be set by the first ``$ORIGIN`` line in a masterfile. + + *rdclass*, an ``int``, the zone's rdata class; the default is class IN. + + *relativize*, a ``bool``, determine's whether domain names are + relativized to the zone's origin. The default is ``True``. + + *pruning policy*, a function taking a `Version` and returning + a `bool`, or `None`. Should the version be pruned? If `None`, + the default policy, which retains one version is used. + """ + super().__init__(origin, rdclass, relativize) + self.versions = collections.deque() + self.version_lock = _threading.Lock() + if pruning_policy is None: + self._pruning_policy = self._default_pruning_policy + else: + self._pruning_policy = pruning_policy + self._write_txn = None + self._write_event = None + self._write_waiters = collections.deque() + self._commit_version_unlocked(WritableVersion(self), origin) + + def reader(self): + with self.version_lock: + return Transaction(False, self, self.versions[-1]) + + def writer(self, replacement=False): + event = None + while True: + with self.version_lock: + # Checking event == self._write_event ensures that either + # no one was waiting before we got lucky and found no write + # txn, or we were the one who was waiting and got woken up. + # This prevents "taking cuts" when creating a write txn. + if self._write_txn is None and event == self._write_event: + # Creating the transaction defers version setup + # (i.e. copying the nodes dictionary) until we + # give up the lock, so that we hold the lock as + # short a time as possible. This is why we call + # _setup_version() below. + self._write_txn = Transaction(replacement, self) + # give up our exclusive right to make a Transaction + self._write_event = None + break + # Someone else is writing already, so we will have to + # wait, but we want to do the actual wait outside the + # lock. + event = _threading.Event() + self._write_waiters.append(event) + # wait (note we gave up the lock!) + # + # We only wake one sleeper at a time, so it's important + # that no event waiter can exit this method (e.g. via + # cancelation) without returning a transaction or waking + # someone else up. + # + # This is not a problem with Threading module threads as + # they cannot be canceled, but could be an issue with trio + # or curio tasks when we do the async version of writer(). + # I.e. we'd need to do something like: + # + # try: + # event.wait() + # except trio.Cancelled: + # with self.version_lock: + # self._maybe_wakeup_one_waiter_unlocked() + # raise + # + event.wait() + # Do the deferred version setup. + self._write_txn._setup_version() + return self._write_txn + + def _maybe_wakeup_one_waiter_unlocked(self): + if len(self._write_waiters) > 0: + self._write_event = self._write_waiters.popleft() + self._write_event.set() + + # pylint: disable=unused-argument + def _default_pruning_policy(self, zone, version): + return True + # pylint: enable=unused-argument + + def _prune_versions_unlocked(self): + while len(self.versions) > 1 and \ + self._pruning_policy(self, self.versions[0]): + self.versions.popleft() + + def set_max_versions(self, max_versions): + """Set a pruning policy that retains up to the specified number + of versions + """ + if max_versions is not None and max_versions < 1: + raise ValueError('max versions must be at least 1') + if max_versions is None: + def policy(*_): + return False + else: + def policy(zone, _): + return len(zone.versions) > max_versions + self.set_pruning_policy(policy) + + def set_pruning_policy(self, policy): + """Set the pruning policy for the zone. + + The *policy* function takes a `Version` and returns `True` if + the version should be pruned, and `False` otherwise. `None` + may also be specified for policy, in which case the default policy + is used. + + Pruning checking proceeds from the least version and the first + time the function returns `False`, the checking stops. I.e. the + retained versions are always a consecutive sequence. + """ + if policy is None: + policy = self._default_pruning_policy + with self.version_lock: + self._pruning_policy = policy + self._prune_versions_unlocked() + + def _commit_version_unlocked(self, version, origin): + self.versions.append(version) + self._prune_versions_unlocked() + self.nodes = version.nodes + if self.origin is None: + self.origin = origin + self._write_txn = None + self._maybe_wakeup_one_waiter_unlocked() + + def _commit_version(self, version, origin): + with self.version_lock: + self._commit_version_unlocked(version, origin) + + def find_node(self, name, create=False): + if create: + raise UseTransaction + return super().find_node(name) + + def delete_node(self, name): + raise UseTransaction + + def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise UseTransaction + rdataset = super().find_rdataset(name, rdtype, covers) + return dns.rdataset.ImmutableRdataset(rdataset) + + def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE, + create=False): + if create: + raise UseTransaction + rdataset = super().get_rdataset(name, rdtype, covers) + return dns.rdataset.ImmutableRdataset(rdataset) + + def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE): + raise UseTransaction + + def replace_rdataset(self, name, replacement): + raise UseTransaction + + +class Transaction(dns.transaction.Transaction): + + def __init__(self, replacement, zone, version=None): + read_only = version is not None + super().__init__(replacement, read_only) + self.zone = zone + self.version = version + + def _setup_version(self): + assert self.version is None + self.version = WritableVersion(self.zone, self.replacement) + + def _get_rdataset(self, name, rdclass, rdtype, covers): + if rdclass != self.zone.rdclass: + raise ValueError(f'class {rdclass} != ' + + f'zone class {self.zone.rdclass}') + return self.version.get_rdataset(name, rdtype, covers) + + def _put_rdataset(self, name, rdataset): + assert not self.read_only + if rdataset.rdclass != self.zone.rdclass: + raise ValueError(f'rdataset class {rdataset.rdclass} != ' + + f'zone class {self.zone.rdclass}') + self.version.put_rdataset(name, rdataset) + + def _delete_name(self, name): + assert not self.read_only + self.version.delete_node(name) + + def _delete_rdataset(self, name, rdclass, rdtype, covers): + assert not self.read_only + self.version.delete_rdataset(name, rdtype, covers) + + def _name_exists(self, name): + return self.version.get_node(name) is not None + + def _end_transaction(self, commit): + if self.read_only: + return + if commit and len(self.version.changed) > 0: + self.zone._commit_version(ImmutableVersion(self.version), + self.version.origin) + + def _set_origin(self, origin): + if self.version.origin is None: + self.version.origin = origin + + def _iterate_rdatasets(self): + for (name, node) in self.version.items(): + for rdataset in node: + yield (name, rdataset) diff --git a/dns/zone.py b/dns/zone.py index d5bb305..2ca9bc2 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -20,10 +20,9 @@ import contextlib import io import os -import re -import sys import dns.exception +import dns.masterfile import dns.name import dns.node import dns.rdataclass @@ -32,6 +31,7 @@ import dns.rdata import dns.rdtypes.ANY.SOA import dns.rrset import dns.tokenizer +import dns.transaction import dns.ttl import dns.grange @@ -56,7 +56,7 @@ class UnknownOrigin(BadZone): """The DNS zone's origin is unknown.""" -class Zone: +class Zone(dns.transaction.TransactionManager): """A DNS zone. @@ -642,415 +642,108 @@ class Zone: if self.get_rdataset(name, dns.rdatatype.NS) is None: raise NoNS + def reader(self): + return Transaction(False, True, self) -class _MasterReader: - - """Read a DNS master file - - @ivar tok: The tokenizer - @type tok: dns.tokenizer.Tokenizer object - @ivar last_ttl: The last seen explicit TTL for an RR - @type last_ttl: int - @ivar last_ttl_known: Has last TTL been detected - @type last_ttl_known: bool - @ivar default_ttl: The default TTL from a $TTL directive or SOA RR - @type default_ttl: int - @ivar default_ttl_known: Has default TTL been detected - @type default_ttl_known: bool - @ivar last_name: The last name read - @type last_name: dns.name.Name object - @ivar current_origin: The current origin - @type current_origin: dns.name.Name object - @ivar relativize: should names in the zone be relativized? - @type relativize: bool - @ivar zone: the zone - @type zone: dns.zone.Zone object - @ivar saved_state: saved reader state (used when processing $INCLUDE) - @type saved_state: list of (tokenizer, current_origin, last_name, file, - last_ttl, last_ttl_known, default_ttl, default_ttl_known) tuples. - @ivar current_file: the file object of the $INCLUDed file being parsed - (None if no $INCLUDE is active). - @ivar allow_include: is $INCLUDE allowed? - @type allow_include: bool - @ivar check_origin: should sanity checks of the origin node be done? - The default is True. - @type check_origin: bool - """ + def writer(self, replacement=False): + return Transaction(replacement, False, self) - def __init__(self, tok, origin, rdclass, relativize, zone_factory=Zone, - allow_include=False, check_origin=True): - if isinstance(origin, str): - origin = dns.name.from_text(origin) - self.tok = tok - self.current_origin = origin - self.relativize = relativize - self.last_ttl = 0 - self.last_ttl_known = False - self.default_ttl = 0 - self.default_ttl_known = False - self.last_name = self.current_origin - self.zone = zone_factory(origin, rdclass, relativize=relativize) - self.saved_state = [] - self.current_file = None - self.allow_include = allow_include - self.check_origin = check_origin - - def _eat_line(self): - while 1: - token = self.tok.get() - if token.is_eol_or_eof(): - break - - def _rr_line(self): - """Process one line from a DNS master file.""" - # Name - if self.current_origin is None: - raise UnknownOrigin - token = self.tok.get(want_leading=True) - if not token.is_whitespace(): - self.last_name = self.tok.as_name(token, self.current_origin) - else: - token = self.tok.get() - if token.is_eol_or_eof(): - # treat leading WS followed by EOL/EOF as if they were EOL/EOF. - return - self.tok.unget(token) - name = self.last_name - if not name.is_subdomain(self.zone.origin): - self._eat_line() - return - if self.relativize: - name = name.relativize(self.zone.origin) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - # TTL - ttl = None - try: - ttl = dns.ttl.from_text(token.value) - self.last_ttl = ttl - self.last_ttl_known = True - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.ttl.BadTTL: - if self.default_ttl_known: - ttl = self.default_ttl - elif self.last_ttl_known: - ttl = self.last_ttl - - # Class - try: - rdclass = dns.rdataclass.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.exception.SyntaxError: - raise - except Exception: - rdclass = self.zone.rdclass - if rdclass != self.zone.rdclass: - raise dns.exception.SyntaxError("RR class is not zone's class") - # Type - try: - rdtype = dns.rdatatype.from_text(token.value) - except Exception: - raise dns.exception.SyntaxError( - "unknown rdatatype '%s'" % token.value) - n = self.zone.nodes.get(name) - if n is None: - n = self.zone.node_factory() - self.zone.nodes[name] = n - try: - rd = dns.rdata.from_text(rdclass, rdtype, self.tok, - self.current_origin, self.relativize, - self.zone.origin) - except dns.exception.SyntaxError: - # Catch and reraise. - raise - except Exception: - # All exceptions that occur in the processing of rdata - # are treated as syntax errors. This is not strictly - # correct, but it is correct almost all of the time. - # We convert them to syntax errors so that we can emit - # helpful filename:line info. - (ty, va) = sys.exc_info()[:2] - raise dns.exception.SyntaxError( - "caught exception {}: {}".format(str(ty), str(va))) - - if not self.default_ttl_known and rdtype == dns.rdatatype.SOA: - # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default - # TTL from the SOA minttl if no $TTL statement is present before the - # SOA is parsed. - self.default_ttl = rd.minimum - self.default_ttl_known = True - if ttl is None: - # if we didn't have a TTL on the SOA, set it! - ttl = rd.minimum - - # TTL check. We had to wait until now to do this as the SOA RR's - # own TTL can be inferred from its minimum. - if ttl is None: - raise dns.exception.SyntaxError("Missing default TTL value") - - covers = rd.covers() - rds = n.find_rdataset(rdclass, rdtype, covers, True) - rds.add(rd, ttl) - - def _parse_modify(self, side): - # Here we catch everything in '{' '}' in a group so we can replace it - # with ''. - is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$") - is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$") - is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$") - # Sometimes there are modifiers in the hostname. These come after - # the dollar sign. They are in the form: ${offset[,width[,base]]}. - # Make names - g1 = is_generate1.match(side) - if g1: - mod, sign, offset, width, base = g1.groups() - if sign == '': - sign = '+' - g2 = is_generate2.match(side) - if g2: - mod, sign, offset = g2.groups() - if sign == '': - sign = '+' - width = 0 - base = 'd' - g3 = is_generate3.match(side) - if g3: - mod, sign, offset, width = g3.groups() - if sign == '': - sign = '+' - base = 'd' - - if not (g1 or g2 or g3): - mod = '' - sign = '+' - offset = 0 - width = 0 - base = 'd' - - if base != 'd': - raise NotImplementedError() - - return mod, sign, offset, width, base - - def _generate_line(self): - # range lhs [ttl] [class] type rhs [ comment ] - """Process one line containing the GENERATE statement from a DNS - master file.""" - if self.current_origin is None: - raise UnknownOrigin - - token = self.tok.get() - # Range (required) - try: - start, stop, step = dns.grange.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except Exception: - raise dns.exception.SyntaxError - - # lhs (required) - try: - lhs = token.value - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except Exception: - raise dns.exception.SyntaxError - - # TTL - try: - ttl = dns.ttl.from_text(token.value) - self.last_ttl = ttl - self.last_ttl_known = True - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.ttl.BadTTL: - if not (self.last_ttl_known or self.default_ttl_known): - raise dns.exception.SyntaxError("Missing default TTL value") - if self.default_ttl_known: - ttl = self.default_ttl - elif self.last_ttl_known: - ttl = self.last_ttl - # Class - try: - rdclass = dns.rdataclass.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except dns.exception.SyntaxError: - raise dns.exception.SyntaxError - except Exception: - rdclass = self.zone.rdclass +class Transaction(dns.transaction.Transaction): + + _deleted_rdataset = dns.rdataset.Rdataset(dns.rdataclass.ANY, + dns.rdatatype.ANY) + + def __init__(self, replacement, read_only, zone): + super().__init__(replacement, read_only) + self.zone = zone + self.rdatasets = {} + + def _get_rdataset(self, name, rdclass, rdtype, covers): if rdclass != self.zone.rdclass: - raise dns.exception.SyntaxError("RR class is not zone's class") - # Type - try: - rdtype = dns.rdatatype.from_text(token.value) - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError - except Exception: - raise dns.exception.SyntaxError("unknown rdatatype '%s'" % - token.value) - - # rhs (required) - rhs = token.value - - # The code currently only supports base 'd', so the last value - # in the tuple _parse_modify returns is ignored - lmod, lsign, loffset, lwidth, _ = self._parse_modify(lhs) - rmod, rsign, roffset, rwidth, _ = self._parse_modify(rhs) - for i in range(start, stop + 1, step): - # +1 because bind is inclusive and python is exclusive - - if lsign == '+': - lindex = i + int(loffset) - elif lsign == '-': - lindex = i - int(loffset) - - if rsign == '-': - rindex = i - int(roffset) - elif rsign == '+': - rindex = i + int(roffset) - - lzfindex = str(lindex).zfill(int(lwidth)) - rzfindex = str(rindex).zfill(int(rwidth)) - - name = lhs.replace('$%s' % (lmod), lzfindex) - rdata = rhs.replace('$%s' % (rmod), rzfindex) - - self.last_name = dns.name.from_text(name, self.current_origin, - self.tok.idna_codec) - name = self.last_name - if not name.is_subdomain(self.zone.origin): - self._eat_line() - return - if self.relativize: - name = name.relativize(self.zone.origin) - - n = self.zone.nodes.get(name) - if n is None: - n = self.zone.node_factory() - self.zone.nodes[name] = n - try: - rd = dns.rdata.from_text(rdclass, rdtype, rdata, - self.current_origin, self.relativize, - self.zone.origin) - except dns.exception.SyntaxError: - # Catch and reraise. - raise - except Exception: - # All exceptions that occur in the processing of rdata - # are treated as syntax errors. This is not strictly - # correct, but it is correct almost all of the time. - # We convert them to syntax errors so that we can emit - # helpful filename:line info. - (ty, va) = sys.exc_info()[:2] - raise dns.exception.SyntaxError("caught exception %s: %s" % - (str(ty), str(va))) - - covers = rd.covers() - rds = n.find_rdataset(rdclass, rdtype, covers, True) - rds.add(rd, ttl) - - def read(self): - """Read a DNS master file and build a zone object. - - @raises dns.zone.NoSOA: No SOA RR was found at the zone origin - @raises dns.zone.NoNS: No NS RRset was found at the zone origin - """ + raise ValueError(f'class {rdclass} != ' + + f'zone class {self.zone.rdclass}') + rdataset = self.rdatasets.get((name, rdtype, covers)) + if rdataset is self._deleted_rdataset: + return None + elif rdataset is None: + rdataset = self.zone.get_rdataset(name, rdtype, covers) + return rdataset + def _put_rdataset(self, name, rdataset): + assert not self.read_only + self.zone._validate_name(name) + if rdataset.rdclass != self.zone.rdclass: + raise ValueError(f'rdataset class {rdataset.rdclass} != ' + + f'zone class {self.zone.rdclass}') + self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset + + def _delete_name(self, name): + assert not self.read_only + # First remove any changes involving the name + remove = [] + for key in self.rdatasets: + if key[0] == name: + remove.append(key) + if len(remove) > 0: + for key in remove: + del self.rdatasets[key] + # Next add deletion records for any rdatasets matching the + # name in the zone + node = self.zone.get_node(name) + if node is not None: + for rdataset in node.rdatasets: + self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \ + self._deleted_rdataset + + def _delete_rdataset(self, name, rdclass, rdtype, covers): + assert not self.read_only + # The high-level code always does a _get_rdataset() before any + # situation where it would call _delete_rdataset(), so we don't + # need to check if rdclass != self.zone.rdclass. try: - while 1: - token = self.tok.get(True, True) - if token.is_eof(): - if self.current_file is not None: - self.current_file.close() - if len(self.saved_state) > 0: - (self.tok, - self.current_origin, - self.last_name, - self.current_file, - self.last_ttl, - self.last_ttl_known, - self.default_ttl, - self.default_ttl_known) = self.saved_state.pop(-1) - continue - break - elif token.is_eol(): - continue - elif token.is_comment(): - self.tok.get_eol() - continue - elif token.value[0] == '$': - c = token.value.upper() - if c == '$TTL': - token = self.tok.get() - if not token.is_identifier(): - raise dns.exception.SyntaxError("bad $TTL") - self.default_ttl = dns.ttl.from_text(token.value) - self.default_ttl_known = True - self.tok.get_eol() - elif c == '$ORIGIN': - self.current_origin = self.tok.get_name() - self.tok.get_eol() - if self.zone.origin is None: - self.zone.origin = self.current_origin - elif c == '$INCLUDE' and self.allow_include: - token = self.tok.get() - filename = token.value - token = self.tok.get() - if token.is_identifier(): - new_origin =\ - dns.name.from_text(token.value, - self.current_origin, - self.tok.idna_codec) - self.tok.get_eol() - elif not token.is_eol_or_eof(): - raise dns.exception.SyntaxError( - "bad origin in $INCLUDE") - else: - new_origin = self.current_origin - self.saved_state.append((self.tok, - self.current_origin, - self.last_name, - self.current_file, - self.last_ttl, - self.last_ttl_known, - self.default_ttl, - self.default_ttl_known)) - self.current_file = open(filename, 'r') - self.tok = dns.tokenizer.Tokenizer(self.current_file, - filename) - self.current_origin = new_origin - elif c == '$GENERATE': - self._generate_line() - else: - raise dns.exception.SyntaxError( - "Unknown master file directive '" + c + "'") - continue - self.tok.unget(token) - self._rr_line() - except dns.exception.SyntaxError as detail: - (filename, line_number) = self.tok.where() - if detail is None: - detail = "syntax error" - ex = dns.exception.SyntaxError( - "%s:%d: %s" % (filename, line_number, detail)) - tb = sys.exc_info()[2] - raise ex.with_traceback(tb) from None - - # Now that we're done reading, do some basic checking of the zone. - if self.check_origin: - self.zone.check_origin() + del self.rdatasets[(name, rdtype, covers)] + except KeyError: + pass + rdataset = self.zone.get_rdataset(name, rdtype, covers) + if rdataset is not None: + self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \ + self._deleted_rdataset + + def _name_exists(self, name): + for key, rdataset in self.rdatasets.items(): + if key[0] == name: + if rdataset != self._deleted_rdataset: + return True + else: + return None + self.zone._validate_name(name) + if self.zone.get_node(name): + return True + return False + + def _end_transaction(self, commit): + if commit and not self.read_only: + for (name, rdtype, covers), rdataset in \ + self.rdatasets.items(): + if rdataset is self._deleted_rdataset: + self.zone.delete_rdataset(name, rdtype, covers) + else: + self.zone.replace_rdataset(name, rdataset) + + def _set_origin(self, origin): + if self.zone.origin is None: + self.zone.origin = origin + + def _iterate_rdatasets(self): + # Expensive but simple! Use a versioned zone for efficient txn + # iteration. + rdatasets = {} + for (name, rdataset) in self.zone.iterate_rdatasets(): + rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset + rdatasets.update(self.rdatasets) + for (name, _, _), rdataset in rdatasets.items(): + yield (name, rdataset) def from_text(text, origin=None, rdclass=dns.rdataclass.IN, @@ -1103,12 +796,20 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN, if filename is None: filename = '' - tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) - reader = _MasterReader(tok, origin, rdclass, relativize, zone_factory, - allow_include=allow_include, - check_origin=check_origin) - reader.read() - return reader.zone + zone = zone_factory(origin, rdclass, relativize=relativize) + with zone.writer(True) as txn: + tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec) + reader = dns.masterfile.Reader(tok, origin, rdclass, relativize, txn, + allow_include=allow_include) + try: + reader.read() + except dns.masterfile.UnknownOrigin: + # for backwards compatibility + raise dns.zone.UnknownOrigin + # Now that we're done reading, do some basic checking of the zone. + if check_origin: + zone.check_origin() + return zone def from_file(f, origin=None, rdclass=dns.rdataclass.IN, diff --git a/tests/test_immutable.py b/tests/test_immutable.py index 0385fc9..1a70e3d 100644 --- a/tests/test_immutable.py +++ b/tests/test_immutable.py @@ -3,20 +3,30 @@ import unittest import dns.immutable +import dns._immutable_attr + +try: + import dns._immutable_ctx as immutable_ctx + _have_contextvars = True +except ImportError: + _have_contextvars = False + + class immutable_ctx: + pass class ImmutableTestCase(unittest.TestCase): - def test_ImmutableDict_hash(self): - d1 = dns.immutable.ImmutableDict({'a': 1, 'b': 2}) - d2 = dns.immutable.ImmutableDict({'b': 2, 'a': 1}) + def test_immutable_dict_hash(self): + d1 = dns.immutable.Dict({'a': 1, 'b': 2}) + d2 = dns.immutable.Dict({'b': 2, 'a': 1}) d3 = {'b': 2, 'a': 1} self.assertEqual(d1, d2) self.assertEqual(d2, d3) self.assertEqual(hash(d1), hash(d2)) - def test_ImmutableDict_hash_cache(self): - d = dns.immutable.ImmutableDict({'a': 1, 'b': 2}) + def test_immutable_dict_hash_cache(self): + d = dns.immutable.Dict({'a': 1, 'b': 2}) self.assertEqual(d._hash, None) h1 = hash(d) self.assertEqual(d._hash, h1) @@ -30,11 +40,121 @@ class ImmutableTestCase(unittest.TestCase): ((1, [2], 3), (1, (2,), 3)), ([1, 2, 3], (1, 2, 3)), ([1, {'a': [1, 2]}], - (1, dns.immutable.ImmutableDict({'a': (1, 2)}))), + (1, dns.immutable.Dict({'a': (1, 2)}))), ('hi', 'hi'), (b'hi', b'hi'), ) for input, expected in items: self.assertEqual(dns.immutable.constify(input), expected) self.assertIsInstance(dns.immutable.constify({'a': 1}), - dns.immutable.ImmutableDict) + dns.immutable.Dict) + + +class DecoratorTestCase(unittest.TestCase): + + immutable_module = dns._immutable_attr + + def make_classes(self): + class A: + def __init__(self, a, akw=10): + self.a = a + self.akw = akw + + class B(A): + def __init__(self, a, b): + super().__init__(a, akw=20) + self.b = b + B = self.immutable_module.immutable(B) + + # note C is immutable by inheritance + class C(B): + def __init__(self, a, b, c): + super().__init__(a, b) + self.c = c + C = self.immutable_module.immutable(C) + + class SA: + __slots__ = ('a', 'akw') + def __init__(self, a, akw=10): + self.a = a + self.akw = akw + + class SB(A): + __slots__ = ('b') + def __init__(self, a, b): + super().__init__(a, akw=20) + self.b = b + SB = self.immutable_module.immutable(SB) + + # note SC is immutable by inheritance and has no slots of its own + class SC(SB): + def __init__(self, a, b, c): + super().__init__(a, b) + self.c = c + SC = self.immutable_module.immutable(SC) + + return ((A, B, C), (SA, SB, SC)) + + def test_basic(self): + for A, B, C in self.make_classes(): + a = A(1) + self.assertEqual(a.a, 1) + self.assertEqual(a.akw, 10) + b = B(11, 21) + self.assertEqual(b.a, 11) + self.assertEqual(b.akw, 20) + self.assertEqual(b.b, 21) + c = C(111, 211, 311) + self.assertEqual(c.a, 111) + self.assertEqual(c.akw, 20) + self.assertEqual(c.b, 211) + self.assertEqual(c.c, 311) + # changing A is ok! + a.a = 11 + self.assertEqual(a.a, 11) + # changing B is not! + with self.assertRaises(TypeError): + b.a = 11 + with self.assertRaises(TypeError): + del b.a + + def test_constructor_deletes_attribute(self): + class A: + def __init__(self, a): + self.a = a + self.b = a + del self.b + A = self.immutable_module.immutable(A) + a = A(10) + self.assertEqual(a.a, 10) + self.assertFalse(hasattr(a, 'b')) + + def test_no_collateral_damage(self): + + # A and B are immutable but not related. The magic that lets + # us write to immutable things while initializing B should not let + # B mess with A. + + class A: + def __init__(self, a): + self.a = a + A = self.immutable_module.immutable(A) + + class B: + def __init__(self, a, b): + self.b = a.a + b + # rudely attempt to mutate innocent immutable bystander 'a' + a.a = 1000 + B = self.immutable_module.immutable(B) + + a = A(10) + self.assertEqual(a.a, 10) + with self.assertRaises(TypeError): + B(a, 20) + self.assertEqual(a.a, 10) + + +@unittest.skipIf(not _have_contextvars, "contextvars not available") +class CtxDecoratorTestCase(DecoratorTestCase): + + immutable_module = immutable_ctx diff --git a/tests/test_rdataset.py b/tests/test_rdataset.py index a80d650..88b4840 100644 --- a/tests/test_rdataset.py +++ b/tests/test_rdataset.py @@ -122,5 +122,34 @@ class RdatasetTestCase(unittest.TestCase): ' 0: + for key in remove: + del self.rdatasets[key] + + def _delete_rdataset(self, name, rdclass, rdtype, covers): + del self.rdatasets[(name, rdclass, rdtype, covers)] + + def _name_exists(self, name): + for key in self.rdatasets.keys(): + if key[0] == name: + return True + return False + + def _end_transaction(self, commit): + if commit: + self.db.rdatasets = self.rdatasets + + def _set_origin(self, origin): + pass + +@pytest.fixture +def db(): + db = DB() + rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content') + db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] = rrset + return db + +def test_basic(db): + # successful txn + with db.writer() as txn: + rrset = dns.rrset.from_text('foo', 300, 'in', 'a', + '10.0.0.1', '10.0.0.2') + txn.add(rrset) + assert txn.name_exists(rrset.name) + assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \ + rrset + # rollback + with pytest.raises(Exception): + with db.writer() as txn: + rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a', + '10.0.0.3', '10.0.0.4') + txn.add(rrset2) + raise Exception() + assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \ + rrset + with db.writer() as txn: + txn.delete(rrset.name) + assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \ + is None + +def test_get(db): + with db.writer() as txn: + content = dns.name.from_text('content', None) + rdataset = txn.get(content, dns.rdataclass.IN, dns.rdatatype.TXT) + assert rdataset is not None + assert rdataset[0].strings == (b'content',) + assert isinstance(rdataset, dns.rdataset.ImmutableRdataset) + +def test_add(db): + with db.writer() as txn: + rrset = dns.rrset.from_text('foo', 300, 'in', 'a', + '10.0.0.1', '10.0.0.2') + txn.add(rrset) + rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a', + '10.0.0.3', '10.0.0.4') + txn.add(rrset2) + expected = dns.rrset.from_text('foo', 300, 'in', 'a', + '10.0.0.1', '10.0.0.2', + '10.0.0.3', '10.0.0.4') + assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \ + expected + +def test_replacement(db): + with db.writer() as txn: + rrset = dns.rrset.from_text('foo', 300, 'in', 'a', + '10.0.0.1', '10.0.0.2') + txn.add(rrset) + rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a', + '10.0.0.3', '10.0.0.4') + txn.replace(rrset2) + assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \ + rrset2 + +def test_delete(db): + with db.writer() as txn: + txn.delete(dns.name.from_text('nonexistent', None)) + content = dns.name.from_text('content', None) + content2 = dns.name.from_text('content2', None) + txn.delete(content) + assert not txn.name_exists(content) + txn.delete(content2, dns.rdataclass.IN, dns.rdatatype.TXT) + rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'new-content') + txn.add(rrset) + assert txn.name_exists(content) + txn.delete(content, dns.rdataclass.IN, dns.rdatatype.TXT) + assert not txn.name_exists(content) + rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'new-content') + txn.delete(rrset) + content_keys = [k for k in db.rdatasets if k[0] == content] + assert len(content_keys) == 0 + +def test_delete_exact(db): + with db.writer() as txn: + rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'bad-content') + with pytest.raises(dns.transaction.DeleteNotExact): + txn.delete_exact(rrset) + rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'bad-content') + with pytest.raises(dns.transaction.DeleteNotExact): + txn.delete_exact(rrset) + with pytest.raises(dns.transaction.DeleteNotExact): + txn.delete_exact(rrset.name) + with pytest.raises(dns.transaction.DeleteNotExact): + txn.delete_exact(rrset.name, dns.rdataclass.IN, dns.rdatatype.TXT) + rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content') + txn.delete_exact(rrset) + assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \ + is None + +def test_parameter_forms(db): + with db.writer() as txn: + foo = dns.name.from_text('foo', None) + rdataset = dns.rdataset.from_text('in', 'a', 300, + '10.0.0.1', '10.0.0.2') + rdata1 = dns.rdata.from_text('in', 'a', '10.0.0.3') + rdata2 = dns.rdata.from_text('in', 'a', '10.0.0.4') + txn.add(foo, rdataset) + txn.add(foo, 100, rdata1) + txn.add(foo, 30, rdata2) + expected = dns.rrset.from_text('foo', 30, 'in', 'a', + '10.0.0.1', '10.0.0.2', + '10.0.0.3', '10.0.0.4') + assert db.rdatasets[(foo, rdataset.rdclass, rdataset.rdtype, 0)] == \ + expected + with db.writer() as txn: + txn.delete(foo, rdataset) + txn.delete(foo, rdata1) + txn.delete(foo, rdata2) + assert db.rdatasets.get((foo, rdataset.rdclass, rdataset.rdtype, 0)) \ + is None + +def test_bad_parameters(db): + with db.writer() as txn: + with pytest.raises(TypeError): + txn.add(1) + with pytest.raises(TypeError): + rrset = dns.rrset.from_text('bar', 300, 'in', 'txt', 'bar') + txn.add(rrset, 1) + with pytest.raises(ValueError): + foo = dns.name.from_text('foo', None) + rdata = dns.rdata.from_text('in', 'a', '10.0.0.3') + txn.add(foo, 0x80000000, rdata) + with pytest.raises(TypeError): + txn.add(foo) + with pytest.raises(TypeError): + txn.add() + with pytest.raises(TypeError): + txn.add(foo, 300) + with pytest.raises(TypeError): + txn.add(foo, 300, 'hi') + with pytest.raises(TypeError): + txn.add(foo, 'hi') + with pytest.raises(TypeError): + txn.delete() + with pytest.raises(TypeError): + txn.delete(1) + +example_text = """$TTL 3600 +$ORIGIN example. +@ soa foo bar 1 2 3 4 5 +@ ns ns1 +@ ns ns2 +ns1 a 10.0.0.1 +ns2 a 10.0.0.2 +$TTL 300 +$ORIGIN foo.example. +bar mx 0 blaz +""" + +example_text_output = """@ 3600 IN SOA foo bar 1 2 3 4 5 +@ 3600 IN NS ns1 +@ 3600 IN NS ns2 +@ 3600 IN NS ns3 +ns1 3600 IN A 10.0.0.1 +ns2 3600 IN A 10.0.0.2 +ns3 3600 IN A 10.0.0.3 +""" + +@pytest.fixture(params=[dns.zone.Zone, dns.versioned.Zone]) +def zone(request): + return dns.zone.from_text(example_text, zone_factory=request.param) + +def test_zone_basic(zone): + with zone.writer() as txn: + txn.delete(dns.name.from_text('bar.foo', None)) + rd = dns.rdata.from_text('in', 'ns', 'ns3') + txn.add(dns.name.empty, 3600, rd) + rd = dns.rdata.from_text('in', 'a', '10.0.0.3') + txn.add(dns.name.from_text('ns3', None), 3600, rd) + output = zone.to_text() + assert output == example_text_output + +def test_zone_base_layer(zone): + with zone.writer() as txn: + # Get a set from the zone layer + rdataset = txn.get(dns.name.empty, dns.rdataclass.IN, + dns.rdatatype.NS, dns.rdatatype.NONE) + expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2') + assert rdataset == expected + +def test_zone_transaction_layer(zone): + with zone.writer() as txn: + # Make a change + rd = dns.rdata.from_text('in', 'ns', 'ns3') + txn.add(dns.name.empty, 3600, rd) + # Get a set from the transaction layer + expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2', 'ns3') + rdataset = txn.get(dns.name.empty, dns.rdataclass.IN, + dns.rdatatype.NS, dns.rdatatype.NONE) + assert rdataset == expected + assert txn.name_exists(dns.name.empty) + ns1 = dns.name.from_text('ns1', None) + assert txn.name_exists(ns1) + ns99 = dns.name.from_text('ns99', None) + assert not txn.name_exists(ns99) + +def test_zone_add_and_delete(zone): + with zone.writer() as txn: + a99 = dns.name.from_text('a99', None) + a100 = dns.name.from_text('a100', None) + a101 = dns.name.from_text('a101', None) + rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') + txn.add(a99, rds) + txn.delete(a99, dns.rdataclass.IN, dns.rdatatype.A) + txn.delete(a100, dns.rdataclass.IN, dns.rdatatype.A) + txn.delete(a101) + assert not txn.name_exists(a99) + assert not txn.name_exists(a100) + assert not txn.name_exists(a101) + ns1 = dns.name.from_text('ns1', None) + txn.delete(ns1, dns.rdataclass.IN, dns.rdatatype.A) + assert not txn.name_exists(ns1) + with zone.writer() as txn: + txn.add(a99, rds) + txn.delete(a99) + assert not txn.name_exists(a99) + with zone.writer() as txn: + txn.add(a100, rds) + txn.delete(a99) + assert not txn.name_exists(a99) + assert txn.name_exists(a100) + +def test_zone_get_deleted(zone): + with zone.writer() as txn: + print(zone.to_text()) + ns1 = dns.name.from_text('ns1', None) + assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is not None + txn.delete(ns1) + assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is None + ns2 = dns.name.from_text('ns2', None) + txn.delete(ns2, dns.rdataclass.IN, dns.rdatatype.A) + assert txn.get(ns2, dns.rdataclass.IN, dns.rdatatype.A) is None + +def test_zone_bad_class(zone): + with zone.writer() as txn: + with pytest.raises(ValueError): + txn.get(dns.name.empty, dns.rdataclass.CH, + dns.rdatatype.NS, dns.rdatatype.NONE) + rds = dns.rdataset.from_text('ch', 'ns', 300, 'ns1', 'ns2') + with pytest.raises(ValueError): + txn.add(dns.name.empty, rds) + with pytest.raises(ValueError): + txn.replace(dns.name.empty, rds) + with pytest.raises(ValueError): + txn.delete(dns.name.empty, rds) + with pytest.raises(ValueError): + txn.delete(dns.name.empty, dns.rdataclass.CH, + dns.rdatatype.NS, dns.rdatatype.NONE) + +def test_set_serial(zone): + # basic + with zone.writer() as txn: + txn.set_serial() + rdataset = zone.find_rdataset('@', 'soa') + assert rdataset[0].serial == 2 + # max + with zone.writer() as txn: + txn.set_serial(0, 0xffffffff) + rdataset = zone.find_rdataset('@', 'soa') + assert rdataset[0].serial == 0xffffffff + # wraparound to 1 + with zone.writer() as txn: + txn.set_serial() + rdataset = zone.find_rdataset('@', 'soa') + assert rdataset[0].serial == 1 + # trying to set to zero sets to 1 + with zone.writer() as txn: + txn.set_serial(0, 0) + rdataset = zone.find_rdataset('@', 'soa') + assert rdataset[0].serial == 1 + with pytest.raises(KeyError): + with zone.writer() as txn: + txn.set_serial(name=dns.name.from_text('unknown', None)) + +class ExpectedException(Exception): + pass + +def test_zone_rollback(zone): + try: + with zone.writer() as txn: + a99 = dns.name.from_text('a99.example.') + rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') + txn.add(a99, rds) + assert txn.name_exists(a99) + raise ExpectedException + except ExpectedException: + pass + assert not zone.get_node(a99) + +def test_zone_ooz_name(zone): + with zone.writer() as txn: + with pytest.raises(KeyError): + a99 = dns.name.from_text('a99.not-example.') + assert txn.name_exists(a99) + +def test_zone_iteration(zone): + expected = {} + for (name, rdataset) in zone.iterate_rdatasets(): + expected[(name, rdataset.rdtype, rdataset.covers)] = rdataset + with zone.writer() as txn: + actual = {} + for (name, rdataset) in txn: + actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset + assert actual == expected + +@pytest.fixture +def vzone(): + return dns.zone.from_text(example_text, zone_factory=dns.versioned.Zone) + +def test_vzone_read_only(vzone): + with vzone.reader() as txn: + rdataset = txn.get(dns.name.empty, dns.rdataclass.IN, + dns.rdatatype.NS, dns.rdatatype.NONE) + expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2') + assert rdataset == expected + with pytest.raises(dns.transaction.ReadOnly): + txn.replace(dns.name.empty, expected) + +def test_vzone_multiple_versions(vzone): + assert len(vzone.versions) == 1 + vzone.set_max_versions(None) # unlimited! + with vzone.writer() as txn: + txn.set_serial() + with vzone.writer() as txn: + txn.set_serial() + with vzone.writer() as txn: + txn.set_serial() + rdataset = vzone.find_rdataset('@', 'soa') + assert rdataset[0].serial == 4 + assert len(vzone.versions) == 4 + vzone.set_max_versions(2) + assert len(vzone.versions) == 2 + # The ones that survived should be 3 and 4 + rdataset = vzone.versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA, + dns.rdatatype.NONE) + assert rdataset[0].serial == 3 + rdataset = vzone.versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA, + dns.rdatatype.NONE) + assert rdataset[0].serial == 4 + with pytest.raises(ValueError): + vzone.set_max_versions(0) + +try: + import threading + + one_got_lock = threading.Event() + + def run_one(zone): + with zone.writer() as txn: + one_got_lock.set() + # wait until two blocks + while len(zone._write_waiters) == 0: + time.sleep(0.01) + rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.98') + txn.add('a98', rds) + + def run_two(zone): + # wait until one has the lock so we know we will block if we + # get the call done before the sleep in one completes + one_got_lock.wait() + with zone.writer() as txn: + rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99') + txn.add('a99', rds) + + def test_vzone_concurrency(vzone): + t1 = threading.Thread(target=run_one, args=(vzone,)) + t1.start() + t2 = threading.Thread(target=run_two, args=(vzone,)) + t2.start() + t1.join() + t2.join() + with vzone.reader() as txn: + assert txn.name_exists('a98') + assert txn.name_exists('a99') + +except ImportError: # pragma: no cover + pass -- cgit v1.2.1