summaryrefslogtreecommitdiff
path: root/dns
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-08-12 06:59:59 -0700
committerBob Halley <halley@dnspython.org>2020-08-12 06:59:59 -0700
commit9c91b29cb4505650b02cc45cb7389e232ecb9af1 (patch)
tree020540c6548c63343f09c89813eed04d4490c12a /dns
parentf64a5af8b4a6e7ef6c7985b281236a229195e556 (diff)
downloaddnspython-9c91b29cb4505650b02cc45cb7389e232ecb9af1.tar.gz
If we rollback a write, release the write txn and wake someone up.
Don't allow pruning to prune any version >= the version of an active reader. (This isn't a bug fix as the reader was safe before, but this ensures that the reader can open a successor version if needed.)
Diffstat (limited to 'dns')
-rw-r--r--dns/versioned.py58
1 files changed, 46 insertions, 12 deletions
diff --git a/dns/versioned.py b/dns/versioned.py
index ae921f1..e753438 100644
--- a/dns/versioned.py
+++ b/dns/versioned.py
@@ -169,7 +169,8 @@ class ImmutableNode(Node):
class Zone(dns.zone.Zone):
__slots__ = ['_versions', '_versions_lock', '_write_txn',
- '_write_waiters', '_write_event', '_pruning_policy']
+ '_write_waiters', '_write_event', '_pruning_policy',
+ '_readers']
node_factory = Node
@@ -200,7 +201,8 @@ class Zone(dns.zone.Zone):
self._write_txn = None
self._write_event = None
self._write_waiters = collections.deque()
- self._commit_version_unlocked(WritableVersion(self), origin)
+ self._readers = set()
+ self._commit_version_unlocked(None, WritableVersion(self), origin)
def reader(self, id=None, serial=None): # pylint: disable=arguments-differ
if id is not None and serial is not None:
@@ -231,7 +233,9 @@ class Zone(dns.zone.Zone):
raise KeyError('serial not found')
else:
version = self._versions[-1]
- return Transaction(False, self, version)
+ txn = Transaction(False, self, version)
+ self._readers.add(txn)
+ return txn
def writer(self, replacement=False):
event = None
@@ -291,7 +295,19 @@ class Zone(dns.zone.Zone):
# pylint: enable=unused-argument
def _prune_versions_unlocked(self):
- while len(self._versions) > 1 and \
+ assert len(self._versions) > 0
+ # Don't ever prune a version greater than or equal to one that
+ # a reader has open. This pins versions in memory while the
+ # reader is open, and importantly lets the reader open a txn on
+ # a successor version (e.g. if generating an IXFR).
+ #
+ # Note our definition of least_kept also ensures we do not try to
+ # delete the greatest version.
+ if len(self._readers) > 0:
+ least_kept = min(txn.version.id for txn in self._readers)
+ else:
+ least_kept = self._versions[-1].id
+ while self._versions[0].id < least_kept and \
self._pruning_policy(self, self._versions[0]):
self._versions.popleft()
@@ -327,18 +343,33 @@ class Zone(dns.zone.Zone):
self._pruning_policy = policy
self._prune_versions_unlocked()
- def _commit_version_unlocked(self, version, origin):
+ def _end_read(self, txn):
+ with self._version_lock:
+ self._readers.remove(txn)
+ self._prune_versions_unlocked()
+
+ def _end_write_unlocked(self, txn):
+ assert self._write_txn == txn
+ self._write_txn = None
+ self._maybe_wakeup_one_waiter_unlocked()
+
+ def _end_write(self, txn):
+ with self._version_lock:
+ self._end_write_unlocked(txn)
+
+ def _commit_version_unlocked(self, txn, version, origin):
self._versions.append(version)
self._prune_versions_unlocked()
self.nodes = version.nodes
if self.origin is None:
self.origin = origin
- self._write_txn = None
- self._maybe_wakeup_one_waiter_unlocked()
+ # txn can be None in __init__ when we make the empty version.
+ if txn is not None:
+ self._end_write_unlocked(txn)
- def _commit_version(self, version, origin):
+ def _commit_version(self, txn, version, origin):
with self._version_lock:
- self._commit_version_unlocked(version, origin)
+ self._commit_version_unlocked(txn, version, origin)
def find_node(self, name, create=False):
if create:
@@ -407,10 +438,13 @@ class Transaction(dns.transaction.Transaction):
def _end_transaction(self, commit):
if self.read_only:
- return
- if commit and len(self.version.changed) > 0:
- self.zone._commit_version(ImmutableVersion(self.version),
+ self.zone._end_read(self)
+ elif commit and len(self.version.changed) > 0:
+ self.zone._commit_version(self, ImmutableVersion(self.version),
self.version.origin)
+ else:
+ # rollback
+ self.zone._end_write(self)
def _set_origin(self, origin):
if self.version.origin is None: