summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-07-30 09:21:03 -0700
committerBob Halley <halley@dnspython.org>2020-08-10 06:43:32 -0700
commit54deb97c2a5331fe99a12d720f24fb481ec31576 (patch)
tree0954caa36095efcc667aee6a3732759d4c9001c6
parent26fd19690c44a01c84c27a2d4244d2e5dc7b7a19 (diff)
downloaddnspython-54deb97c2a5331fe99a12d720f24fb481ec31576.tar.gz
txn checkpointtransaction
-rw-r--r--dns/__init__.py4
-rw-r--r--dns/_immutable_attr.py64
-rw-r--r--dns/_immutable_ctx.py58
-rw-r--r--dns/exception.py2
-rw-r--r--dns/immutable.py17
-rw-r--r--dns/masterfile.py404
-rw-r--r--dns/node.py30
-rw-r--r--dns/rdataset.py47
-rw-r--r--dns/transaction.py383
-rw-r--r--dns/versioned.py392
-rw-r--r--dns/zone.py525
-rw-r--r--tests/test_immutable.py134
-rw-r--r--tests/test_rdataset.py29
-rw-r--r--tests/test_transaction.py451
14 files changed, 2115 insertions, 425 deletions
diff --git a/dns/__init__.py b/dns/__init__.py
index b944701..eafdcc4 100644
--- a/dns/__init__.py
+++ b/dns/__init__.py
@@ -27,9 +27,11 @@ __all__ = [
'entropy',
'exception',
'flags',
+ 'immutable',
'inet',
'ipv4',
'ipv6',
+ 'masterfile',
'message',
'name',
'namedict',
@@ -48,12 +50,14 @@ __all__ = [
'serial',
'set',
'tokenizer',
+ 'transaction',
'tsig',
'tsigkeyring',
'ttl',
'rdtypes',
'update',
'version',
+ 'versioned',
'wire',
'zone',
]
diff --git a/dns/_immutable_attr.py b/dns/_immutable_attr.py
new file mode 100644
index 0000000..4221967
--- /dev/null
+++ b/dns/_immutable_attr.py
@@ -0,0 +1,64 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# This implementation of the immutable decorator is for python 3.6,
+# which doesn't have Context Variables. This implementation is somewhat
+# costly for classes with slots, as it adds a __dict__ to them.
+
+class _Immutable:
+ """Immutable mixin class"""
+
+ # Note we MUST NOT have __slots__ as that causes
+ #
+ # TypeError: multiple bases have instance lay-out conflict
+ #
+ # when we get mixed in with another class with slots. When we
+ # get mixed into something with slots, it effectively adds __dict__ to
+ # the slots of the other class, which allows attribute setting to work,
+ # albeit at the cost of the dictionary.
+
+ def __setattr__(self, name, value):
+ if not hasattr(self, '_immutable_init') or \
+ self._immutable_init is not self:
+ raise TypeError("object doesn't support attribute assignment")
+ else:
+ super().__setattr__(name, value)
+
+ def __delattr__(self, name):
+ if not hasattr(self, '_immutable_init') or \
+ self._immutable_init is not self:
+ raise TypeError("object doesn't support attribute assignment")
+ else:
+ super().__delattr__(name)
+
+
+def _immutable_init(f):
+ def nf(*args, **kwargs):
+ try:
+ # Are we already initializing an immutable class?
+ previous = args[0]._immutable_init
+ except AttributeError:
+ # We are the first!
+ previous = None
+ object.__setattr__(args[0], '_immutable_init', args[0])
+ # call the actual __init__
+ f(*args, **kwargs)
+ if not previous:
+ # If we started the initialzation, establish immutability
+ # by removing the attribute that allows mutation
+ object.__delattr__(args[0], '_immutable_init')
+ return nf
+
+
+def immutable(cls):
+ if _Immutable in cls.__mro__:
+ # Some ancestor already has the mixin, so just make sure we keep
+ # following the __init__ protocol.
+ cls.__init__ = _immutable_init(cls.__init__)
+ ncls = cls
+ else:
+ # Mixin the Immutable class and follow the __init__ protocol.
+ class ncls(_Immutable, cls):
+ @_immutable_init
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ return ncls
diff --git a/dns/_immutable_ctx.py b/dns/_immutable_ctx.py
new file mode 100644
index 0000000..017310d
--- /dev/null
+++ b/dns/_immutable_ctx.py
@@ -0,0 +1,58 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# This implementation of the immutable decorator requires python >=
+# 3.7, and is significantly more storage efficient when making classes
+# with slots immutable. It's also faster.
+
+import contextvars
+
+_in__init__ = contextvars.ContextVar('_immutable_in__init__', default=False)
+
+
+class _Immutable:
+ """Immutable mixin class"""
+
+ # We set slots to the empty list to say "we don't have any attributes".
+ # We do this so that if we're mixed in with a class with __slots__, we
+ # don't cause a __dict__ to be added which would waste space.
+
+ __slots__ = ()
+
+ def __setattr__(self, name, value):
+ if _in__init__.get() is not self:
+ raise TypeError("object doesn't support attribute assignment")
+ else:
+ super().__setattr__(name, value)
+
+ def __delattr__(self, name):
+ if _in__init__.get() is not self:
+ raise TypeError("object doesn't support attribute assignment")
+ else:
+ super().__delattr__(name)
+
+
+def _immutable_init(f):
+ def nf(*args, **kwargs):
+ previous = _in__init__.set(args[0])
+ # call the actual __init__
+ f(*args, **kwargs)
+ _in__init__.reset(previous)
+ return nf
+
+
+def immutable(cls):
+ if _Immutable in cls.__mro__:
+ # Some ancestor already has the mixin, so just make sure we keep
+ # following the __init__ protocol.
+ cls.__init__ = _immutable_init(cls.__init__)
+ ncls = cls
+ else:
+ # Mixin the Immutable class and follow the __init__ protocol.
+ class ncls(_Immutable, cls):
+ # We have to do the __slots__ declaration here too!
+ __slots__ = ()
+
+ @_immutable_init
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ return ncls
diff --git a/dns/exception.py b/dns/exception.py
index 9486f45..9392373 100644
--- a/dns/exception.py
+++ b/dns/exception.py
@@ -138,5 +138,5 @@ class ExceptionWrapper:
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None and not isinstance(exc_val,
self.exception_class):
- raise self.exception_class() from exc_val
+ raise self.exception_class(str(exc_val)) from exc_val
return False
diff --git a/dns/immutable.py b/dns/immutable.py
index dc48fe8..7cc39dd 100644
--- a/dns/immutable.py
+++ b/dns/immutable.py
@@ -3,13 +3,19 @@
import collections.abc
import sys
+# pylint: disable=unused-import
if sys.version_info >= (3, 7):
odict = dict
+ from dns._immutable_ctx import immutable
else:
- from collections import OrderedDict as odict # pragma: no cover
+ # pragma: no cover
+ from collections import OrderedDict as odict
+ from dns._immutable_attr import immutable # noqa
+# pylint: enable=unused-import
-class ImmutableDict(collections.abc.Mapping):
+@immutable
+class Dict(collections.abc.Mapping):
def __init__(self, dictionary, no_copy=False):
"""Make an immutable dictionary from the specified dictionary.
@@ -28,9 +34,10 @@ class ImmutableDict(collections.abc.Mapping):
def __hash__(self):
if self._hash is None:
- self._hash = 0
+ h = 0
for key in sorted(self._odict.keys()):
- self._hash ^= hash(key)
+ h ^= hash(key)
+ object.__setattr__(self, '_hash', h)
return self._hash
def __len__(self):
@@ -58,5 +65,5 @@ def constify(o):
cdict = odict()
for k, v in o.items():
cdict[k] = constify(v)
- return ImmutableDict(cdict, True)
+ return Dict(cdict, True)
return o
diff --git a/dns/masterfile.py b/dns/masterfile.py
new file mode 100644
index 0000000..30553b5
--- /dev/null
+++ b/dns/masterfile.py
@@ -0,0 +1,404 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+# Copyright (C) 2003-2007, 2009-2011 Nominum, Inc.
+#
+# Permission to use, copy, modify, and distribute this software and its
+# documentation for any purpose with or without fee is hereby granted,
+# provided that the above copyright notice and this permission notice
+# appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
+# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+"""DNS Zones."""
+
+import re
+import sys
+
+import dns.exception
+import dns.name
+import dns.node
+import dns.rdataclass
+import dns.rdatatype
+import dns.rdata
+import dns.rdtypes.ANY.SOA
+import dns.rrset
+import dns.tokenizer
+import dns.transaction
+import dns.ttl
+import dns.grange
+
+
+class UnknownOrigin(dns.exception.DNSException):
+ """Unknown origin"""
+
+
+class Reader:
+
+ """Read a DNS master file into a transaction."""
+
+ def __init__(self, tok, origin, rdclass, relativize, txn,
+ allow_include=False):
+ if isinstance(origin, str):
+ origin = dns.name.from_text(origin)
+ self.tok = tok
+ self.current_origin = origin
+ self.relativize = relativize
+ self.last_ttl = 0
+ self.last_ttl_known = False
+ self.default_ttl = 0
+ self.default_ttl_known = False
+ self.last_name = self.current_origin
+ self.zone_origin = origin
+ self.zone_rdclass = rdclass
+ self.txn = txn
+ self.saved_state = []
+ self.current_file = None
+ self.allow_include = allow_include
+
+ def _eat_line(self):
+ while 1:
+ token = self.tok.get()
+ if token.is_eol_or_eof():
+ break
+
+ def _rr_line(self):
+ """Process one line from a DNS master file."""
+ # Name
+ if self.current_origin is None:
+ raise UnknownOrigin
+ token = self.tok.get(want_leading=True)
+ if not token.is_whitespace():
+ self.last_name = self.tok.as_name(token, self.current_origin)
+ else:
+ token = self.tok.get()
+ if token.is_eol_or_eof():
+ # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
+ return
+ self.tok.unget(token)
+ name = self.last_name
+ if not name.is_subdomain(self.zone_origin):
+ self._eat_line()
+ return
+ if self.relativize:
+ name = name.relativize(self.zone_origin)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+
+ # TTL
+ ttl = None
+ try:
+ ttl = dns.ttl.from_text(token.value)
+ self.last_ttl = ttl
+ self.last_ttl_known = True
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.ttl.BadTTL:
+ if self.default_ttl_known:
+ ttl = self.default_ttl
+ elif self.last_ttl_known:
+ ttl = self.last_ttl
+
+ # Class
+ try:
+ rdclass = dns.rdataclass.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.exception.SyntaxError:
+ raise
+ except Exception:
+ rdclass = self.zone_rdclass
+ if rdclass != self.zone_rdclass:
+ raise dns.exception.SyntaxError("RR class is not zone's class")
+ # Type
+ try:
+ rdtype = dns.rdatatype.from_text(token.value)
+ except Exception:
+ raise dns.exception.SyntaxError(
+ "unknown rdatatype '%s'" % token.value)
+ try:
+ rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
+ self.current_origin, self.relativize,
+ self.zone_origin)
+ except dns.exception.SyntaxError:
+ # Catch and reraise.
+ raise
+ except Exception:
+ # All exceptions that occur in the processing of rdata
+ # are treated as syntax errors. This is not strictly
+ # correct, but it is correct almost all of the time.
+ # We convert them to syntax errors so that we can emit
+ # helpful filename:line info.
+ (ty, va) = sys.exc_info()[:2]
+ raise dns.exception.SyntaxError(
+ "caught exception {}: {}".format(str(ty), str(va)))
+
+ if not self.default_ttl_known and rdtype == dns.rdatatype.SOA:
+ # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
+ # TTL from the SOA minttl if no $TTL statement is present before the
+ # SOA is parsed.
+ self.default_ttl = rd.minimum
+ self.default_ttl_known = True
+ if ttl is None:
+ # if we didn't have a TTL on the SOA, set it!
+ ttl = rd.minimum
+
+ # TTL check. We had to wait until now to do this as the SOA RR's
+ # own TTL can be inferred from its minimum.
+ if ttl is None:
+ raise dns.exception.SyntaxError("Missing default TTL value")
+
+ self.txn.add(name, ttl, rd)
+
+ def _parse_modify(self, side):
+ # Here we catch everything in '{' '}' in a group so we can replace it
+ # with ''.
+ is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$")
+ is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$")
+ is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$")
+ # Sometimes there are modifiers in the hostname. These come after
+ # the dollar sign. They are in the form: ${offset[,width[,base]]}.
+ # Make names
+ g1 = is_generate1.match(side)
+ if g1:
+ mod, sign, offset, width, base = g1.groups()
+ if sign == '':
+ sign = '+'
+ g2 = is_generate2.match(side)
+ if g2:
+ mod, sign, offset = g2.groups()
+ if sign == '':
+ sign = '+'
+ width = 0
+ base = 'd'
+ g3 = is_generate3.match(side)
+ if g3:
+ mod, sign, offset, width = g3.groups()
+ if sign == '':
+ sign = '+'
+ base = 'd'
+
+ if not (g1 or g2 or g3):
+ mod = ''
+ sign = '+'
+ offset = 0
+ width = 0
+ base = 'd'
+
+ if base != 'd':
+ raise NotImplementedError()
+
+ return mod, sign, offset, width, base
+
+ def _generate_line(self):
+ # range lhs [ttl] [class] type rhs [ comment ]
+ """Process one line containing the GENERATE statement from a DNS
+ master file."""
+ if self.current_origin is None:
+ raise UnknownOrigin
+
+ token = self.tok.get()
+ # Range (required)
+ try:
+ start, stop, step = dns.grange.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except Exception:
+ raise dns.exception.SyntaxError
+
+ # lhs (required)
+ try:
+ lhs = token.value
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except Exception:
+ raise dns.exception.SyntaxError
+
+ # TTL
+ try:
+ ttl = dns.ttl.from_text(token.value)
+ self.last_ttl = ttl
+ self.last_ttl_known = True
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.ttl.BadTTL:
+ if not (self.last_ttl_known or self.default_ttl_known):
+ raise dns.exception.SyntaxError("Missing default TTL value")
+ if self.default_ttl_known:
+ ttl = self.default_ttl
+ elif self.last_ttl_known:
+ ttl = self.last_ttl
+ # Class
+ try:
+ rdclass = dns.rdataclass.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except dns.exception.SyntaxError:
+ raise dns.exception.SyntaxError
+ except Exception:
+ rdclass = self.zone_rdclass
+ if rdclass != self.zone_rdclass:
+ raise dns.exception.SyntaxError("RR class is not zone's class")
+ # Type
+ try:
+ rdtype = dns.rdatatype.from_text(token.value)
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError
+ except Exception:
+ raise dns.exception.SyntaxError("unknown rdatatype '%s'" %
+ token.value)
+
+ # rhs (required)
+ rhs = token.value
+
+ # The code currently only supports base 'd', so the last value
+ # in the tuple _parse_modify returns is ignored
+ lmod, lsign, loffset, lwidth, _ = self._parse_modify(lhs)
+ rmod, rsign, roffset, rwidth, _ = self._parse_modify(rhs)
+ for i in range(start, stop + 1, step):
+ # +1 because bind is inclusive and python is exclusive
+
+ if lsign == '+':
+ lindex = i + int(loffset)
+ elif lsign == '-':
+ lindex = i - int(loffset)
+
+ if rsign == '-':
+ rindex = i - int(roffset)
+ elif rsign == '+':
+ rindex = i + int(roffset)
+
+ lzfindex = str(lindex).zfill(int(lwidth))
+ rzfindex = str(rindex).zfill(int(rwidth))
+
+ name = lhs.replace('$%s' % (lmod), lzfindex)
+ rdata = rhs.replace('$%s' % (rmod), rzfindex)
+
+ self.last_name = dns.name.from_text(name, self.current_origin,
+ self.tok.idna_codec)
+ name = self.last_name
+ if not name.is_subdomain(self.zone_origin):
+ self._eat_line()
+ return
+ if self.relativize:
+ name = name.relativize(self.zone_origin)
+
+ try:
+ rd = dns.rdata.from_text(rdclass, rdtype, rdata,
+ self.current_origin, self.relativize,
+ self.zone_origin)
+ except dns.exception.SyntaxError:
+ # Catch and reraise.
+ raise
+ except Exception:
+ # All exceptions that occur in the processing of rdata
+ # are treated as syntax errors. This is not strictly
+ # correct, but it is correct almost all of the time.
+ # We convert them to syntax errors so that we can emit
+ # helpful filename:line info.
+ (ty, va) = sys.exc_info()[:2]
+ raise dns.exception.SyntaxError("caught exception %s: %s" %
+ (str(ty), str(va)))
+
+ self.txn.add(name, ttl, rd)
+
+ def read(self):
+ """Read a DNS master file and build a zone object.
+
+ @raises dns.zone.NoSOA: No SOA RR was found at the zone origin
+ @raises dns.zone.NoNS: No NS RRset was found at the zone origin
+ """
+
+ try:
+ while 1:
+ token = self.tok.get(True, True)
+ if token.is_eof():
+ if self.current_file is not None:
+ self.current_file.close()
+ if len(self.saved_state) > 0:
+ (self.tok,
+ self.current_origin,
+ self.last_name,
+ self.current_file,
+ self.last_ttl,
+ self.last_ttl_known,
+ self.default_ttl,
+ self.default_ttl_known) = self.saved_state.pop(-1)
+ continue
+ break
+ elif token.is_eol():
+ continue
+ elif token.is_comment():
+ self.tok.get_eol()
+ continue
+ elif token.value[0] == '$':
+ c = token.value.upper()
+ if c == '$TTL':
+ token = self.tok.get()
+ if not token.is_identifier():
+ raise dns.exception.SyntaxError("bad $TTL")
+ self.default_ttl = dns.ttl.from_text(token.value)
+ self.default_ttl_known = True
+ self.tok.get_eol()
+ elif c == '$ORIGIN':
+ self.current_origin = self.tok.get_name()
+ self.tok.get_eol()
+ if self.zone_origin is None:
+ self.zone_origin = self.current_origin
+ self.txn._set_origin(self.current_origin)
+ elif c == '$INCLUDE' and self.allow_include:
+ token = self.tok.get()
+ filename = token.value
+ token = self.tok.get()
+ if token.is_identifier():
+ new_origin =\
+ dns.name.from_text(token.value,
+ self.current_origin,
+ self.tok.idna_codec)
+ self.tok.get_eol()
+ elif not token.is_eol_or_eof():
+ raise dns.exception.SyntaxError(
+ "bad origin in $INCLUDE")
+ else:
+ new_origin = self.current_origin
+ self.saved_state.append((self.tok,
+ self.current_origin,
+ self.last_name,
+ self.current_file,
+ self.last_ttl,
+ self.last_ttl_known,
+ self.default_ttl,
+ self.default_ttl_known))
+ self.current_file = open(filename, 'r')
+ self.tok = dns.tokenizer.Tokenizer(self.current_file,
+ filename)
+ self.current_origin = new_origin
+ elif c == '$GENERATE':
+ self._generate_line()
+ else:
+ raise dns.exception.SyntaxError(
+ "Unknown master file directive '" + c + "'")
+ continue
+ self.tok.unget(token)
+ self._rr_line()
+ except dns.exception.SyntaxError as detail:
+ (filename, line_number) = self.tok.where()
+ if detail is None:
+ detail = "syntax error"
+ ex = dns.exception.SyntaxError(
+ "%s:%d: %s" % (filename, line_number, detail))
+ tb = sys.exc_info()[2]
+ raise ex.with_traceback(tb) from None
diff --git a/dns/node.py b/dns/node.py
index b7e21b5..8e1451f 100644
--- a/dns/node.py
+++ b/dns/node.py
@@ -183,3 +183,33 @@ class Node:
self.delete_rdataset(replacement.rdclass, replacement.rdtype,
replacement.covers)
self.rdatasets.append(replacement)
+
+
+@dns.immutable.immutable
+class ImmutableNode(Node):
+
+ """An ImmutableNode is an immutable set of rdatasets."""
+
+ def __init__(self, node):
+ super().__init__()
+ self.rdatasets = tuple(
+ [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+ )
+
+ def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise TypeError("immutable")
+ return super().find_rdataset(rdclass, rdtype, covers, False)
+
+ def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise TypeError("immutable")
+ return super().get_rdataset(rdclass, rdtype, covers, False)
+
+ def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+ raise TypeError("immutable")
+
+ def replace_rdataset(self, replacement):
+ raise TypeError("immutable")
diff --git a/dns/rdataset.py b/dns/rdataset.py
index ba93ab4..1f372cd 100644
--- a/dns/rdataset.py
+++ b/dns/rdataset.py
@@ -22,6 +22,7 @@ import random
import struct
import dns.exception
+import dns.immutable
import dns.rdatatype
import dns.rdataclass
import dns.rdata
@@ -306,6 +307,52 @@ class Rdataset(dns.set.Set):
return False
+@dns.immutable.immutable
+class ImmutableRdataset(Rdataset):
+
+ """An immutable DNS rdataset."""
+
+ def __init__(self, rdataset):
+ """Create an immutable rdataset from the specified rdataset."""
+
+ super().__init__(rdataset.rdclass, rdataset.rdtype, rdataset.covers,
+ rdataset.ttl)
+ self.items = dns.immutable.Dict(rdataset.items)
+
+ def update_ttl(self, ttl):
+ raise TypeError('immutable')
+
+ def add(self, rd, ttl=None):
+ raise TypeError('immutable')
+
+ def union_update(self, other):
+ raise TypeError('immutable')
+
+ def intersection_update(self, other):
+ raise TypeError('immutable')
+
+ def update(self, other):
+ raise TypeError('immutable')
+
+ def __delitem__(self, i):
+ raise TypeError('immutable')
+
+ def __ior__(self, other):
+ raise TypeError('immutable')
+
+ def __iand__(self, other):
+ raise TypeError('immutable')
+
+ def __iadd__(self, other):
+ raise TypeError('immutable')
+
+ def __isub__(self, other):
+ raise TypeError('immutable')
+
+ def clear(self):
+ raise TypeError('immutable')
+
+
def from_text_list(rdclass, rdtype, ttl, text_rdatas, idna_codec=None,
origin=None, relativize=True, relativize_to=None):
"""Create an rdataset with the specified class, type, and TTL, and with
diff --git a/dns/transaction.py b/dns/transaction.py
new file mode 100644
index 0000000..20d6939
--- /dev/null
+++ b/dns/transaction.py
@@ -0,0 +1,383 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+import collections
+
+import dns.exception
+import dns.name
+import dns.rdataclass
+import dns.rdataset
+import dns.rdatatype
+import dns.rrset
+import dns.ttl
+
+
+class TransactionManager:
+ def reader(self):
+ """Begin a read-only transaction."""
+ raise NotImplementedError # pragma: no cover
+
+ def writer(self, replacement=False):
+ """Begin a writable transaction.
+
+ *replacement*, a `bool`. If `True`, the content of the
+ transaction completely replaces any prior content. If False,
+ the default, then the content of the transaction updates the
+ existing content.
+ """
+ raise NotImplementedError # pragma: no cover
+
+
+class DeleteNotExact(dns.exception.DNSException):
+ """Existing data did not match data specified by an exact delete."""
+
+
+class ReadOnly(dns.exception.DNSException):
+ """Tried to write to a read-only transaction."""
+
+
+class Transaction:
+
+ def __init__(self, replacement=False, read_only=False):
+ self.replacement = replacement
+ self.read_only = read_only
+
+ #
+ # This is the high level API
+ #
+
+ def get(self, name, rdclass, rdtype, covers=dns.rdatatype.NONE):
+ """Return the rdataset associated with *name*, *rdclass*, *rdtype*,
+ and *covers*, or `None` if not found.
+
+ Note that the returned rdataset is immutable.
+ """
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ rdclass = dns.rdataclass.RdataClass.make(rdclass)
+ rdtype = dns.rdatatype.RdataType.make(rdtype)
+ rdataset = self._get_rdataset(name, rdclass, rdtype, covers)
+ if rdataset is not None and \
+ not isinstance(rdataset, dns.rdataset.ImmutableRdataset):
+ rdataset = dns.rdataset.ImmutableRdataset(rdataset)
+ return rdataset
+
+ def _check_read_only(self):
+ if self.read_only:
+ raise ReadOnly
+
+ def add(self, *args):
+ """Add records.
+
+ The arguments may be:
+
+ - rrset
+
+ - name, rdataset...
+
+ - name, ttl, rdata...
+ """
+ self._check_read_only()
+ return self._add(False, args)
+
+ def replace(self, *args):
+ """Replace the existing rdataset at the name with the specified
+ rdataset, or add the specified rdataset if there was no existing
+ rdataset.
+
+ The arguments may be:
+
+ - rrset
+
+ - name, rdataset...
+
+ - name, ttl, rdata...
+
+ Note that if you want to replace the entire node, you should do
+ a delete of the name followed by one or more calls to add() or
+ replace().
+ """
+ self._check_read_only()
+ return self._add(True, args)
+
+ def delete(self, *args):
+ """Delete records.
+
+ It is not an error if some of the records are not in the existing
+ set.
+
+ The arguments may be:
+
+ - rrset
+
+ - name
+
+ - name, rdataclass, rdatatype, [covers]
+
+ - name, rdataset...
+
+ - name, rdata...
+ """
+ self._check_read_only()
+ return self._delete(False, args)
+
+ def delete_exact(self, *args):
+ """Delete records.
+
+ The arguments may be:
+
+ - rrset
+
+ - name
+
+ - name, rdataclass, rdatatype, [covers]
+
+ - name, rdataset...
+
+ - name, rdata...
+
+ Raises dns.transaction.DeleteNotExact if some of the records
+ are not in the existing set.
+
+ """
+ self._check_read_only()
+ return self._delete(True, args)
+
+ def name_exists(self, name):
+ """Does the specified name exist?"""
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ return self._name_exists(name)
+
+ def set_serial(self, increment=1, value=None, name=dns.name.empty,
+ rdclass=dns.rdataclass.IN):
+ if isinstance(name, str):
+ name = dns.name.from_text(name, None)
+ rdataset = self._get_rdataset(name, rdclass, dns.rdatatype.SOA,
+ dns.rdatatype.NONE)
+ if rdataset is None or len(rdataset) == 0:
+ raise KeyError
+ if value is not None:
+ serial = value
+ else:
+ serial = rdataset[0].serial
+ serial += increment
+ if serial > 0xffffffff or serial < 1:
+ serial = 1
+ rdata = rdataset[0].replace(serial=serial)
+ new_rdataset = dns.rdataset.from_rdata(rdataset.ttl, rdata)
+ self.replace(name, new_rdataset)
+
+ def __iter__(self):
+ return self._iterate_rdatasets()
+
+ #
+ # Helper methods
+ #
+
+ def _raise_if_not_empty(self, method, args):
+ if len(args) != 0:
+ raise TypeError(f'extra parameters to {method}')
+
+ def _rdataset_from_args(self, method, deleting, args):
+ try:
+ arg = args.popleft()
+ if isinstance(arg, dns.rdataset.Rdataset):
+ rdataset = arg
+ else:
+ if deleting:
+ ttl = 0
+ else:
+ if isinstance(arg, int):
+ ttl = arg
+ if ttl > dns.ttl.MAX_TTL:
+ raise ValueError(f'{method}: TTL value too big')
+ else:
+ raise TypeError(f'{method}: expected a TTL')
+ arg = args.popleft()
+ if isinstance(arg, dns.rdata.Rdata):
+ rdataset = dns.rdataset.from_rdata(ttl, arg)
+ else:
+ raise TypeError(f'{method}: expected an Rdata')
+ return rdataset
+ except IndexError:
+ if deleting:
+ return None
+ else:
+ # reraise
+ raise TypeError(f'{method}: expected more arguments')
+
+ def _add(self, replace, args):
+ try:
+ args = collections.deque(args)
+ if replace:
+ method = 'replace()'
+ else:
+ method = 'add()'
+ arg = args.popleft()
+ if isinstance(arg, str):
+ arg = dns.name.from_text(arg, None)
+ if isinstance(arg, dns.name.Name):
+ name = arg
+ rdataset = self._rdataset_from_args(method, False, args)
+ elif isinstance(arg, dns.rrset.RRset):
+ rrset = arg
+ name = rrset.name
+ # rrsets are also rdatasets, but they don't print the
+ # same, so convert.
+ rdataset = dns.rdataset.Rdataset(rrset.rdclass, rrset.rdtype,
+ rrset.covers, rrset.ttl)
+ rdataset.union_update(rrset)
+ else:
+ raise TypeError(f'{method} requires a name or RRset ' +
+ 'as the first argument')
+ self._raise_if_not_empty(method, args)
+ if not replace:
+ existing = self._get_rdataset(name, rdataset.rdclass,
+ rdataset.rdtype, rdataset.covers)
+ if existing is not None:
+ if isinstance(existing, dns.rdataset.ImmutableRdataset):
+ trds = dns.rdataset.Rdataset(existing.rdclass,
+ existing.rdtype,
+ existing.covers)
+ trds.update(existing)
+ existing = trds
+ rdataset = existing.union(rdataset)
+ self._put_rdataset(name, rdataset)
+ except IndexError:
+ raise TypeError(f'not enough parameters to {method}')
+
+ def _delete(self, exact, args):
+ try:
+ args = collections.deque(args)
+ if exact:
+ method = 'delete_exact()'
+ else:
+ method = 'delete()'
+ arg = args.popleft()
+ if isinstance(arg, str):
+ arg = dns.name.from_text(arg, None)
+ if isinstance(arg, dns.name.Name):
+ name = arg
+ if len(args) > 0 and isinstance(args[0], int):
+ # deleting by type and class
+ rdclass = dns.rdataclass.RdataClass.make(args.popleft())
+ rdtype = dns.rdatatype.RdataType.make(args.popleft())
+ if len(args) > 0:
+ covers = dns.rdatatype.RdataType.make(args.popleft())
+ else:
+ covers = dns.rdatatype.NONE
+ self._raise_if_not_empty(method, args)
+ existing = self._get_rdataset(name, rdclass, rdtype, covers)
+ if existing is None:
+ if exact:
+ raise DeleteNotExact(f'{method}: missing rdataset')
+ else:
+ self._delete_rdataset(name, rdclass, rdtype, covers)
+ return
+ else:
+ rdataset = self._rdataset_from_args(method, True, args)
+ elif isinstance(arg, dns.rrset.RRset):
+ rdataset = arg # rrsets are also rdatasets
+ name = rdataset.name
+ else:
+ raise TypeError(f'{method} requires a name or RRset ' +
+ 'as the first argument')
+ self._raise_if_not_empty(method, args)
+ if rdataset:
+ existing = self._get_rdataset(name, rdataset.rdclass,
+ rdataset.rdtype, rdataset.covers)
+ if existing is not None:
+ if exact:
+ intersection = existing.intersection(rdataset)
+ if intersection != rdataset:
+ raise DeleteNotExact(f'{method}: missing rdatas')
+ rdataset = existing.difference(rdataset)
+ if len(rdataset) == 0:
+ self._delete_rdataset(name, rdataset.rdclass,
+ rdataset.rdtype, rdataset.covers)
+ else:
+ self._put_rdataset(name, rdataset)
+ elif exact:
+ raise DeleteNotExact(f'{method}: missing rdataset')
+ else:
+ if exact and not self._name_exists(name):
+ raise DeleteNotExact(f'{method}: name not known')
+ self._delete_name(name)
+ except IndexError:
+ raise TypeError(f'not enough parameters to {method}')
+
+ #
+ # Transactions are context managers.
+ #
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if exc_type is None:
+ self._end_transaction(True)
+ else:
+ self._end_transaction(False)
+ return False
+
+ #
+ # This is the low level API, which must be implemented by subclasses
+ # of Transaction.
+ #
+
+ def _get_rdataset(self, name, rdclass, rdtype, covers):
+ """Return the rdataset associated with *name*, *rdclass*, *rdtype*,
+ and *covers*, or `None` if not found."""
+ raise NotImplementedError # pragma: no cover
+
+ def _put_rdataset(self, name, rdataset):
+ """Store the rdataset."""
+ raise NotImplementedError # pragma: no cover
+
+ def _delete_name(self, name):
+ """Delete all data associated with *name*.
+
+ It is not an error if the rdataset does not exist.
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def _delete_rdataset(self, name, rdclass, rdtype, covers):
+ """Delete all data associated with *name*, *rdclass*, *rdtype*, and
+ *covers*.
+
+ It is not an error if the rdataset does not exist.
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def _name_exists(self, name):
+ """Does name exist?
+
+ Returns a bool.
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def _end_transaction(self, commit):
+ """End the transaction.
+
+ *commit*, a bool. If ``True``, commit the transaction, otherwise
+ roll it back.
+
+ Raises an exception if committing failed.
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def _set_origin(self, origin):
+ """Set the origin.
+
+ This method is called when reading a possibly relativized
+ source, and an origin setting operation occurs (e.g. $ORIGIN
+ in a masterfile).
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def _iterate_rdatasets(self):
+ """Return an iterator that yields (name, rdataset) tuples.
+
+ Not all Transaction subclasses implement this.
+ """
+ raise NotImplementedError # pragma: no cover
diff --git a/dns/versioned.py b/dns/versioned.py
new file mode 100644
index 0000000..6f911e1
--- /dev/null
+++ b/dns/versioned.py
@@ -0,0 +1,392 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""DNS Versioned Zones."""
+
+import collections
+try:
+ import threading as _threading
+except ImportError: # pragma: no cover
+ import dummy_threading as _threading # type: ignore
+
+import dns.exception
+import dns.immutable
+import dns.name
+import dns.node
+import dns.rdataclass
+import dns.rdatatype
+import dns.rdata
+import dns.rdtypes.ANY.SOA
+import dns.transaction
+import dns.zone
+
+
+class UseTransaction(dns.exception.DNSException):
+ """To alter a versioned zone, use a transaction."""
+
+
+class Version:
+ def __init__(self, zone, id):
+ self.zone = zone
+ self.id = id
+ self.nodes = {}
+
+ def _validate_name(self, name):
+ if name.is_absolute():
+ if not name.is_subdomain(self.zone.origin):
+ raise KeyError("name is not a subdomain of the zone origin")
+ if self.zone.relativize:
+ name = name.relativize(self.origin)
+ return name
+
+ def get_node(self, name):
+ name = self._validate_name(name)
+ return self.nodes.get(name)
+
+ def get_rdataset(self, name, rdtype, covers):
+ node = self.get_node(name)
+ if node is None:
+ return None
+ return node.get_rdataset(self.zone.rdclass, rdtype, covers)
+
+ def items(self):
+ return self.nodes.items() # pylint: disable=dict-items-not-iterating
+
+ def _print(self): # pragma: no cover
+ # XXXRTH This is for debugging
+ print('VERSION', self.id)
+ for (name, node) in self.nodes.items():
+ for rdataset in node:
+ print(rdataset.to_text(name))
+
+
+class WritableVersion(Version):
+ def __init__(self, zone, replacement=False):
+ if len(zone.versions) > 0:
+ id = zone.versions[-1].id + 1
+ else:
+ id = 1
+ super().__init__(zone, id)
+ if not replacement:
+ # We copy the map, because that gives us a simple and thread-safe
+ # way of doing versions, and we have a garbage collector to help
+ # us. We only make new node objects if we actually change the
+ # node.
+ self.nodes.update(zone.nodes)
+ # We have to copy the zone origin as it may be None in the first
+ # version, and we don't want to mutate the zone until we commit.
+ self.origin = zone.origin
+ self.changed = set()
+
+ def _validate_name(self, name):
+ if name.is_absolute():
+ if not name.is_subdomain(self.origin):
+ raise KeyError("name is not a subdomain of the zone origin")
+ if self.zone.relativize:
+ name = name.relativize(self.origin)
+ return name
+
+ def _maybe_cow(self, name):
+ name = self._validate_name(name)
+ node = self.nodes.get(name)
+ if node is None or node.id != self.id:
+ new_node = self.zone.node_factory()
+ new_node.id = self.id
+ if node is not None:
+ # moo! copy on write!
+ new_node.rdatasets.extend(node.rdatasets)
+ self.nodes[name] = new_node
+ self.changed.add(name)
+ return new_node
+ else:
+ return node
+
+ def delete_node(self, name):
+ name = self._validate_name(name)
+ if name in self.nodes:
+ del self.nodes[name]
+ return True
+ return False
+
+ def put_rdataset(self, name, rdataset):
+ node = self._maybe_cow(name)
+ node.replace_rdataset(rdataset)
+
+ def delete_rdataset(self, name, rdtype, covers):
+ node = self._maybe_cow(name)
+ if not node.get_rdataset(self.zone.rdclass, rdtype, covers):
+ return False
+ node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+ if len(node) == 0:
+ del self.nodes[name]
+ return True
+
+
+@dns.immutable.immutable
+class ImmutableVersion(Version):
+ def __init__(self, version):
+ # We tell super() that it's a replacement as we don't want it
+ # to copy the nodes, as we're about to do that with an
+ # immutable Dict.
+ super().__init__(version.zone, True)
+ # set the right id!
+ self.id = version.id
+ # Make changed nodes immutable
+ for name in version.changed:
+ node = version.nodes.get(name)
+ # it might not exist if we deleted it in the version
+ if node:
+ version.nodes[name] = ImmutableNode(node)
+ self.nodes = dns.immutable.Dict(version.nodes, True)
+
+
+# A node with a version id.
+
+class Node(dns.node.Node):
+ __slots__ = ['id']
+
+ def __init__(self):
+ super().__init__()
+ # A proper id will get set by the Version
+ self.id = 0
+
+
+# It would be nice if this were a subclass of Node (just above) but it's
+# less code duplication this way as we inherit all of the method disabling
+# code.
+
+@dns.immutable.immutable
+class ImmutableNode(dns.node.ImmutableNode):
+ __slots__ = ['id']
+
+ def __init__(self, node):
+ super().__init__(node)
+ self.id = node.id
+
+
+class Zone(dns.zone.Zone):
+
+ __slots__ = ['versions', '_write_txn', '_write_waiters', '_write_event',
+ '_pruning_policy']
+
+ node_factory = Node
+
+ def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True,
+ pruning_policy=None):
+ """Initialize a versioned zone object.
+
+ *origin* is the origin of the zone. It may be a ``dns.name.Name``,
+ a ``str``, or ``None``. If ``None``, then the zone's origin will
+ be set by the first ``$ORIGIN`` line in a masterfile.
+
+ *rdclass*, an ``int``, the zone's rdata class; the default is class IN.
+
+ *relativize*, a ``bool``, determine's whether domain names are
+ relativized to the zone's origin. The default is ``True``.
+
+ *pruning policy*, a function taking a `Version` and returning
+ a `bool`, or `None`. Should the version be pruned? If `None`,
+ the default policy, which retains one version is used.
+ """
+ super().__init__(origin, rdclass, relativize)
+ self.versions = collections.deque()
+ self.version_lock = _threading.Lock()
+ if pruning_policy is None:
+ self._pruning_policy = self._default_pruning_policy
+ else:
+ self._pruning_policy = pruning_policy
+ self._write_txn = None
+ self._write_event = None
+ self._write_waiters = collections.deque()
+ self._commit_version_unlocked(WritableVersion(self), origin)
+
+ def reader(self):
+ with self.version_lock:
+ return Transaction(False, self, self.versions[-1])
+
+ def writer(self, replacement=False):
+ event = None
+ while True:
+ with self.version_lock:
+ # Checking event == self._write_event ensures that either
+ # no one was waiting before we got lucky and found no write
+ # txn, or we were the one who was waiting and got woken up.
+ # This prevents "taking cuts" when creating a write txn.
+ if self._write_txn is None and event == self._write_event:
+ # Creating the transaction defers version setup
+ # (i.e. copying the nodes dictionary) until we
+ # give up the lock, so that we hold the lock as
+ # short a time as possible. This is why we call
+ # _setup_version() below.
+ self._write_txn = Transaction(replacement, self)
+ # give up our exclusive right to make a Transaction
+ self._write_event = None
+ break
+ # Someone else is writing already, so we will have to
+ # wait, but we want to do the actual wait outside the
+ # lock.
+ event = _threading.Event()
+ self._write_waiters.append(event)
+ # wait (note we gave up the lock!)
+ #
+ # We only wake one sleeper at a time, so it's important
+ # that no event waiter can exit this method (e.g. via
+ # cancelation) without returning a transaction or waking
+ # someone else up.
+ #
+ # This is not a problem with Threading module threads as
+ # they cannot be canceled, but could be an issue with trio
+ # or curio tasks when we do the async version of writer().
+ # I.e. we'd need to do something like:
+ #
+ # try:
+ # event.wait()
+ # except trio.Cancelled:
+ # with self.version_lock:
+ # self._maybe_wakeup_one_waiter_unlocked()
+ # raise
+ #
+ event.wait()
+ # Do the deferred version setup.
+ self._write_txn._setup_version()
+ return self._write_txn
+
+ def _maybe_wakeup_one_waiter_unlocked(self):
+ if len(self._write_waiters) > 0:
+ self._write_event = self._write_waiters.popleft()
+ self._write_event.set()
+
+ # pylint: disable=unused-argument
+ def _default_pruning_policy(self, zone, version):
+ return True
+ # pylint: enable=unused-argument
+
+ def _prune_versions_unlocked(self):
+ while len(self.versions) > 1 and \
+ self._pruning_policy(self, self.versions[0]):
+ self.versions.popleft()
+
+ def set_max_versions(self, max_versions):
+ """Set a pruning policy that retains up to the specified number
+ of versions
+ """
+ if max_versions is not None and max_versions < 1:
+ raise ValueError('max versions must be at least 1')
+ if max_versions is None:
+ def policy(*_):
+ return False
+ else:
+ def policy(zone, _):
+ return len(zone.versions) > max_versions
+ self.set_pruning_policy(policy)
+
+ def set_pruning_policy(self, policy):
+ """Set the pruning policy for the zone.
+
+ The *policy* function takes a `Version` and returns `True` if
+ the version should be pruned, and `False` otherwise. `None`
+ may also be specified for policy, in which case the default policy
+ is used.
+
+ Pruning checking proceeds from the least version and the first
+ time the function returns `False`, the checking stops. I.e. the
+ retained versions are always a consecutive sequence.
+ """
+ if policy is None:
+ policy = self._default_pruning_policy
+ with self.version_lock:
+ self._pruning_policy = policy
+ self._prune_versions_unlocked()
+
+ def _commit_version_unlocked(self, version, origin):
+ self.versions.append(version)
+ self._prune_versions_unlocked()
+ self.nodes = version.nodes
+ if self.origin is None:
+ self.origin = origin
+ self._write_txn = None
+ self._maybe_wakeup_one_waiter_unlocked()
+
+ def _commit_version(self, version, origin):
+ with self.version_lock:
+ self._commit_version_unlocked(version, origin)
+
+ def find_node(self, name, create=False):
+ if create:
+ raise UseTransaction
+ return super().find_node(name)
+
+ def delete_node(self, name):
+ raise UseTransaction
+
+ def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise UseTransaction
+ rdataset = super().find_rdataset(name, rdtype, covers)
+ return dns.rdataset.ImmutableRdataset(rdataset)
+
+ def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise UseTransaction
+ rdataset = super().get_rdataset(name, rdtype, covers)
+ return dns.rdataset.ImmutableRdataset(rdataset)
+
+ def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):
+ raise UseTransaction
+
+ def replace_rdataset(self, name, replacement):
+ raise UseTransaction
+
+
+class Transaction(dns.transaction.Transaction):
+
+ def __init__(self, replacement, zone, version=None):
+ read_only = version is not None
+ super().__init__(replacement, read_only)
+ self.zone = zone
+ self.version = version
+
+ def _setup_version(self):
+ assert self.version is None
+ self.version = WritableVersion(self.zone, self.replacement)
+
+ def _get_rdataset(self, name, rdclass, rdtype, covers):
+ if rdclass != self.zone.rdclass:
+ raise ValueError(f'class {rdclass} != ' +
+ f'zone class {self.zone.rdclass}')
+ return self.version.get_rdataset(name, rdtype, covers)
+
+ def _put_rdataset(self, name, rdataset):
+ assert not self.read_only
+ if rdataset.rdclass != self.zone.rdclass:
+ raise ValueError(f'rdataset class {rdataset.rdclass} != ' +
+ f'zone class {self.zone.rdclass}')
+ self.version.put_rdataset(name, rdataset)
+
+ def _delete_name(self, name):
+ assert not self.read_only
+ self.version.delete_node(name)
+
+ def _delete_rdataset(self, name, rdclass, rdtype, covers):
+ assert not self.read_only
+ self.version.delete_rdataset(name, rdtype, covers)
+
+ def _name_exists(self, name):
+ return self.version.get_node(name) is not None
+
+ def _end_transaction(self, commit):
+ if self.read_only:
+ return
+ if commit and len(self.version.changed) > 0:
+ self.zone._commit_version(ImmutableVersion(self.version),
+ self.version.origin)
+
+ def _set_origin(self, origin):
+ if self.version.origin is None:
+ self.version.origin = origin
+
+ def _iterate_rdatasets(self):
+ for (name, node) in self.version.items():
+ for rdataset in node:
+ yield (name, rdataset)
diff --git a/dns/zone.py b/dns/zone.py
index d5bb305..2ca9bc2 100644
--- a/dns/zone.py
+++ b/dns/zone.py
@@ -20,10 +20,9 @@
import contextlib
import io
import os
-import re
-import sys
import dns.exception
+import dns.masterfile
import dns.name
import dns.node
import dns.rdataclass
@@ -32,6 +31,7 @@ import dns.rdata
import dns.rdtypes.ANY.SOA
import dns.rrset
import dns.tokenizer
+import dns.transaction
import dns.ttl
import dns.grange
@@ -56,7 +56,7 @@ class UnknownOrigin(BadZone):
"""The DNS zone's origin is unknown."""
-class Zone:
+class Zone(dns.transaction.TransactionManager):
"""A DNS zone.
@@ -642,415 +642,108 @@ class Zone:
if self.get_rdataset(name, dns.rdatatype.NS) is None:
raise NoNS
+ def reader(self):
+ return Transaction(False, True, self)
-class _MasterReader:
-
- """Read a DNS master file
-
- @ivar tok: The tokenizer
- @type tok: dns.tokenizer.Tokenizer object
- @ivar last_ttl: The last seen explicit TTL for an RR
- @type last_ttl: int
- @ivar last_ttl_known: Has last TTL been detected
- @type last_ttl_known: bool
- @ivar default_ttl: The default TTL from a $TTL directive or SOA RR
- @type default_ttl: int
- @ivar default_ttl_known: Has default TTL been detected
- @type default_ttl_known: bool
- @ivar last_name: The last name read
- @type last_name: dns.name.Name object
- @ivar current_origin: The current origin
- @type current_origin: dns.name.Name object
- @ivar relativize: should names in the zone be relativized?
- @type relativize: bool
- @ivar zone: the zone
- @type zone: dns.zone.Zone object
- @ivar saved_state: saved reader state (used when processing $INCLUDE)
- @type saved_state: list of (tokenizer, current_origin, last_name, file,
- last_ttl, last_ttl_known, default_ttl, default_ttl_known) tuples.
- @ivar current_file: the file object of the $INCLUDed file being parsed
- (None if no $INCLUDE is active).
- @ivar allow_include: is $INCLUDE allowed?
- @type allow_include: bool
- @ivar check_origin: should sanity checks of the origin node be done?
- The default is True.
- @type check_origin: bool
- """
+ def writer(self, replacement=False):
+ return Transaction(replacement, False, self)
- def __init__(self, tok, origin, rdclass, relativize, zone_factory=Zone,
- allow_include=False, check_origin=True):
- if isinstance(origin, str):
- origin = dns.name.from_text(origin)
- self.tok = tok
- self.current_origin = origin
- self.relativize = relativize
- self.last_ttl = 0
- self.last_ttl_known = False
- self.default_ttl = 0
- self.default_ttl_known = False
- self.last_name = self.current_origin
- self.zone = zone_factory(origin, rdclass, relativize=relativize)
- self.saved_state = []
- self.current_file = None
- self.allow_include = allow_include
- self.check_origin = check_origin
-
- def _eat_line(self):
- while 1:
- token = self.tok.get()
- if token.is_eol_or_eof():
- break
-
- def _rr_line(self):
- """Process one line from a DNS master file."""
- # Name
- if self.current_origin is None:
- raise UnknownOrigin
- token = self.tok.get(want_leading=True)
- if not token.is_whitespace():
- self.last_name = self.tok.as_name(token, self.current_origin)
- else:
- token = self.tok.get()
- if token.is_eol_or_eof():
- # treat leading WS followed by EOL/EOF as if they were EOL/EOF.
- return
- self.tok.unget(token)
- name = self.last_name
- if not name.is_subdomain(self.zone.origin):
- self._eat_line()
- return
- if self.relativize:
- name = name.relativize(self.zone.origin)
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- # TTL
- ttl = None
- try:
- ttl = dns.ttl.from_text(token.value)
- self.last_ttl = ttl
- self.last_ttl_known = True
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except dns.ttl.BadTTL:
- if self.default_ttl_known:
- ttl = self.default_ttl
- elif self.last_ttl_known:
- ttl = self.last_ttl
-
- # Class
- try:
- rdclass = dns.rdataclass.from_text(token.value)
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except dns.exception.SyntaxError:
- raise
- except Exception:
- rdclass = self.zone.rdclass
- if rdclass != self.zone.rdclass:
- raise dns.exception.SyntaxError("RR class is not zone's class")
- # Type
- try:
- rdtype = dns.rdatatype.from_text(token.value)
- except Exception:
- raise dns.exception.SyntaxError(
- "unknown rdatatype '%s'" % token.value)
- n = self.zone.nodes.get(name)
- if n is None:
- n = self.zone.node_factory()
- self.zone.nodes[name] = n
- try:
- rd = dns.rdata.from_text(rdclass, rdtype, self.tok,
- self.current_origin, self.relativize,
- self.zone.origin)
- except dns.exception.SyntaxError:
- # Catch and reraise.
- raise
- except Exception:
- # All exceptions that occur in the processing of rdata
- # are treated as syntax errors. This is not strictly
- # correct, but it is correct almost all of the time.
- # We convert them to syntax errors so that we can emit
- # helpful filename:line info.
- (ty, va) = sys.exc_info()[:2]
- raise dns.exception.SyntaxError(
- "caught exception {}: {}".format(str(ty), str(va)))
-
- if not self.default_ttl_known and rdtype == dns.rdatatype.SOA:
- # The pre-RFC2308 and pre-BIND9 behavior inherits the zone default
- # TTL from the SOA minttl if no $TTL statement is present before the
- # SOA is parsed.
- self.default_ttl = rd.minimum
- self.default_ttl_known = True
- if ttl is None:
- # if we didn't have a TTL on the SOA, set it!
- ttl = rd.minimum
-
- # TTL check. We had to wait until now to do this as the SOA RR's
- # own TTL can be inferred from its minimum.
- if ttl is None:
- raise dns.exception.SyntaxError("Missing default TTL value")
-
- covers = rd.covers()
- rds = n.find_rdataset(rdclass, rdtype, covers, True)
- rds.add(rd, ttl)
-
- def _parse_modify(self, side):
- # Here we catch everything in '{' '}' in a group so we can replace it
- # with ''.
- is_generate1 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+),(.)}).*$")
- is_generate2 = re.compile(r"^.*\$({(\+|-?)(\d+)}).*$")
- is_generate3 = re.compile(r"^.*\$({(\+|-?)(\d+),(\d+)}).*$")
- # Sometimes there are modifiers in the hostname. These come after
- # the dollar sign. They are in the form: ${offset[,width[,base]]}.
- # Make names
- g1 = is_generate1.match(side)
- if g1:
- mod, sign, offset, width, base = g1.groups()
- if sign == '':
- sign = '+'
- g2 = is_generate2.match(side)
- if g2:
- mod, sign, offset = g2.groups()
- if sign == '':
- sign = '+'
- width = 0
- base = 'd'
- g3 = is_generate3.match(side)
- if g3:
- mod, sign, offset, width = g3.groups()
- if sign == '':
- sign = '+'
- base = 'd'
-
- if not (g1 or g2 or g3):
- mod = ''
- sign = '+'
- offset = 0
- width = 0
- base = 'd'
-
- if base != 'd':
- raise NotImplementedError()
-
- return mod, sign, offset, width, base
-
- def _generate_line(self):
- # range lhs [ttl] [class] type rhs [ comment ]
- """Process one line containing the GENERATE statement from a DNS
- master file."""
- if self.current_origin is None:
- raise UnknownOrigin
-
- token = self.tok.get()
- # Range (required)
- try:
- start, stop, step = dns.grange.from_text(token.value)
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except Exception:
- raise dns.exception.SyntaxError
-
- # lhs (required)
- try:
- lhs = token.value
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except Exception:
- raise dns.exception.SyntaxError
-
- # TTL
- try:
- ttl = dns.ttl.from_text(token.value)
- self.last_ttl = ttl
- self.last_ttl_known = True
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except dns.ttl.BadTTL:
- if not (self.last_ttl_known or self.default_ttl_known):
- raise dns.exception.SyntaxError("Missing default TTL value")
- if self.default_ttl_known:
- ttl = self.default_ttl
- elif self.last_ttl_known:
- ttl = self.last_ttl
- # Class
- try:
- rdclass = dns.rdataclass.from_text(token.value)
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except dns.exception.SyntaxError:
- raise dns.exception.SyntaxError
- except Exception:
- rdclass = self.zone.rdclass
+class Transaction(dns.transaction.Transaction):
+
+ _deleted_rdataset = dns.rdataset.Rdataset(dns.rdataclass.ANY,
+ dns.rdatatype.ANY)
+
+ def __init__(self, replacement, read_only, zone):
+ super().__init__(replacement, read_only)
+ self.zone = zone
+ self.rdatasets = {}
+
+ def _get_rdataset(self, name, rdclass, rdtype, covers):
if rdclass != self.zone.rdclass:
- raise dns.exception.SyntaxError("RR class is not zone's class")
- # Type
- try:
- rdtype = dns.rdatatype.from_text(token.value)
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError
- except Exception:
- raise dns.exception.SyntaxError("unknown rdatatype '%s'" %
- token.value)
-
- # rhs (required)
- rhs = token.value
-
- # The code currently only supports base 'd', so the last value
- # in the tuple _parse_modify returns is ignored
- lmod, lsign, loffset, lwidth, _ = self._parse_modify(lhs)
- rmod, rsign, roffset, rwidth, _ = self._parse_modify(rhs)
- for i in range(start, stop + 1, step):
- # +1 because bind is inclusive and python is exclusive
-
- if lsign == '+':
- lindex = i + int(loffset)
- elif lsign == '-':
- lindex = i - int(loffset)
-
- if rsign == '-':
- rindex = i - int(roffset)
- elif rsign == '+':
- rindex = i + int(roffset)
-
- lzfindex = str(lindex).zfill(int(lwidth))
- rzfindex = str(rindex).zfill(int(rwidth))
-
- name = lhs.replace('$%s' % (lmod), lzfindex)
- rdata = rhs.replace('$%s' % (rmod), rzfindex)
-
- self.last_name = dns.name.from_text(name, self.current_origin,
- self.tok.idna_codec)
- name = self.last_name
- if not name.is_subdomain(self.zone.origin):
- self._eat_line()
- return
- if self.relativize:
- name = name.relativize(self.zone.origin)
-
- n = self.zone.nodes.get(name)
- if n is None:
- n = self.zone.node_factory()
- self.zone.nodes[name] = n
- try:
- rd = dns.rdata.from_text(rdclass, rdtype, rdata,
- self.current_origin, self.relativize,
- self.zone.origin)
- except dns.exception.SyntaxError:
- # Catch and reraise.
- raise
- except Exception:
- # All exceptions that occur in the processing of rdata
- # are treated as syntax errors. This is not strictly
- # correct, but it is correct almost all of the time.
- # We convert them to syntax errors so that we can emit
- # helpful filename:line info.
- (ty, va) = sys.exc_info()[:2]
- raise dns.exception.SyntaxError("caught exception %s: %s" %
- (str(ty), str(va)))
-
- covers = rd.covers()
- rds = n.find_rdataset(rdclass, rdtype, covers, True)
- rds.add(rd, ttl)
-
- def read(self):
- """Read a DNS master file and build a zone object.
-
- @raises dns.zone.NoSOA: No SOA RR was found at the zone origin
- @raises dns.zone.NoNS: No NS RRset was found at the zone origin
- """
+ raise ValueError(f'class {rdclass} != ' +
+ f'zone class {self.zone.rdclass}')
+ rdataset = self.rdatasets.get((name, rdtype, covers))
+ if rdataset is self._deleted_rdataset:
+ return None
+ elif rdataset is None:
+ rdataset = self.zone.get_rdataset(name, rdtype, covers)
+ return rdataset
+ def _put_rdataset(self, name, rdataset):
+ assert not self.read_only
+ self.zone._validate_name(name)
+ if rdataset.rdclass != self.zone.rdclass:
+ raise ValueError(f'rdataset class {rdataset.rdclass} != ' +
+ f'zone class {self.zone.rdclass}')
+ self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+
+ def _delete_name(self, name):
+ assert not self.read_only
+ # First remove any changes involving the name
+ remove = []
+ for key in self.rdatasets:
+ if key[0] == name:
+ remove.append(key)
+ if len(remove) > 0:
+ for key in remove:
+ del self.rdatasets[key]
+ # Next add deletion records for any rdatasets matching the
+ # name in the zone
+ node = self.zone.get_node(name)
+ if node is not None:
+ for rdataset in node.rdatasets:
+ self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
+ self._deleted_rdataset
+
+ def _delete_rdataset(self, name, rdclass, rdtype, covers):
+ assert not self.read_only
+ # The high-level code always does a _get_rdataset() before any
+ # situation where it would call _delete_rdataset(), so we don't
+ # need to check if rdclass != self.zone.rdclass.
try:
- while 1:
- token = self.tok.get(True, True)
- if token.is_eof():
- if self.current_file is not None:
- self.current_file.close()
- if len(self.saved_state) > 0:
- (self.tok,
- self.current_origin,
- self.last_name,
- self.current_file,
- self.last_ttl,
- self.last_ttl_known,
- self.default_ttl,
- self.default_ttl_known) = self.saved_state.pop(-1)
- continue
- break
- elif token.is_eol():
- continue
- elif token.is_comment():
- self.tok.get_eol()
- continue
- elif token.value[0] == '$':
- c = token.value.upper()
- if c == '$TTL':
- token = self.tok.get()
- if not token.is_identifier():
- raise dns.exception.SyntaxError("bad $TTL")
- self.default_ttl = dns.ttl.from_text(token.value)
- self.default_ttl_known = True
- self.tok.get_eol()
- elif c == '$ORIGIN':
- self.current_origin = self.tok.get_name()
- self.tok.get_eol()
- if self.zone.origin is None:
- self.zone.origin = self.current_origin
- elif c == '$INCLUDE' and self.allow_include:
- token = self.tok.get()
- filename = token.value
- token = self.tok.get()
- if token.is_identifier():
- new_origin =\
- dns.name.from_text(token.value,
- self.current_origin,
- self.tok.idna_codec)
- self.tok.get_eol()
- elif not token.is_eol_or_eof():
- raise dns.exception.SyntaxError(
- "bad origin in $INCLUDE")
- else:
- new_origin = self.current_origin
- self.saved_state.append((self.tok,
- self.current_origin,
- self.last_name,
- self.current_file,
- self.last_ttl,
- self.last_ttl_known,
- self.default_ttl,
- self.default_ttl_known))
- self.current_file = open(filename, 'r')
- self.tok = dns.tokenizer.Tokenizer(self.current_file,
- filename)
- self.current_origin = new_origin
- elif c == '$GENERATE':
- self._generate_line()
- else:
- raise dns.exception.SyntaxError(
- "Unknown master file directive '" + c + "'")
- continue
- self.tok.unget(token)
- self._rr_line()
- except dns.exception.SyntaxError as detail:
- (filename, line_number) = self.tok.where()
- if detail is None:
- detail = "syntax error"
- ex = dns.exception.SyntaxError(
- "%s:%d: %s" % (filename, line_number, detail))
- tb = sys.exc_info()[2]
- raise ex.with_traceback(tb) from None
-
- # Now that we're done reading, do some basic checking of the zone.
- if self.check_origin:
- self.zone.check_origin()
+ del self.rdatasets[(name, rdtype, covers)]
+ except KeyError:
+ pass
+ rdataset = self.zone.get_rdataset(name, rdtype, covers)
+ if rdataset is not None:
+ self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
+ self._deleted_rdataset
+
+ def _name_exists(self, name):
+ for key, rdataset in self.rdatasets.items():
+ if key[0] == name:
+ if rdataset != self._deleted_rdataset:
+ return True
+ else:
+ return None
+ self.zone._validate_name(name)
+ if self.zone.get_node(name):
+ return True
+ return False
+
+ def _end_transaction(self, commit):
+ if commit and not self.read_only:
+ for (name, rdtype, covers), rdataset in \
+ self.rdatasets.items():
+ if rdataset is self._deleted_rdataset:
+ self.zone.delete_rdataset(name, rdtype, covers)
+ else:
+ self.zone.replace_rdataset(name, rdataset)
+
+ def _set_origin(self, origin):
+ if self.zone.origin is None:
+ self.zone.origin = origin
+
+ def _iterate_rdatasets(self):
+ # Expensive but simple! Use a versioned zone for efficient txn
+ # iteration.
+ rdatasets = {}
+ for (name, rdataset) in self.zone.iterate_rdatasets():
+ rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ rdatasets.update(self.rdatasets)
+ for (name, _, _), rdataset in rdatasets.items():
+ yield (name, rdataset)
def from_text(text, origin=None, rdclass=dns.rdataclass.IN,
@@ -1103,12 +796,20 @@ def from_text(text, origin=None, rdclass=dns.rdataclass.IN,
if filename is None:
filename = '<string>'
- tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
- reader = _MasterReader(tok, origin, rdclass, relativize, zone_factory,
- allow_include=allow_include,
- check_origin=check_origin)
- reader.read()
- return reader.zone
+ zone = zone_factory(origin, rdclass, relativize=relativize)
+ with zone.writer(True) as txn:
+ tok = dns.tokenizer.Tokenizer(text, filename, idna_codec=idna_codec)
+ reader = dns.masterfile.Reader(tok, origin, rdclass, relativize, txn,
+ allow_include=allow_include)
+ try:
+ reader.read()
+ except dns.masterfile.UnknownOrigin:
+ # for backwards compatibility
+ raise dns.zone.UnknownOrigin
+ # Now that we're done reading, do some basic checking of the zone.
+ if check_origin:
+ zone.check_origin()
+ return zone
def from_file(f, origin=None, rdclass=dns.rdataclass.IN,
diff --git a/tests/test_immutable.py b/tests/test_immutable.py
index 0385fc9..1a70e3d 100644
--- a/tests/test_immutable.py
+++ b/tests/test_immutable.py
@@ -3,20 +3,30 @@
import unittest
import dns.immutable
+import dns._immutable_attr
+
+try:
+ import dns._immutable_ctx as immutable_ctx
+ _have_contextvars = True
+except ImportError:
+ _have_contextvars = False
+
+ class immutable_ctx:
+ pass
class ImmutableTestCase(unittest.TestCase):
- def test_ImmutableDict_hash(self):
- d1 = dns.immutable.ImmutableDict({'a': 1, 'b': 2})
- d2 = dns.immutable.ImmutableDict({'b': 2, 'a': 1})
+ def test_immutable_dict_hash(self):
+ d1 = dns.immutable.Dict({'a': 1, 'b': 2})
+ d2 = dns.immutable.Dict({'b': 2, 'a': 1})
d3 = {'b': 2, 'a': 1}
self.assertEqual(d1, d2)
self.assertEqual(d2, d3)
self.assertEqual(hash(d1), hash(d2))
- def test_ImmutableDict_hash_cache(self):
- d = dns.immutable.ImmutableDict({'a': 1, 'b': 2})
+ def test_immutable_dict_hash_cache(self):
+ d = dns.immutable.Dict({'a': 1, 'b': 2})
self.assertEqual(d._hash, None)
h1 = hash(d)
self.assertEqual(d._hash, h1)
@@ -30,11 +40,121 @@ class ImmutableTestCase(unittest.TestCase):
((1, [2], 3), (1, (2,), 3)),
([1, 2, 3], (1, 2, 3)),
([1, {'a': [1, 2]}],
- (1, dns.immutable.ImmutableDict({'a': (1, 2)}))),
+ (1, dns.immutable.Dict({'a': (1, 2)}))),
('hi', 'hi'),
(b'hi', b'hi'),
)
for input, expected in items:
self.assertEqual(dns.immutable.constify(input), expected)
self.assertIsInstance(dns.immutable.constify({'a': 1}),
- dns.immutable.ImmutableDict)
+ dns.immutable.Dict)
+
+
+class DecoratorTestCase(unittest.TestCase):
+
+ immutable_module = dns._immutable_attr
+
+ def make_classes(self):
+ class A:
+ def __init__(self, a, akw=10):
+ self.a = a
+ self.akw = akw
+
+ class B(A):
+ def __init__(self, a, b):
+ super().__init__(a, akw=20)
+ self.b = b
+ B = self.immutable_module.immutable(B)
+
+ # note C is immutable by inheritance
+ class C(B):
+ def __init__(self, a, b, c):
+ super().__init__(a, b)
+ self.c = c
+ C = self.immutable_module.immutable(C)
+
+ class SA:
+ __slots__ = ('a', 'akw')
+ def __init__(self, a, akw=10):
+ self.a = a
+ self.akw = akw
+
+ class SB(A):
+ __slots__ = ('b')
+ def __init__(self, a, b):
+ super().__init__(a, akw=20)
+ self.b = b
+ SB = self.immutable_module.immutable(SB)
+
+ # note SC is immutable by inheritance and has no slots of its own
+ class SC(SB):
+ def __init__(self, a, b, c):
+ super().__init__(a, b)
+ self.c = c
+ SC = self.immutable_module.immutable(SC)
+
+ return ((A, B, C), (SA, SB, SC))
+
+ def test_basic(self):
+ for A, B, C in self.make_classes():
+ a = A(1)
+ self.assertEqual(a.a, 1)
+ self.assertEqual(a.akw, 10)
+ b = B(11, 21)
+ self.assertEqual(b.a, 11)
+ self.assertEqual(b.akw, 20)
+ self.assertEqual(b.b, 21)
+ c = C(111, 211, 311)
+ self.assertEqual(c.a, 111)
+ self.assertEqual(c.akw, 20)
+ self.assertEqual(c.b, 211)
+ self.assertEqual(c.c, 311)
+ # changing A is ok!
+ a.a = 11
+ self.assertEqual(a.a, 11)
+ # changing B is not!
+ with self.assertRaises(TypeError):
+ b.a = 11
+ with self.assertRaises(TypeError):
+ del b.a
+
+ def test_constructor_deletes_attribute(self):
+ class A:
+ def __init__(self, a):
+ self.a = a
+ self.b = a
+ del self.b
+ A = self.immutable_module.immutable(A)
+ a = A(10)
+ self.assertEqual(a.a, 10)
+ self.assertFalse(hasattr(a, 'b'))
+
+ def test_no_collateral_damage(self):
+
+ # A and B are immutable but not related. The magic that lets
+ # us write to immutable things while initializing B should not let
+ # B mess with A.
+
+ class A:
+ def __init__(self, a):
+ self.a = a
+ A = self.immutable_module.immutable(A)
+
+ class B:
+ def __init__(self, a, b):
+ self.b = a.a + b
+ # rudely attempt to mutate innocent immutable bystander 'a'
+ a.a = 1000
+ B = self.immutable_module.immutable(B)
+
+ a = A(10)
+ self.assertEqual(a.a, 10)
+ with self.assertRaises(TypeError):
+ B(a, 20)
+ self.assertEqual(a.a, 10)
+
+
+@unittest.skipIf(not _have_contextvars, "contextvars not available")
+class CtxDecoratorTestCase(DecoratorTestCase):
+
+ immutable_module = immutable_ctx
diff --git a/tests/test_rdataset.py b/tests/test_rdataset.py
index a80d650..88b4840 100644
--- a/tests/test_rdataset.py
+++ b/tests/test_rdataset.py
@@ -122,5 +122,34 @@ class RdatasetTestCase(unittest.TestCase):
'<DNS IN RRSIG(NSEC) rdataset:'))
+class ImmutableRdatasetTestCase(unittest.TestCase):
+
+ def test_basic(self):
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1', '10.0.0.2')
+ rd = dns.rdata.from_text('in', 'a', '10.0.0.3')
+ irds = dns.rdataset.ImmutableRdataset(rds)
+ with self.assertRaises(TypeError):
+ irds.update_ttl(100)
+ with self.assertRaises(TypeError):
+ irds.add(rd, 300)
+ with self.assertRaises(TypeError):
+ irds.union_update(rds)
+ with self.assertRaises(TypeError):
+ irds.intersection_update(rds)
+ with self.assertRaises(TypeError):
+ irds.update(rds)
+ with self.assertRaises(TypeError):
+ irds += rds
+ with self.assertRaises(TypeError):
+ irds -= rds
+ with self.assertRaises(TypeError):
+ irds &= rds
+ with self.assertRaises(TypeError):
+ irds |= rds
+ with self.assertRaises(TypeError):
+ del irds[0]
+ with self.assertRaises(TypeError):
+ irds.clear()
+
if __name__ == '__main__':
unittest.main()
diff --git a/tests/test_transaction.py b/tests/test_transaction.py
new file mode 100644
index 0000000..ed154fc
--- /dev/null
+++ b/tests/test_transaction.py
@@ -0,0 +1,451 @@
+import time
+
+import pytest
+
+import dns.name
+import dns.rdataclass
+import dns.rdatatype
+import dns.rdataset
+import dns.rrset
+import dns.transaction
+import dns.versioned
+import dns.zone
+
+class DB(dns.transaction.TransactionManager):
+ def __init__(self):
+ self.rdatasets = {}
+
+ def reader(self):
+ return Transaction(False, True, self)
+
+ def writer(self, replacement=False):
+ return Transaction(replacement, False, self)
+
+
+class Transaction(dns.transaction.Transaction):
+ def __init__(self, replacement, read_only, db):
+ super().__init__(replacement)
+ self.db = db
+ self.rdatasets = {}
+ self.read_only = read_only
+ if not replacement:
+ self.rdatasets.update(db.rdatasets)
+
+ def _get_rdataset(self, name, rdclass, rdtype, covers):
+ return self.rdatasets.get((name, rdclass, rdtype, covers))
+
+ def _put_rdataset(self, name, rdataset):
+ self.rdatasets[(name, rdataset.rdclass, rdataset.rdtype,
+ rdataset.covers)] = rdataset
+
+ def _delete_name(self, name):
+ remove = []
+ for key in self.rdatasets.keys():
+ if key[0] == name:
+ remove.append(key)
+ if len(remove) > 0:
+ for key in remove:
+ del self.rdatasets[key]
+
+ def _delete_rdataset(self, name, rdclass, rdtype, covers):
+ del self.rdatasets[(name, rdclass, rdtype, covers)]
+
+ def _name_exists(self, name):
+ for key in self.rdatasets.keys():
+ if key[0] == name:
+ return True
+ return False
+
+ def _end_transaction(self, commit):
+ if commit:
+ self.db.rdatasets = self.rdatasets
+
+ def _set_origin(self, origin):
+ pass
+
+@pytest.fixture
+def db():
+ db = DB()
+ rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content')
+ db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] = rrset
+ return db
+
+def test_basic(db):
+ # successful txn
+ with db.writer() as txn:
+ rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.1', '10.0.0.2')
+ txn.add(rrset)
+ assert txn.name_exists(rrset.name)
+ assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+ rrset
+ # rollback
+ with pytest.raises(Exception):
+ with db.writer() as txn:
+ rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.3', '10.0.0.4')
+ txn.add(rrset2)
+ raise Exception()
+ assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+ rrset
+ with db.writer() as txn:
+ txn.delete(rrset.name)
+ assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \
+ is None
+
+def test_get(db):
+ with db.writer() as txn:
+ content = dns.name.from_text('content', None)
+ rdataset = txn.get(content, dns.rdataclass.IN, dns.rdatatype.TXT)
+ assert rdataset is not None
+ assert rdataset[0].strings == (b'content',)
+ assert isinstance(rdataset, dns.rdataset.ImmutableRdataset)
+
+def test_add(db):
+ with db.writer() as txn:
+ rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.1', '10.0.0.2')
+ txn.add(rrset)
+ rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.3', '10.0.0.4')
+ txn.add(rrset2)
+ expected = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.1', '10.0.0.2',
+ '10.0.0.3', '10.0.0.4')
+ assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+ expected
+
+def test_replacement(db):
+ with db.writer() as txn:
+ rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.1', '10.0.0.2')
+ txn.add(rrset)
+ rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.3', '10.0.0.4')
+ txn.replace(rrset2)
+ assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+ rrset2
+
+def test_delete(db):
+ with db.writer() as txn:
+ txn.delete(dns.name.from_text('nonexistent', None))
+ content = dns.name.from_text('content', None)
+ content2 = dns.name.from_text('content2', None)
+ txn.delete(content)
+ assert not txn.name_exists(content)
+ txn.delete(content2, dns.rdataclass.IN, dns.rdatatype.TXT)
+ rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'new-content')
+ txn.add(rrset)
+ assert txn.name_exists(content)
+ txn.delete(content, dns.rdataclass.IN, dns.rdatatype.TXT)
+ assert not txn.name_exists(content)
+ rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'new-content')
+ txn.delete(rrset)
+ content_keys = [k for k in db.rdatasets if k[0] == content]
+ assert len(content_keys) == 0
+
+def test_delete_exact(db):
+ with db.writer() as txn:
+ rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'bad-content')
+ with pytest.raises(dns.transaction.DeleteNotExact):
+ txn.delete_exact(rrset)
+ rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'bad-content')
+ with pytest.raises(dns.transaction.DeleteNotExact):
+ txn.delete_exact(rrset)
+ with pytest.raises(dns.transaction.DeleteNotExact):
+ txn.delete_exact(rrset.name)
+ with pytest.raises(dns.transaction.DeleteNotExact):
+ txn.delete_exact(rrset.name, dns.rdataclass.IN, dns.rdatatype.TXT)
+ rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content')
+ txn.delete_exact(rrset)
+ assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \
+ is None
+
+def test_parameter_forms(db):
+ with db.writer() as txn:
+ foo = dns.name.from_text('foo', None)
+ rdataset = dns.rdataset.from_text('in', 'a', 300,
+ '10.0.0.1', '10.0.0.2')
+ rdata1 = dns.rdata.from_text('in', 'a', '10.0.0.3')
+ rdata2 = dns.rdata.from_text('in', 'a', '10.0.0.4')
+ txn.add(foo, rdataset)
+ txn.add(foo, 100, rdata1)
+ txn.add(foo, 30, rdata2)
+ expected = dns.rrset.from_text('foo', 30, 'in', 'a',
+ '10.0.0.1', '10.0.0.2',
+ '10.0.0.3', '10.0.0.4')
+ assert db.rdatasets[(foo, rdataset.rdclass, rdataset.rdtype, 0)] == \
+ expected
+ with db.writer() as txn:
+ txn.delete(foo, rdataset)
+ txn.delete(foo, rdata1)
+ txn.delete(foo, rdata2)
+ assert db.rdatasets.get((foo, rdataset.rdclass, rdataset.rdtype, 0)) \
+ is None
+
+def test_bad_parameters(db):
+ with db.writer() as txn:
+ with pytest.raises(TypeError):
+ txn.add(1)
+ with pytest.raises(TypeError):
+ rrset = dns.rrset.from_text('bar', 300, 'in', 'txt', 'bar')
+ txn.add(rrset, 1)
+ with pytest.raises(ValueError):
+ foo = dns.name.from_text('foo', None)
+ rdata = dns.rdata.from_text('in', 'a', '10.0.0.3')
+ txn.add(foo, 0x80000000, rdata)
+ with pytest.raises(TypeError):
+ txn.add(foo)
+ with pytest.raises(TypeError):
+ txn.add()
+ with pytest.raises(TypeError):
+ txn.add(foo, 300)
+ with pytest.raises(TypeError):
+ txn.add(foo, 300, 'hi')
+ with pytest.raises(TypeError):
+ txn.add(foo, 'hi')
+ with pytest.raises(TypeError):
+ txn.delete()
+ with pytest.raises(TypeError):
+ txn.delete(1)
+
+example_text = """$TTL 3600
+$ORIGIN example.
+@ soa foo bar 1 2 3 4 5
+@ ns ns1
+@ ns ns2
+ns1 a 10.0.0.1
+ns2 a 10.0.0.2
+$TTL 300
+$ORIGIN foo.example.
+bar mx 0 blaz
+"""
+
+example_text_output = """@ 3600 IN SOA foo bar 1 2 3 4 5
+@ 3600 IN NS ns1
+@ 3600 IN NS ns2
+@ 3600 IN NS ns3
+ns1 3600 IN A 10.0.0.1
+ns2 3600 IN A 10.0.0.2
+ns3 3600 IN A 10.0.0.3
+"""
+
+@pytest.fixture(params=[dns.zone.Zone, dns.versioned.Zone])
+def zone(request):
+ return dns.zone.from_text(example_text, zone_factory=request.param)
+
+def test_zone_basic(zone):
+ with zone.writer() as txn:
+ txn.delete(dns.name.from_text('bar.foo', None))
+ rd = dns.rdata.from_text('in', 'ns', 'ns3')
+ txn.add(dns.name.empty, 3600, rd)
+ rd = dns.rdata.from_text('in', 'a', '10.0.0.3')
+ txn.add(dns.name.from_text('ns3', None), 3600, rd)
+ output = zone.to_text()
+ assert output == example_text_output
+
+def test_zone_base_layer(zone):
+ with zone.writer() as txn:
+ # Get a set from the zone layer
+ rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+ expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2')
+ assert rdataset == expected
+
+def test_zone_transaction_layer(zone):
+ with zone.writer() as txn:
+ # Make a change
+ rd = dns.rdata.from_text('in', 'ns', 'ns3')
+ txn.add(dns.name.empty, 3600, rd)
+ # Get a set from the transaction layer
+ expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2', 'ns3')
+ rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+ assert rdataset == expected
+ assert txn.name_exists(dns.name.empty)
+ ns1 = dns.name.from_text('ns1', None)
+ assert txn.name_exists(ns1)
+ ns99 = dns.name.from_text('ns99', None)
+ assert not txn.name_exists(ns99)
+
+def test_zone_add_and_delete(zone):
+ with zone.writer() as txn:
+ a99 = dns.name.from_text('a99', None)
+ a100 = dns.name.from_text('a100', None)
+ a101 = dns.name.from_text('a101', None)
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+ txn.add(a99, rds)
+ txn.delete(a99, dns.rdataclass.IN, dns.rdatatype.A)
+ txn.delete(a100, dns.rdataclass.IN, dns.rdatatype.A)
+ txn.delete(a101)
+ assert not txn.name_exists(a99)
+ assert not txn.name_exists(a100)
+ assert not txn.name_exists(a101)
+ ns1 = dns.name.from_text('ns1', None)
+ txn.delete(ns1, dns.rdataclass.IN, dns.rdatatype.A)
+ assert not txn.name_exists(ns1)
+ with zone.writer() as txn:
+ txn.add(a99, rds)
+ txn.delete(a99)
+ assert not txn.name_exists(a99)
+ with zone.writer() as txn:
+ txn.add(a100, rds)
+ txn.delete(a99)
+ assert not txn.name_exists(a99)
+ assert txn.name_exists(a100)
+
+def test_zone_get_deleted(zone):
+ with zone.writer() as txn:
+ print(zone.to_text())
+ ns1 = dns.name.from_text('ns1', None)
+ assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is not None
+ txn.delete(ns1)
+ assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is None
+ ns2 = dns.name.from_text('ns2', None)
+ txn.delete(ns2, dns.rdataclass.IN, dns.rdatatype.A)
+ assert txn.get(ns2, dns.rdataclass.IN, dns.rdatatype.A) is None
+
+def test_zone_bad_class(zone):
+ with zone.writer() as txn:
+ with pytest.raises(ValueError):
+ txn.get(dns.name.empty, dns.rdataclass.CH,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+ rds = dns.rdataset.from_text('ch', 'ns', 300, 'ns1', 'ns2')
+ with pytest.raises(ValueError):
+ txn.add(dns.name.empty, rds)
+ with pytest.raises(ValueError):
+ txn.replace(dns.name.empty, rds)
+ with pytest.raises(ValueError):
+ txn.delete(dns.name.empty, rds)
+ with pytest.raises(ValueError):
+ txn.delete(dns.name.empty, dns.rdataclass.CH,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+
+def test_set_serial(zone):
+ # basic
+ with zone.writer() as txn:
+ txn.set_serial()
+ rdataset = zone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 2
+ # max
+ with zone.writer() as txn:
+ txn.set_serial(0, 0xffffffff)
+ rdataset = zone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 0xffffffff
+ # wraparound to 1
+ with zone.writer() as txn:
+ txn.set_serial()
+ rdataset = zone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 1
+ # trying to set to zero sets to 1
+ with zone.writer() as txn:
+ txn.set_serial(0, 0)
+ rdataset = zone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 1
+ with pytest.raises(KeyError):
+ with zone.writer() as txn:
+ txn.set_serial(name=dns.name.from_text('unknown', None))
+
+class ExpectedException(Exception):
+ pass
+
+def test_zone_rollback(zone):
+ try:
+ with zone.writer() as txn:
+ a99 = dns.name.from_text('a99.example.')
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+ txn.add(a99, rds)
+ assert txn.name_exists(a99)
+ raise ExpectedException
+ except ExpectedException:
+ pass
+ assert not zone.get_node(a99)
+
+def test_zone_ooz_name(zone):
+ with zone.writer() as txn:
+ with pytest.raises(KeyError):
+ a99 = dns.name.from_text('a99.not-example.')
+ assert txn.name_exists(a99)
+
+def test_zone_iteration(zone):
+ expected = {}
+ for (name, rdataset) in zone.iterate_rdatasets():
+ expected[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ with zone.writer() as txn:
+ actual = {}
+ for (name, rdataset) in txn:
+ actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ assert actual == expected
+
+@pytest.fixture
+def vzone():
+ return dns.zone.from_text(example_text, zone_factory=dns.versioned.Zone)
+
+def test_vzone_read_only(vzone):
+ with vzone.reader() as txn:
+ rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+ expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2')
+ assert rdataset == expected
+ with pytest.raises(dns.transaction.ReadOnly):
+ txn.replace(dns.name.empty, expected)
+
+def test_vzone_multiple_versions(vzone):
+ assert len(vzone.versions) == 1
+ vzone.set_max_versions(None) # unlimited!
+ with vzone.writer() as txn:
+ txn.set_serial()
+ with vzone.writer() as txn:
+ txn.set_serial()
+ with vzone.writer() as txn:
+ txn.set_serial()
+ rdataset = vzone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 4
+ assert len(vzone.versions) == 4
+ vzone.set_max_versions(2)
+ assert len(vzone.versions) == 2
+ # The ones that survived should be 3 and 4
+ rdataset = vzone.versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
+ dns.rdatatype.NONE)
+ assert rdataset[0].serial == 3
+ rdataset = vzone.versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
+ dns.rdatatype.NONE)
+ assert rdataset[0].serial == 4
+ with pytest.raises(ValueError):
+ vzone.set_max_versions(0)
+
+try:
+ import threading
+
+ one_got_lock = threading.Event()
+
+ def run_one(zone):
+ with zone.writer() as txn:
+ one_got_lock.set()
+ # wait until two blocks
+ while len(zone._write_waiters) == 0:
+ time.sleep(0.01)
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.98')
+ txn.add('a98', rds)
+
+ def run_two(zone):
+ # wait until one has the lock so we know we will block if we
+ # get the call done before the sleep in one completes
+ one_got_lock.wait()
+ with zone.writer() as txn:
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+ txn.add('a99', rds)
+
+ def test_vzone_concurrency(vzone):
+ t1 = threading.Thread(target=run_one, args=(vzone,))
+ t1.start()
+ t2 = threading.Thread(target=run_two, args=(vzone,))
+ t2.start()
+ t1.join()
+ t2.join()
+ with vzone.reader() as txn:
+ assert txn.name_exists('a98')
+ assert txn.name_exists('a99')
+
+except ImportError: # pragma: no cover
+ pass