summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2021-12-01 06:48:58 -0800
committerBob Halley <halley@dnspython.org>2021-12-01 06:48:58 -0800
commit9a16076cb3b7d36efaabff030688693dd56f0ee6 (patch)
tree173052cae1ce98b0147d84943aaa26cd47e713df
parentc706e26e990856311e35f46cb58eaf333e80ed2f (diff)
downloaddnspython-zone-refactor.tar.gz
Refactor zone transactions to always use versioned CoW code.zone-refactor
-rw-r--r--dns/versioned.py219
-rw-r--r--dns/zone.py277
2 files changed, 225 insertions, 271 deletions
diff --git a/dns/versioned.py b/dns/versioned.py
index 686a83b..42f2c81 100644
--- a/dns/versioned.py
+++ b/dns/versioned.py
@@ -11,12 +11,9 @@ except ImportError: # pragma: no cover
import dns.exception
import dns.immutable
import dns.name
-import dns.node
import dns.rdataclass
import dns.rdatatype
-import dns.rdata
import dns.rdtypes.ANY.SOA
-import dns.transaction
import dns.zone
@@ -24,142 +21,13 @@ class UseTransaction(dns.exception.DNSException):
"""To alter a versioned zone, use a transaction."""
-class Version:
- def __init__(self, zone, id):
- self.zone = zone
- self.id = id
- self.nodes = {}
-
- def _validate_name(self, name):
- if name.is_absolute():
- if not name.is_subdomain(self.zone.origin):
- raise KeyError("name is not a subdomain of the zone origin")
- if self.zone.relativize:
- name = name.relativize(self.origin)
- return name
-
- def get_node(self, name):
- name = self._validate_name(name)
- return self.nodes.get(name)
-
- def get_rdataset(self, name, rdtype, covers):
- node = self.get_node(name)
- if node is None:
- return None
- return node.get_rdataset(self.zone.rdclass, rdtype, covers)
-
- def items(self):
- return self.nodes.items() # pylint: disable=dict-items-not-iterating
-
-
-class WritableVersion(Version):
- def __init__(self, zone, replacement=False):
- # The zone._versions_lock must be held by our caller.
- if len(zone._versions) > 0:
- id = zone._versions[-1].id + 1
- else:
- id = 1
- super().__init__(zone, id)
- if not replacement:
- # We copy the map, because that gives us a simple and thread-safe
- # way of doing versions, and we have a garbage collector to help
- # us. We only make new node objects if we actually change the
- # node.
- self.nodes.update(zone.nodes)
- # We have to copy the zone origin as it may be None in the first
- # version, and we don't want to mutate the zone until we commit.
- self.origin = zone.origin
- self.changed = set()
-
- def _maybe_cow(self, name):
- name = self._validate_name(name)
- node = self.nodes.get(name)
- if node is None or node.id != self.id:
- new_node = self.zone.node_factory()
- new_node.id = self.id
- if node is not None:
- # moo! copy on write!
- new_node.rdatasets.extend(node.rdatasets)
- self.nodes[name] = new_node
- self.changed.add(name)
- return new_node
- else:
- return node
-
- def delete_node(self, name):
- name = self._validate_name(name)
- if name in self.nodes:
- del self.nodes[name]
- self.changed.add(name)
-
- def put_rdataset(self, name, rdataset):
- node = self._maybe_cow(name)
- node.replace_rdataset(rdataset)
-
- def delete_rdataset(self, name, rdtype, covers):
- node = self._maybe_cow(name)
- node.delete_rdataset(self.zone.rdclass, rdtype, covers)
- if len(node) == 0:
- del self.nodes[name]
-
-
-@dns.immutable.immutable
-class ImmutableVersion(Version):
- def __init__(self, version):
- # We tell super() that it's a replacement as we don't want it
- # to copy the nodes, as we're about to do that with an
- # immutable Dict.
- super().__init__(version.zone, True)
- # set the right id!
- self.id = version.id
- # Make changed nodes immutable
- for name in version.changed:
- node = version.nodes.get(name)
- # it might not exist if we deleted it in the version
- if node:
- version.nodes[name] = ImmutableNode(node)
- self.nodes = dns.immutable.Dict(version.nodes, True)
-
-
-# A node with a version id.
-
-class Node(dns.node.Node):
- __slots__ = ['id']
-
- def __init__(self):
- super().__init__()
- # A proper id will get set by the Version
- self.id = 0
-
-
-@dns.immutable.immutable
-class ImmutableNode(Node):
- __slots__ = ['id']
-
- def __init__(self, node):
- super().__init__()
- self.id = node.id
- self.rdatasets = tuple(
- [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
- )
-
- def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
- create=False):
- if create:
- raise TypeError("immutable")
- return super().find_rdataset(rdclass, rdtype, covers, False)
-
- def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
- create=False):
- if create:
- raise TypeError("immutable")
- return super().get_rdataset(rdclass, rdtype, covers, False)
-
- def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
- raise TypeError("immutable")
-
- def replace_rdataset(self, replacement):
- raise TypeError("immutable")
+# Backwards compatibility
+Node = dns.zone.VersionedNode
+ImmutableNode = dns.zone.ImmutableVersionedNode
+Version = dns.zone.Version
+WritableVersion = dns.zone.WritableVersion
+ImmutableVersion = dns.zone.ImmutableVersion
+Transaction = dns.zone.Transaction
class Zone(dns.zone.Zone):
@@ -198,7 +66,9 @@ class Zone(dns.zone.Zone):
self._write_event = None
self._write_waiters = collections.deque()
self._readers = set()
- self._commit_version_unlocked(None, WritableVersion(self), origin)
+ self._commit_version_unlocked(None,
+ WritableVersion(self, replacement=True),
+ origin)
def reader(self, id=None, serial=None): # pylint: disable=arguments-differ
if id is not None and serial is not None:
@@ -247,7 +117,8 @@ class Zone(dns.zone.Zone):
# give up the lock, so that we hold the lock as
# short a time as possible. This is why we call
# _setup_version() below.
- self._write_txn = Transaction(self, replacement)
+ self._write_txn = Transaction(self, replacement,
+ make_immutable=True)
# give up our exclusive right to make a Transaction
self._write_event = None
break
@@ -367,6 +238,13 @@ class Zone(dns.zone.Zone):
with self._version_lock:
self._commit_version_unlocked(txn, version, origin)
+ def _get_next_version_id(self):
+ if len(self._versions) > 0:
+ id = self._versions[-1].id + 1
+ else:
+ id = 1
+ return id
+
def find_node(self, name, create=False):
if create:
raise UseTransaction
@@ -394,62 +272,3 @@ class Zone(dns.zone.Zone):
def replace_rdataset(self, name, replacement):
raise UseTransaction
-
-
-class Transaction(dns.transaction.Transaction):
-
- def __init__(self, zone, replacement, version=None):
- read_only = version is not None
- super().__init__(zone, replacement, read_only)
- self.version = version
-
- @property
- def zone(self):
- return self.manager
-
- def _setup_version(self):
- assert self.version is None
- self.version = WritableVersion(self.zone, self.replacement)
-
- def _get_rdataset(self, name, rdtype, covers):
- return self.version.get_rdataset(name, rdtype, covers)
-
- def _put_rdataset(self, name, rdataset):
- assert not self.read_only
- self.version.put_rdataset(name, rdataset)
-
- def _delete_name(self, name):
- assert not self.read_only
- self.version.delete_node(name)
-
- def _delete_rdataset(self, name, rdtype, covers):
- assert not self.read_only
- self.version.delete_rdataset(name, rdtype, covers)
-
- def _name_exists(self, name):
- return self.version.get_node(name) is not None
-
- def _changed(self):
- if self.read_only:
- return False
- else:
- return len(self.version.changed) > 0
-
- def _end_transaction(self, commit):
- if self.read_only:
- self.zone._end_read(self)
- elif commit and len(self.version.changed) > 0:
- self.zone._commit_version(self, ImmutableVersion(self.version),
- self.version.origin)
- else:
- # rollback
- self.zone._end_write(self)
-
- def _set_origin(self, origin):
- if self.version.origin is None:
- self.version.origin = origin
-
- def _iterate_rdatasets(self):
- for (name, node) in self.version.items():
- for rdataset in node:
- yield (name, rdataset)
diff --git a/dns/zone.py b/dns/zone.py
index 2f99b1b..510be2d 100644
--- a/dns/zone.py
+++ b/dns/zone.py
@@ -24,6 +24,7 @@ import os
import struct
import dns.exception
+import dns.immutable
import dns.name
import dns.node
import dns.rdataclass
@@ -772,10 +773,13 @@ class Zone(dns.transaction.TransactionManager):
# TransactionManager methods
def reader(self):
- return Transaction(self, False, True)
+ return Transaction(self, False,
+ Version(self, 1, self.nodes, self.origin))
def writer(self, replacement=False):
- return Transaction(self, replacement, False)
+ txn = Transaction(self, replacement)
+ txn._setup_version()
+ return txn
def origin_information(self):
if self.relativize:
@@ -787,107 +791,238 @@ class Zone(dns.transaction.TransactionManager):
def get_class(self):
return self.rdclass
+ # Transaction methods
-class Transaction(dns.transaction.Transaction):
+ def _end_read(self, txn):
+ pass
+
+ def _end_write(self, txn):
+ pass
+
+ def _commit_version(self, txn, version, origin):
+ self.nodes = version.nodes
+ if self.origin is None:
+ self.origin = origin
+
+ def _get_next_version_id(self):
+ # Versions are ephemeral and all have id 1
+ return 1
+
+
+# These classes used to be in dns.versioned, but have moved here so we can use
+# the copy-on-write transaction mechanism for both kinds of zones. In a
+# regular zone, the version only exists during the transaction, and the nodes
+# are regular dns.node.Nodes.
+
+# A node with a version id.
+
+class VersionedNode(dns.node.Node):
+ __slots__ = ['id']
+
+ def __init__(self):
+ super().__init__()
+ # A proper id will get set by the Version
+ self.id = 0
+
+
+@dns.immutable.immutable
+class ImmutableVersionedNode(VersionedNode):
+ __slots__ = ['id']
+
+ def __init__(self, node):
+ super().__init__()
+ self.id = node.id
+ self.rdatasets = tuple(
+ [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
+ )
+
+ def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise TypeError("immutable")
+ return super().find_rdataset(rdclass, rdtype, covers, False)
+
+ def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
+ create=False):
+ if create:
+ raise TypeError("immutable")
+ return super().get_rdataset(rdclass, rdtype, covers, False)
- _deleted_rdataset = dns.rdataset.Rdataset(dns.rdataclass.ANY,
- dns.rdatatype.ANY)
+ def delete_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE):
+ raise TypeError("immutable")
+
+ def replace_rdataset(self, replacement):
+ raise TypeError("immutable")
+
+
+class Version:
+ def __init__(self, zone, id, nodes=None, origin=None):
+ self.zone = zone
+ self.id = id
+ if nodes is not None:
+ self.nodes = nodes
+ else:
+ self.nodes = {}
+ self.origin = origin
+
+ def _validate_name(self, name):
+ if name.is_absolute():
+ if not name.is_subdomain(self.zone.origin):
+ raise KeyError("name is not a subdomain of the zone origin")
+ if self.zone.relativize:
+ # XXXRTH should it be an error if self.origin is still None?
+ name = name.relativize(self.origin)
+ return name
+
+ def get_node(self, name):
+ name = self._validate_name(name)
+ return self.nodes.get(name)
+
+ def get_rdataset(self, name, rdtype, covers):
+ node = self.get_node(name)
+ if node is None:
+ return None
+ return node.get_rdataset(self.zone.rdclass, rdtype, covers)
- def __init__(self, zone, replacement, read_only):
+ def items(self):
+ return self.nodes.items() # pylint: disable=dict-items-not-iterating
+
+
+class WritableVersion(Version):
+ def __init__(self, zone, replacement=False):
+ # The zone._versions_lock must be held by our caller in a versioned
+ # zone.
+ id = zone._get_next_version_id()
+ super().__init__(zone, id)
+ if not replacement:
+ # We copy the map, because that gives us a simple and thread-safe
+ # way of doing versions, and we have a garbage collector to help
+ # us. We only make new node objects if we actually change the
+ # node.
+ self.nodes.update(zone.nodes)
+ # We have to copy the zone origin as it may be None in the first
+ # version, and we don't want to mutate the zone until we commit.
+ self.origin = zone.origin
+ self.changed = set()
+
+ def _maybe_cow(self, name):
+ name = self._validate_name(name)
+ node = self.nodes.get(name)
+ if node is None or name not in self.changed:
+ new_node = self.zone.node_factory()
+ if hasattr(new_node, 'id'):
+ # We keep doing this for backwards compatibility, as earlier
+ # code used new_node.id != self.id for the "do we need to CoW?"
+ # test. Now we use the changed set as this works with both
+ # regular zones and versioned zones.
+ new_node.id = self.id
+ if node is not None:
+ # moo! copy on write!
+ new_node.rdatasets.extend(node.rdatasets)
+ self.nodes[name] = new_node
+ self.changed.add(name)
+ return new_node
+ else:
+ return node
+
+ def delete_node(self, name):
+ name = self._validate_name(name)
+ if name in self.nodes:
+ del self.nodes[name]
+ self.changed.add(name)
+
+ def put_rdataset(self, name, rdataset):
+ node = self._maybe_cow(name)
+ node.replace_rdataset(rdataset)
+
+ def delete_rdataset(self, name, rdtype, covers):
+ node = self._maybe_cow(name)
+ node.delete_rdataset(self.zone.rdclass, rdtype, covers)
+ if len(node) == 0:
+ del self.nodes[name]
+
+
+@dns.immutable.immutable
+class ImmutableVersion(Version):
+ def __init__(self, version):
+ # We tell super() that it's a replacement as we don't want it
+ # to copy the nodes, as we're about to do that with an
+ # immutable Dict.
+ super().__init__(version.zone, True)
+ # set the right id!
+ self.id = version.id
+ # keep the origin
+ self.origin = version.origin
+ # Make changed nodes immutable
+ for name in version.changed:
+ node = version.nodes.get(name)
+ # it might not exist if we deleted it in the version
+ if node:
+ version.nodes[name] = ImmutableVersionedNode(node)
+ self.nodes = dns.immutable.Dict(version.nodes, True)
+
+
+class Transaction(dns.transaction.Transaction):
+
+ def __init__(self, zone, replacement, version=None, make_immutable=False):
+ read_only = version is not None
super().__init__(zone, replacement, read_only)
- self.rdatasets = {}
+ self.version = version
+ self.make_immutable = make_immutable
@property
def zone(self):
return self.manager
+ def _setup_version(self):
+ assert self.version is None
+ self.version = WritableVersion(self.zone, self.replacement)
+
def _get_rdataset(self, name, rdtype, covers):
- rdataset = self.rdatasets.get((name, rdtype, covers))
- if rdataset is self._deleted_rdataset:
- return None
- elif rdataset is None and not self.replacement:
- rdataset = self.zone.get_rdataset(name, rdtype, covers)
- return rdataset
+ return self.version.get_rdataset(name, rdtype, covers)
def _put_rdataset(self, name, rdataset):
assert not self.read_only
- self.zone._validate_name(name)
- self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ self.version.put_rdataset(name, rdataset)
def _delete_name(self, name):
assert not self.read_only
- # First remove any changes involving the name
- remove = []
- for key in self.rdatasets:
- if key[0] == name:
- remove.append(key)
- if len(remove) > 0:
- for key in remove:
- del self.rdatasets[key]
- # Next add deletion records for any rdatasets matching the
- # name in the zone
- node = self.zone.get_node(name)
- if node is not None:
- for rdataset in node.rdatasets:
- self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
- self._deleted_rdataset
+ self.version.delete_node(name)
def _delete_rdataset(self, name, rdtype, covers):
assert not self.read_only
- try:
- del self.rdatasets[(name, rdtype, covers)]
- except KeyError:
- pass
- rdataset = self.zone.get_rdataset(name, rdtype, covers)
- if rdataset is not None:
- self.rdatasets[(name, rdataset.rdtype, rdataset.covers)] = \
- self._deleted_rdataset
+ self.version.delete_rdataset(name, rdtype, covers)
def _name_exists(self, name):
- for key, rdataset in self.rdatasets.items():
- if key[0] == name:
- if rdataset != self._deleted_rdataset:
- return True
- else:
- return None
- self.zone._validate_name(name)
- if self.zone.get_node(name):
- return True
- return False
+ return self.version.get_node(name) is not None
def _changed(self):
if self.read_only:
return False
else:
- return len(self.rdatasets) > 0
+ return len(self.version.changed) > 0
def _end_transaction(self, commit):
- if commit and self._changed():
- if self.replacement:
- self.zone.nodes = {}
- for (name, rdtype, covers), rdataset in \
- self.rdatasets.items():
- if rdataset is self._deleted_rdataset:
- self.zone.delete_rdataset(name, rdtype, covers)
- else:
- self.zone.replace_rdataset(name, rdataset)
+ if self.read_only:
+ self.zone._end_read(self)
+ elif commit and len(self.version.changed) > 0:
+ if self.make_immutable:
+ version = ImmutableVersion(self.version)
+ else:
+ version = self.version
+ self.zone._commit_version(self, version, self.version.origin)
+ else:
+ # rollback
+ self.zone._end_write(self)
def _set_origin(self, origin):
- if self.zone.origin is None:
- self.zone.origin = origin
+ if self.version.origin is None:
+ self.version.origin = origin
def _iterate_rdatasets(self):
- # Expensive but simple! Use a versioned zone for efficient txn
- # iteration.
- if self.replacement:
- rdatasets = self.rdatasets
- else:
- rdatasets = {}
- for (name, rdataset) in self.zone.iterate_rdatasets():
- rdatasets[(name, rdataset.rdtype, rdataset.covers)] = rdataset
- rdatasets.update(self.rdatasets)
- for (name, _, _), rdataset in rdatasets.items():
- yield (name, rdataset)
+ for (name, node) in self.version.items():
+ for rdataset in node:
+ yield (name, rdataset)
def from_text(text, origin=None, rdclass=dns.rdataclass.IN,