diff options
-rw-r--r-- | dns/versioned.py | 39 | ||||
-rw-r--r-- | tests/test_transaction.py | 14 |
2 files changed, 28 insertions, 25 deletions
diff --git a/dns/versioned.py b/dns/versioned.py index cc2f714..ae921f1 100644 --- a/dns/versioned.py +++ b/dns/versioned.py @@ -54,8 +54,9 @@ class Version: class WritableVersion(Version): def __init__(self, zone, replacement=False): - if len(zone.versions) > 0: - id = zone.versions[-1].id + 1 + # The zone._versions_lock must be held by our caller. + if len(zone._versions) > 0: + id = zone._versions[-1].id + 1 else: id = 1 super().__init__(zone, id) @@ -167,8 +168,8 @@ class ImmutableNode(Node): class Zone(dns.zone.Zone): - __slots__ = ['versions', '_write_txn', '_write_waiters', '_write_event', - '_pruning_policy'] + __slots__ = ['_versions', '_versions_lock', '_write_txn', + '_write_waiters', '_write_event', '_pruning_policy'] node_factory = Node @@ -190,8 +191,8 @@ class Zone(dns.zone.Zone): the default policy, which retains one version is used. """ super().__init__(origin, rdclass, relativize) - self.versions = collections.deque() - self.version_lock = _threading.Lock() + self._versions = collections.deque() + self._version_lock = _threading.Lock() if pruning_policy is None: self._pruning_policy = self._default_pruning_policy else: @@ -204,10 +205,10 @@ class Zone(dns.zone.Zone): def reader(self, id=None, serial=None): # pylint: disable=arguments-differ if id is not None and serial is not None: raise ValueError('cannot specify both id and serial') - with self.version_lock: + with self._version_lock: if id is not None: version = None - for v in reversed(self.versions): + for v in reversed(self._versions): if v.id == id: version = v break @@ -219,7 +220,7 @@ class Zone(dns.zone.Zone): else: oname = self.origin version = None - for v in reversed(self.versions): + for v in reversed(self._versions): n = v.nodes.get(oname) if n: rds = n.get_rdataset(self.rdclass, dns.rdatatype.SOA) @@ -229,13 +230,13 @@ class Zone(dns.zone.Zone): if version is None: raise KeyError('serial not found') else: - version = self.versions[-1] + version = self._versions[-1] return Transaction(False, self, version) def writer(self, replacement=False): event = None while True: - with self.version_lock: + with self._version_lock: # Checking event == self._write_event ensures that either # no one was waiting before we got lucky and found no write # txn, or we were the one who was waiting and got woken up. @@ -270,7 +271,7 @@ class Zone(dns.zone.Zone): # try: # event.wait() # except trio.Cancelled: - # with self.version_lock: + # with self._version_lock: # self._maybe_wakeup_one_waiter_unlocked() # raise # @@ -290,9 +291,9 @@ class Zone(dns.zone.Zone): # pylint: enable=unused-argument def _prune_versions_unlocked(self): - while len(self.versions) > 1 and \ - self._pruning_policy(self, self.versions[0]): - self.versions.popleft() + while len(self._versions) > 1 and \ + self._pruning_policy(self, self._versions[0]): + self._versions.popleft() def set_max_versions(self, max_versions): """Set a pruning policy that retains up to the specified number @@ -305,7 +306,7 @@ class Zone(dns.zone.Zone): return False else: def policy(zone, _): - return len(zone.versions) > max_versions + return len(zone._versions) > max_versions self.set_pruning_policy(policy) def set_pruning_policy(self, policy): @@ -322,12 +323,12 @@ class Zone(dns.zone.Zone): """ if policy is None: policy = self._default_pruning_policy - with self.version_lock: + with self._version_lock: self._pruning_policy = policy self._prune_versions_unlocked() def _commit_version_unlocked(self, version, origin): - self.versions.append(version) + self._versions.append(version) self._prune_versions_unlocked() self.nodes = version.nodes if self.origin is None: @@ -336,7 +337,7 @@ class Zone(dns.zone.Zone): self._maybe_wakeup_one_waiter_unlocked() def _commit_version(self, version, origin): - with self.version_lock: + with self._version_lock: self._commit_version_unlocked(version, origin) def find_node(self, name, create=False): diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 64705ed..888fbd5 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -391,7 +391,7 @@ def test_vzone_read_only(vzone): txn.replace(dns.name.empty, expected) def test_vzone_multiple_versions(vzone): - assert len(vzone.versions) == 1 + assert len(vzone._versions) == 1 vzone.set_max_versions(None) # unlimited! with vzone.writer() as txn: txn.set_serial() @@ -401,7 +401,7 @@ def test_vzone_multiple_versions(vzone): txn.set_serial(increment=0, value=1000) rdataset = vzone.find_rdataset('@', 'soa') assert rdataset[0].serial == 1000 - assert len(vzone.versions) == 4 + assert len(vzone._versions) == 4 with vzone.reader(id=5) as txn: assert txn.version.id == 5 rdataset = txn.get('@', 'in', 'soa') @@ -411,13 +411,15 @@ def test_vzone_multiple_versions(vzone): rdataset = txn.get('@', 'in', 'soa') assert rdataset[0].serial == 1000 vzone.set_max_versions(2) - assert len(vzone.versions) == 2 + assert len(vzone._versions) == 2 # The ones that survived should be 3 and 1000 - rdataset = vzone.versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA, + rdataset = vzone._versions[0].get_rdataset(dns.name.empty, + dns.rdatatype.SOA, dns.rdatatype.NONE) assert rdataset[0].serial == 3 - rdataset = vzone.versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA, - dns.rdatatype.NONE) + rdataset = vzone._versions[1].get_rdataset(dns.name.empty, + dns.rdatatype.SOA, + dns.rdatatype.NONE) assert rdataset[0].serial == 1000 with pytest.raises(ValueError): vzone.set_max_versions(0) |