summaryrefslogtreecommitdiff
path: root/dns/versioned.py
diff options
context:
space:
mode:
Diffstat (limited to 'dns/versioned.py')
-rw-r--r--dns/versioned.py392
1 files changed, 392 insertions, 0 deletions
diff --git a/dns/versioned.py b/dns/versioned.py
new file mode 100644
index 0000000..6f911e1
--- /dev/null
+++ b/dns/versioned.py
@@ -0,0 +1,392 @@
+# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
+
+"""DNS Versioned Zones."""
+
+import collections
+try:
+ import threading as _threading
+except ImportError: # pragma: no cover
+ import dummy_threading as _threading # type: ignore
+
+import dns.exception
+import dns.immutable
+import dns.name
+import dns.node
+import dns.rdataclass
+import dns.rdatatype
+import dns.rdata
+import dns.rdtypes.ANY.SOA
+import dns.transaction
+import dns.zone
+
+
+class UseTransaction(dns.exception.DNSException):
+ """To alter a versioned zone, use a transaction."""
+
+
+class Version:
+ def __init__(self, zone, id):
+ self.zone = zone
+ self.id = id
+ self.nodes = {}
+
+ def _validate_name(self, name):
+ if name.is_absolute():
+ if not name.is_subdomain(self.zone.origin):
+ raise KeyError("name is not a subdomain of the zone origin")
+ if self.zone.relativize:
+ name = name.relativize(self.origin)
+ return name
+
+ def get_node(self, name):
+ name = self._validate_name(name)
+ return self.nodes.get(name)
+
+ def get_rdataset(self, name, rdtype, covers):
+ node = self.get_node(name)
+ if node is None:
+ return None
+ return node.get_rdataset(self.zone.rdclass, rdtype, covers)
+
+ def items(self):
+ return self.nodes.items() # pylint: disable=dict-items-not-iterating
+
+ def _print(self): # pragma: no cover
+ # XXXRTH This is for debugging
+ print('VERSION', self.id)
+ for (name, node) in self.nodes.items():
+ for rdataset in node:
+ print(rdataset.to_text(name))
+
+
+class WritableVersion(Version):
+ def __init__(self, zone, replacement=False):
+ if len(zone.versions) > 0:
+ id = zone.versions[-1].id + 1
+ else:
+ id = 1
+ super().__init__(zone, id)
+ if not replacement:
+ # We copy the map, because that gives us a simple and thread-safe
+ # way of doing versions, and we have a garbage collector to help
+ # us. We only make new node objects if we actually change the
+ # node.
+ self.nodes.update(zone.nodes)
+ # We have to copy the zone origin as it may be None in the first
+ # version, and we don't want to mutate the zone until we commit.
+ self.origin = zone.origin
+ self.changed = set()
+
+ def _validate_name(self, name):
+ if name.is_absolute():
+ if not name.is_subdomain(self.origin):
+ raise KeyError("name is not a subdomain of the zone origin")
+ if self.zone.relativize:
+ name = name.relativize(self.origin)
+ return name
+
+ def _maybe_cow(self, name):
+ name = self._validate_name(name)
+ node = self.nodes.get(name)
+ if node is None or node.id != self.id:
+ new_node = self.zone.node_factory()
+ new_node.id = self.id
+ if node is not None:
+ # moo! copy on write!
+ new_node.rdatasets.extend(node.rdatasets)
+ self.nodes[name] = new_node
+ self.changed.add(name)
+ return new_node
+ else:
+ return node
+
+ def delete_node(self, name):
+ name = self._validate_name(name)
+ if name in self.nodes:
+ del self.nodes[name]
+ return True
+ return False
+
+ def put_rdataset(self, name, rdataset):
+ node = self._maybe_cow(name)
+ node.replace_rdataset(rdataset)
+
+ def delete_rdataset(self, name, rdtype, covers):
+ node = self._maybe_cow(name)
+ if not node.get_rdataset(self.zone.rdclass, rdtype, covers):
+ return False
+ node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+ if len(node) == 0:
+ del self.nodes[name]
+ return True
+
+
+@dns.immutable.immutable
+class ImmutableVersion(Version):
+ def __init__(self, version):
+ # We tell super() that it's a replacement as we don't want it
+ # to copy the nodes, as we're about to do that with an
+ # immutable Dict.
+ super().__init__(version.zone, True)
+ # set the right id!
+ self.id = version.id
+ # Make changed nodes immutable
+ for name in version.changed:
+ node = version.nodes.get(name)
+ # it might not exist if we deleted it in the version
+ if node:
+ version.nodes[name] = ImmutableNode(node)
+ self.nodes = dns.immutable.Dict(version.nodes, True)
+
+
+# A node with a version id.
+
+class Node(dns.node.Node):
+ __slots__ = ['id']
+
+ def __init__(self):
+ super().__init__()
+ # A proper id will get set by the Version
+ self.id = 0
+
+
+# It would be nice if this were a subclass of Node (just above) but it's
+# less code duplication this way as we inherit all of the method disabling
+# code.
+
+@dns.immutable.immutable
+class ImmutableNode(dns.node.ImmutableNode):
+ __slots__ = ['id']
+
+ def __init__(self, node):
+ super().__init__(node)
+ self.id = node.id
+
+
+class Zone(dns.zone.Zone):
+
+ __slots__ = ['versions', '_write_txn', '_write_waiters', '_write_event',
+ '_pruning_policy']
+
+ node_factory = Node
+
+ def __init__(self, origin, rdclass=dns.rdataclass.IN, relativize=True,
+ pruning_policy=None):
+ """Initialize a versioned zone object.
+
+ *origin* is the origin of the zone. It may be a ``dns.name.Name``,
+ a ``str``, or ``None``. If ``None``, then the zone's origin will
+ be set by the first ``$ORIGIN`` line in a masterfile.
+
+ *rdclass*, an ``int``, the zone's rdata class; the default is class IN.
+
+ *relativize*, a ``bool``, determine's whether domain names are
+ relativized to the zone's origin. The default is ``True``.
+
+ *pruning policy*, a function taking a `Version` and returning
+ a `bool`, or `None`. Should the version be pruned? If `None`,
+ the default policy, which retains one version is used.
+ """
+ super().__init__(origin, rdclass, relativize)
+ self.versions = collections.deque()
+ self.version_lock = _threading.Lock()
+ if pruning_policy is None:
+ self._pruning_policy = self._default_pruning_policy
+ else:
+ self._pruning_policy = pruning_policy
+ self._write_txn = None
+ self._write_event = None
+ self._write_waiters = collections.deque()
+ self._commit_version_unlocked(WritableVersion(self), origin)
+
+ def reader(self):
+ with self.version_lock:
+ return Transaction(False, self, self.versions[-1])
+
+ def writer(self, replacement=False):
+ event = None
+ while True:
+ with self.version_lock:
+ # Checking event == self._write_event ensures that either
+ # no one was waiting before we got lucky and found no write
+ # txn, or we were the one who was waiting and got woken up.
+ # This prevents "taking cuts" when creating a write txn.
+ if self._write_txn is None and event == self._write_event:
+ # Creating the transaction defers version setup
+ # (i.e. copying the nodes dictionary) until we
+ # give up the lock, so that we hold the lock as
+ # short a time as possible. This is why we call
+ # _setup_version() below.
+ self._write_txn = Transaction(replacement, self)
+ # give up our exclusive right to make a Transaction
+ self._write_event = None
+ break
+ # Someone else is writing already, so we will have to
+ # wait, but we want to do the actual wait outside the
+ # lock.
+ event = _threading.Event()
+ self._write_waiters.append(event)
+ # wait (note we gave up the lock!)
+ #
+ # We only wake one sleeper at a time, so it's important
+ # that no event waiter can exit this method (e.g. via
+ # cancelation) without returning a transaction or waking
+ # someone else up.
+ #
+ # This is not a problem with Threading module threads as
+ # they cannot be canceled, but could be an issue with trio
+ # or curio tasks when we do the async version of writer().
+ # I.e. we'd need to do something like:
+ #
+ # try:
+ # event.wait()
+ # except trio.Cancelled:
+ # with self.version_lock:
+ # self._maybe_wakeup_one_waiter_unlocked()
+ # raise
+ #
+ event.wait()
+ # Do the deferred version setup.
+ self._write_txn._setup_version()
+ return self._write_txn
+
+ def _maybe_wakeup_one_waiter_unlocked(self):
+ if len(self._write_waiters) > 0:
+ self._write_event = self._write_waiters.popleft()
+ self._write_event.set()
+
+ # pylint: disable=unused-argument
+ def _default_pruning_policy(self, zone, version):
+ return True
+ # pylint: enable=unused-argument
+
+ def _prune_versions_unlocked(self):
+ while len(self.versions) > 1 and \
+ self._pruning_policy(self, self.versions[0]):
+ self.versions.popleft()
+
+ def set_max_versions(self, max_versions):
+ """Set a pruning policy that retains up to the specified number
+ of versions
+ """
+ if max_versions is not None and max_versions < 1:
+ raise ValueError('max versions must be at least 1')
+ if max_versions is None:
+ def policy(*_):
+ return False
+ else:
+ def policy(zone, _):
+ return len(zone.versions) > max_versions
+ self.set_pruning_policy(policy)
+
+ def set_pruning_policy(self, policy):
+ """Set the pruning policy for the zone.
+
+ The *policy* function takes a `Version` and returns `True` if
+ the version should be pruned, and `False` otherwise. `None`
+ may also be specified for policy, in which case the default policy
+ is used.
+
+ Pruning checking proceeds from the least version and the first
+ time the function returns `False`, the checking stops. I.e. the
+ retained versions are always a consecutive sequence.
+ """
+ if policy is None:
+ policy = self._default_pruning_policy
+ with self.version_lock:
+ self._pruning_policy = policy
+ self._prune_versions_unlocked()
+
+ def _commit_version_unlocked(self, version, origin):
+ self.versions.append(version)
+ self._prune_versions_unlocked()
+ self.nodes = version.nodes
+ if self.origin is None:
+ self.origin = origin
+ self._write_txn = None
+ self._maybe_wakeup_one_waiter_unlocked()
+
+ def _commit_version(self, version, origin):
+ with self.version_lock:
+ self._commit_version_unlocked(version, origin)
+
+ def find_node(self, name, create=False):
+ if create:
+ raise UseTransaction
+ return super().find_node(name)
+
+ def delete_node(self, name):
+ raise UseTransaction
+
+ def find_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise UseTransaction
+ rdataset = super().find_rdataset(name, rdtype, covers)
+ return dns.rdataset.ImmutableRdataset(rdataset)
+
+ def get_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise UseTransaction
+ rdataset = super().get_rdataset(name, rdtype, covers)
+ return dns.rdataset.ImmutableRdataset(rdataset)
+
+ def delete_rdataset(self, name, rdtype, covers=dns.rdatatype.NONE):
+ raise UseTransaction
+
+ def replace_rdataset(self, name, replacement):
+ raise UseTransaction
+
+
+class Transaction(dns.transaction.Transaction):
+
+ def __init__(self, replacement, zone, version=None):
+ read_only = version is not None
+ super().__init__(replacement, read_only)
+ self.zone = zone
+ self.version = version
+
+ def _setup_version(self):
+ assert self.version is None
+ self.version = WritableVersion(self.zone, self.replacement)
+
+ def _get_rdataset(self, name, rdclass, rdtype, covers):
+ if rdclass != self.zone.rdclass:
+ raise ValueError(f'class {rdclass} != ' +
+ f'zone class {self.zone.rdclass}')
+ return self.version.get_rdataset(name, rdtype, covers)
+
+ def _put_rdataset(self, name, rdataset):
+ assert not self.read_only
+ if rdataset.rdclass != self.zone.rdclass:
+ raise ValueError(f'rdataset class {rdataset.rdclass} != ' +
+ f'zone class {self.zone.rdclass}')
+ self.version.put_rdataset(name, rdataset)
+
+ def _delete_name(self, name):
+ assert not self.read_only
+ self.version.delete_node(name)
+
+ def _delete_rdataset(self, name, rdclass, rdtype, covers):
+ assert not self.read_only
+ self.version.delete_rdataset(name, rdtype, covers)
+
+ def _name_exists(self, name):
+ return self.version.get_node(name) is not None
+
+ def _end_transaction(self, commit):
+ if self.read_only:
+ return
+ if commit and len(self.version.changed) > 0:
+ self.zone._commit_version(ImmutableVersion(self.version),
+ self.version.origin)
+
+ def _set_origin(self, origin):
+ if self.version.origin is None:
+ self.version.origin = origin
+
+ def _iterate_rdatasets(self):
+ for (name, node) in self.version.items():
+ for rdataset in node:
+ yield (name, rdataset)