summaryrefslogtreecommitdiff
path: root/dns/zone.py
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2021-12-01 06:48:58 -0800
committerBob Halley <halley@dnspython.org>2021-12-01 06:48:58 -0800
commit9a16076cb3b7d36efaabff030688693dd56f0ee6 (patch)
tree173052cae1ce98b0147d84943aaa26cd47e713df /dns/zone.py
parentc706e26e990856311e35f46cb58eaf333e80ed2f (diff)
downloaddnspython-zone-refactor.tar.gz
Refactor zone transactions to always use versioned CoW code.zone-refactor
Diffstat (limited to 'dns/zone.py')
-rw-r--r--dns/zone.py277
1 files changed, 206 insertions, 71 deletions
diff --git a/dns/zone.py b/dns/zone.py
index 2f99b1b..510be2d 100644
--- a/dns/zone.py
+++ b/dns/zone.py
@@ -24,6 +24,7 @@ import os
import struct
import dns.exception
+import dns.immutable
import dns.name
import dns.node
import dns.rdataclass
@@ -772,10 +773,13 @@ class Zone(dns.transaction.TransactionManager):
# TransactionManager methods
def reader(self):
- return Transaction(self, False, True)
+ return Transaction(self, False,
+ Version(self, 1, self.nodes, self.origin))
def writer(self, replacement=False):
- return Transaction(self, replacement, False)
+ txn = Transaction(self, replacement)
+ txn._setup_version()
+ return txn
def origin_information(self):
if self.relativize:
@@ -787,107 +791,238 @@ class Zone(dns.transaction.TransactionManager):
def get_class(self):
return self.rdclass
+ # Transaction methods
-class Transaction(dns.transaction.Transaction):
+ def _end_read(self, txn):
+ pass
+
+ def _end_write(self, txn):
+ pass
+
+ def _commit_version(self, txn, version, origin):
+ self.nodes = version.nodes
+ if self.origin is None:
+ self.origin = origin
+
+ def _get_next_version_id(self):
+ # Versions are ephemeral and all have id 1
+ return 1
+
+
+# These classes used to be in dns.versioned, but have moved here so we can use
+# the copy-on-write transaction mechanism for both kinds of zones. In a
+# regular zone, the version only exists during the transaction, and the nodes
+# are regular dns.node.Nodes.
+
+# A node with a version id.
+
+class VersionedNode(dns.node.Node):
+ __slots__ = ['id']
+
+ def __init__(self):
+ super().__init__()
+ # A proper id will get set by the Version
+ self.id = 0
+
+
+@dns.immutable.immutable
+class ImmutableVersionedNode(VersionedNode):
+ __slots__ = ['id']
+
+ def __init__(self, node):
+ super().__init__()
+ self.id = node.id
+ self.rdatasets = tuple(
+ [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+ )
+
+ def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise TypeError("immutable")
+ return super().find_rdataset(rdclass, rdtype, covers, False)
+
+ def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise TypeError("immutable")
+ return super().get_rdataset(rdclass, rdtype, covers, False)
- _deleted_rdataset = dns.rdataset.Rdataset(dns.rdataclass.ANY,
- dns.rdatatype.ANY)
+ def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+ raise TypeError("immutable")
+
+ def replace_rdataset(self, replacement):
+ raise TypeError("immutable")
+
+
+class Version:
+ def __init__(self, zone, id, nodes=None, origin=None):
+ self.zone = zone
+ self.id = id
+ if nodes is not None:
+ self.nodes = nodes
+ else:
+ self.nodes = {}
+ self.origin = origin
+
+ def _validate_name(self, name):
+ if name.is_absolute():
+ if not name.is_subdomain(self.zone.origin):
+ raise KeyError("name is not a subdomain of the zone origin")
+ if self.zone.relativize:
+ # XXXRTH should it be an error if self.origin is still None?
+ name = name.relativize(self.origin)
+ return name
+
+ def get_node(self, name):
+ name = self._validate_name(name)
+ return self.nodes.get(name)
+
+ def get_rdataset(self, name, rdtype, covers):
+ node = self.get_node(name)
+ if node is None:
+ return None
+ return node.get_rdataset(self.zone.rdclass, rdtype, covers)
- def __init__(self, zone, replacement, read_only):
+ def items(self):
+ return self.nodes.items() # pylint: disable=dict-items-not-iterating
+
+
+class WritableVersion(Version):
+ def __init__(self, zone, replacement=False):
+ # The zone._versions_lock must be held by our caller in a versioned
+ # zone.
+ id = zone._get_next_version_id()
+ super().__init__(zone, id)
+ if not replacement:
+ # We copy the map, because that gives us a simple and thread-safe
+ # way of doing versions, and we have a garbage collector to help
+ # us. We only make new node objects if we actually change the
+ # node.
+ self.nodes.update(zone.nodes)
+ # We have to copy the zone origin as it may be None in the first
+ # version, and we don't want to mutate the zone until we commit.
+ self.origin = zone.origin
+ self.changed = set()
+
+ def _maybe_cow(self, name):
+ name = self._validate_name(name)
+ node = self.nodes.get(name)
+ if node is None or name not in self.changed:
+ new_node = self.zone.node_factory()
+ if hasattr(new_node, 'id'):
+ # We keep doing this for backwards compatibility, as earlier
+ # code used new_node.id != self.id for the "do we need to CoW?"
+ # test. Now we use the changed set as this works with both
+ # regular zones and versioned zones.
+ new_node.id = self.id
+ if node is not None:
+ # moo! copy on write!
+ new_node.rdatasets.extend(node.rdatasets)
+ self.nodes[name] = new_node
+ self.changed.add(name)
+ return new_node
+ else:
+ return node
+
+ def delete_node(self, name):
+ name = self._validate_name(name)
+ if name in self.nodes:
+ del self.nodes[name]
+ self.changed.add(name)
+
+ def put_rdataset(self, name, rdataset):
+ node = self._maybe_cow(name)
+ node.replace_rdataset(rdataset)
+
+ def delete_rdataset(self, name, rdtype, covers):
+ node = self._maybe_cow(name)
+ node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+ if len(node) == 0:
+ del self.nodes[name]
+
+
+@dns.immutable.immutable
+class ImmutableVersion(Version):
+ def __init__(self, version):
+ # We tell super() that it's a replacement as we don't want it
+ # to copy the nodes, as we're about to do that with an
+ # immutable Dict.
+ super().__init__(version.zone, True)
+ # set the right id!
+ self.id = version.id
+ # keep the origin
+ self.origin = version.origin
+ # Make changed nodes immutable
+ for name in version.changed:
+ node = version.nodes.get(name)
+ # it might not exist if we deleted it in the version
+ if node:
+ version.nodes[name] = ImmutableVersionedNode(node)
+ self.nodes = dns.immutable.Dict(version.nodes, True)
+
+
+class Transaction(dns.transaction.Transaction):
+
+ def __init__(self, zone, replacement, version=None, make_immutable=False):
+ read_only = version is not None
super().__init__(zone, replacement, read_only)
- self.rdatasets = {}
+ self.version = version
+ self.make_immutable = make_immutable
@property
def zone(self):
return self.manager
+ def _setup_version(self):
+ assert self.version is None
+ self.version = WritableVersion(self.zone, self.replacement)
+
def _get_rdataset(self, name, rdtype, covers):
- rdataset = self.rdatasets.get((name, rdtype, covers))
- if rdataset is self._deleted_rdataset:
- return None
- elif rdataset is None and not self.replacement:
- rdataset = self.zone.get_rdataset(name, rdtype, covers)
- return rdataset
+ return self.version.get_rdataset(name, rdtype, covers)
def _put_rdataset(self, name, rdataset):
assert not self.read_only
- self.zone._validate_name(name)
- self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ self.version.put_rdataset(name, rdataset)
def _delete_name(self, name):
assert not self.read_only
- # First remove any changes involving the name
- remove = []
- for key in self.rdatasets:
- if key[0] == name:
- remove.append(key)
- if len(remove) > 0:
- for key in remove:
- del self.rdatasets[key]
- # Next add deletion records for any rdatasets matching the
- # name in the zone
- node = self.zone.get_node(name)
- if node is not None:
- for rdataset in node.rdatasets:
- self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
- self._deleted_rdataset
+ self.version.delete_node(name)
def _delete_rdataset(self, name, rdtype, covers):
assert not self.read_only
- try:
- del self.rdatasets[(name, rdtype, covers)]
- except KeyError:
- pass
- rdataset = self.zone.get_rdataset(name, rdtype, covers)
- if rdataset is not None:
- self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
- self._deleted_rdataset
+ self.version.delete_rdataset(name, rdtype, covers)
def _name_exists(self, name):
- for key, rdataset in self.rdatasets.items():
- if key[0] == name:
- if rdataset != self._deleted_rdataset:
- return True
- else:
- return None
- self.zone._validate_name(name)
- if self.zone.get_node(name):
- return True
- return False
+ return self.version.get_node(name) is not None
def _changed(self):
if self.read_only:
return False
else:
- return len(self.rdatasets) > 0
+ return len(self.version.changed) > 0
def _end_transaction(self, commit):
- if commit and self._changed():
- if self.replacement:
- self.zone.nodes = {}
- for (name, rdtype, covers), rdataset in \
- self.rdatasets.items():
- if rdataset is self._deleted_rdataset:
- self.zone.delete_rdataset(name, rdtype, covers)
- else:
- self.zone.replace_rdataset(name, rdataset)
+ if self.read_only:
+ self.zone._end_read(self)
+ elif commit and len(self.version.changed) > 0:
+ if self.make_immutable:
+ version = ImmutableVersion(self.version)
+ else:
+ version = self.version
+ self.zone._commit_version(self, version, self.version.origin)
+ else:
+ # rollback
+ self.zone._end_write(self)
def _set_origin(self, origin):
- if self.zone.origin is None:
- self.zone.origin = origin
+ if self.version.origin is None:
+ self.version.origin = origin
def _iterate_rdatasets(self):
- # Expensive but simple! Use a versioned zone for efficient txn
- # iteration.
- if self.replacement:
- rdatasets = self.rdatasets
- else:
- rdatasets = {}
- for (name, rdataset) in self.zone.iterate_rdatasets():
- rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
- rdatasets.update(self.rdatasets)
- for (name, _, _), rdataset in rdatasets.items():
- yield (name, rdataset)
+ for (name, node) in self.version.items():
+ for rdataset in node:
+ yield (name, rdataset)
def from_text(text, origin=None, rdclass=dns.rdataclass.IN,