summaryrefslogtreecommitdiff
path: root/tests/test_transaction.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_transaction.py')
-rw-r--r--tests/test_transaction.py451
1 files changed, 451 insertions, 0 deletions
diff --git a/tests/test_transaction.py b/tests/test_transaction.py
new file mode 100644
index 0000000..ed154fc
--- /dev/null
+++ b/tests/test_transaction.py
@@ -0,0 +1,451 @@
+import time
+
+import pytest
+
+import dns.name
+import dns.rdataclass
+import dns.rdatatype
+import dns.rdataset
+import dns.rrset
+import dns.transaction
+import dns.versioned
+import dns.zone
+
+class DB(dns.transaction.TransactionManager):
+ def __init__(self):
+ self.rdatasets = {}
+
+ def reader(self):
+ return Transaction(False, True, self)
+
+ def writer(self, replacement=False):
+ return Transaction(replacement, False, self)
+
+
+class Transaction(dns.transaction.Transaction):
+ def __init__(self, replacement, read_only, db):
+ super().__init__(replacement)
+ self.db = db
+ self.rdatasets = {}
+ self.read_only = read_only
+ if not replacement:
+ self.rdatasets.update(db.rdatasets)
+
+ def _get_rdataset(self, name, rdclass, rdtype, covers):
+ return self.rdatasets.get((name, rdclass, rdtype, covers))
+
+ def _put_rdataset(self, name, rdataset):
+ self.rdatasets[(name, rdataset.rdclass, rdataset.rdtype,
+ rdataset.covers)] = rdataset
+
+ def _delete_name(self, name):
+ remove = []
+ for key in self.rdatasets.keys():
+ if key[0] == name:
+ remove.append(key)
+ if len(remove) > 0:
+ for key in remove:
+ del self.rdatasets[key]
+
+ def _delete_rdataset(self, name, rdclass, rdtype, covers):
+ del self.rdatasets[(name, rdclass, rdtype, covers)]
+
+ def _name_exists(self, name):
+ for key in self.rdatasets.keys():
+ if key[0] == name:
+ return True
+ return False
+
+ def _end_transaction(self, commit):
+ if commit:
+ self.db.rdatasets = self.rdatasets
+
+ def _set_origin(self, origin):
+ pass
+
+@pytest.fixture
+def db():
+ db = DB()
+ rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content')
+ db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] = rrset
+ return db
+
+def test_basic(db):
+ # successful txn
+ with db.writer() as txn:
+ rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.1', '10.0.0.2')
+ txn.add(rrset)
+ assert txn.name_exists(rrset.name)
+ assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+ rrset
+ # rollback
+ with pytest.raises(Exception):
+ with db.writer() as txn:
+ rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.3', '10.0.0.4')
+ txn.add(rrset2)
+ raise Exception()
+ assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+ rrset
+ with db.writer() as txn:
+ txn.delete(rrset.name)
+ assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \
+ is None
+
+def test_get(db):
+ with db.writer() as txn:
+ content = dns.name.from_text('content', None)
+ rdataset = txn.get(content, dns.rdataclass.IN, dns.rdatatype.TXT)
+ assert rdataset is not None
+ assert rdataset[0].strings == (b'content',)
+ assert isinstance(rdataset, dns.rdataset.ImmutableRdataset)
+
+def test_add(db):
+ with db.writer() as txn:
+ rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.1', '10.0.0.2')
+ txn.add(rrset)
+ rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.3', '10.0.0.4')
+ txn.add(rrset2)
+ expected = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.1', '10.0.0.2',
+ '10.0.0.3', '10.0.0.4')
+ assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+ expected
+
+def test_replacement(db):
+ with db.writer() as txn:
+ rrset = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.1', '10.0.0.2')
+ txn.add(rrset)
+ rrset2 = dns.rrset.from_text('foo', 300, 'in', 'a',
+ '10.0.0.3', '10.0.0.4')
+ txn.replace(rrset2)
+ assert db.rdatasets[(rrset.name, rrset.rdclass, rrset.rdtype, 0)] == \
+ rrset2
+
+def test_delete(db):
+ with db.writer() as txn:
+ txn.delete(dns.name.from_text('nonexistent', None))
+ content = dns.name.from_text('content', None)
+ content2 = dns.name.from_text('content2', None)
+ txn.delete(content)
+ assert not txn.name_exists(content)
+ txn.delete(content2, dns.rdataclass.IN, dns.rdatatype.TXT)
+ rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'new-content')
+ txn.add(rrset)
+ assert txn.name_exists(content)
+ txn.delete(content, dns.rdataclass.IN, dns.rdatatype.TXT)
+ assert not txn.name_exists(content)
+ rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'new-content')
+ txn.delete(rrset)
+ content_keys = [k for k in db.rdatasets if k[0] == content]
+ assert len(content_keys) == 0
+
+def test_delete_exact(db):
+ with db.writer() as txn:
+ rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'bad-content')
+ with pytest.raises(dns.transaction.DeleteNotExact):
+ txn.delete_exact(rrset)
+ rrset = dns.rrset.from_text('content2', 300, 'in', 'txt', 'bad-content')
+ with pytest.raises(dns.transaction.DeleteNotExact):
+ txn.delete_exact(rrset)
+ with pytest.raises(dns.transaction.DeleteNotExact):
+ txn.delete_exact(rrset.name)
+ with pytest.raises(dns.transaction.DeleteNotExact):
+ txn.delete_exact(rrset.name, dns.rdataclass.IN, dns.rdatatype.TXT)
+ rrset = dns.rrset.from_text('content', 300, 'in', 'txt', 'content')
+ txn.delete_exact(rrset)
+ assert db.rdatasets.get((rrset.name, rrset.rdclass, rrset.rdtype, 0)) \
+ is None
+
+def test_parameter_forms(db):
+ with db.writer() as txn:
+ foo = dns.name.from_text('foo', None)
+ rdataset = dns.rdataset.from_text('in', 'a', 300,
+ '10.0.0.1', '10.0.0.2')
+ rdata1 = dns.rdata.from_text('in', 'a', '10.0.0.3')
+ rdata2 = dns.rdata.from_text('in', 'a', '10.0.0.4')
+ txn.add(foo, rdataset)
+ txn.add(foo, 100, rdata1)
+ txn.add(foo, 30, rdata2)
+ expected = dns.rrset.from_text('foo', 30, 'in', 'a',
+ '10.0.0.1', '10.0.0.2',
+ '10.0.0.3', '10.0.0.4')
+ assert db.rdatasets[(foo, rdataset.rdclass, rdataset.rdtype, 0)] == \
+ expected
+ with db.writer() as txn:
+ txn.delete(foo, rdataset)
+ txn.delete(foo, rdata1)
+ txn.delete(foo, rdata2)
+ assert db.rdatasets.get((foo, rdataset.rdclass, rdataset.rdtype, 0)) \
+ is None
+
+def test_bad_parameters(db):
+ with db.writer() as txn:
+ with pytest.raises(TypeError):
+ txn.add(1)
+ with pytest.raises(TypeError):
+ rrset = dns.rrset.from_text('bar', 300, 'in', 'txt', 'bar')
+ txn.add(rrset, 1)
+ with pytest.raises(ValueError):
+ foo = dns.name.from_text('foo', None)
+ rdata = dns.rdata.from_text('in', 'a', '10.0.0.3')
+ txn.add(foo, 0x80000000, rdata)
+ with pytest.raises(TypeError):
+ txn.add(foo)
+ with pytest.raises(TypeError):
+ txn.add()
+ with pytest.raises(TypeError):
+ txn.add(foo, 300)
+ with pytest.raises(TypeError):
+ txn.add(foo, 300, 'hi')
+ with pytest.raises(TypeError):
+ txn.add(foo, 'hi')
+ with pytest.raises(TypeError):
+ txn.delete()
+ with pytest.raises(TypeError):
+ txn.delete(1)
+
+example_text = """$TTL 3600
+$ORIGIN example.
+@ soa foo bar 1 2 3 4 5
+@ ns ns1
+@ ns ns2
+ns1 a 10.0.0.1
+ns2 a 10.0.0.2
+$TTL 300
+$ORIGIN foo.example.
+bar mx 0 blaz
+"""
+
+example_text_output = """@ 3600 IN SOA foo bar 1 2 3 4 5
+@ 3600 IN NS ns1
+@ 3600 IN NS ns2
+@ 3600 IN NS ns3
+ns1 3600 IN A 10.0.0.1
+ns2 3600 IN A 10.0.0.2
+ns3 3600 IN A 10.0.0.3
+"""
+
+@pytest.fixture(params=[dns.zone.Zone, dns.versioned.Zone])
+def zone(request):
+ return dns.zone.from_text(example_text, zone_factory=request.param)
+
+def test_zone_basic(zone):
+ with zone.writer() as txn:
+ txn.delete(dns.name.from_text('bar.foo', None))
+ rd = dns.rdata.from_text('in', 'ns', 'ns3')
+ txn.add(dns.name.empty, 3600, rd)
+ rd = dns.rdata.from_text('in', 'a', '10.0.0.3')
+ txn.add(dns.name.from_text('ns3', None), 3600, rd)
+ output = zone.to_text()
+ assert output == example_text_output
+
+def test_zone_base_layer(zone):
+ with zone.writer() as txn:
+ # Get a set from the zone layer
+ rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+ expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2')
+ assert rdataset == expected
+
+def test_zone_transaction_layer(zone):
+ with zone.writer() as txn:
+ # Make a change
+ rd = dns.rdata.from_text('in', 'ns', 'ns3')
+ txn.add(dns.name.empty, 3600, rd)
+ # Get a set from the transaction layer
+ expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2', 'ns3')
+ rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+ assert rdataset == expected
+ assert txn.name_exists(dns.name.empty)
+ ns1 = dns.name.from_text('ns1', None)
+ assert txn.name_exists(ns1)
+ ns99 = dns.name.from_text('ns99', None)
+ assert not txn.name_exists(ns99)
+
+def test_zone_add_and_delete(zone):
+ with zone.writer() as txn:
+ a99 = dns.name.from_text('a99', None)
+ a100 = dns.name.from_text('a100', None)
+ a101 = dns.name.from_text('a101', None)
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+ txn.add(a99, rds)
+ txn.delete(a99, dns.rdataclass.IN, dns.rdatatype.A)
+ txn.delete(a100, dns.rdataclass.IN, dns.rdatatype.A)
+ txn.delete(a101)
+ assert not txn.name_exists(a99)
+ assert not txn.name_exists(a100)
+ assert not txn.name_exists(a101)
+ ns1 = dns.name.from_text('ns1', None)
+ txn.delete(ns1, dns.rdataclass.IN, dns.rdatatype.A)
+ assert not txn.name_exists(ns1)
+ with zone.writer() as txn:
+ txn.add(a99, rds)
+ txn.delete(a99)
+ assert not txn.name_exists(a99)
+ with zone.writer() as txn:
+ txn.add(a100, rds)
+ txn.delete(a99)
+ assert not txn.name_exists(a99)
+ assert txn.name_exists(a100)
+
+def test_zone_get_deleted(zone):
+ with zone.writer() as txn:
+ print(zone.to_text())
+ ns1 = dns.name.from_text('ns1', None)
+ assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is not None
+ txn.delete(ns1)
+ assert txn.get(ns1, dns.rdataclass.IN, dns.rdatatype.A) is None
+ ns2 = dns.name.from_text('ns2', None)
+ txn.delete(ns2, dns.rdataclass.IN, dns.rdatatype.A)
+ assert txn.get(ns2, dns.rdataclass.IN, dns.rdatatype.A) is None
+
+def test_zone_bad_class(zone):
+ with zone.writer() as txn:
+ with pytest.raises(ValueError):
+ txn.get(dns.name.empty, dns.rdataclass.CH,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+ rds = dns.rdataset.from_text('ch', 'ns', 300, 'ns1', 'ns2')
+ with pytest.raises(ValueError):
+ txn.add(dns.name.empty, rds)
+ with pytest.raises(ValueError):
+ txn.replace(dns.name.empty, rds)
+ with pytest.raises(ValueError):
+ txn.delete(dns.name.empty, rds)
+ with pytest.raises(ValueError):
+ txn.delete(dns.name.empty, dns.rdataclass.CH,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+
+def test_set_serial(zone):
+ # basic
+ with zone.writer() as txn:
+ txn.set_serial()
+ rdataset = zone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 2
+ # max
+ with zone.writer() as txn:
+ txn.set_serial(0, 0xffffffff)
+ rdataset = zone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 0xffffffff
+ # wraparound to 1
+ with zone.writer() as txn:
+ txn.set_serial()
+ rdataset = zone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 1
+ # trying to set to zero sets to 1
+ with zone.writer() as txn:
+ txn.set_serial(0, 0)
+ rdataset = zone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 1
+ with pytest.raises(KeyError):
+ with zone.writer() as txn:
+ txn.set_serial(name=dns.name.from_text('unknown', None))
+
+class ExpectedException(Exception):
+ pass
+
+def test_zone_rollback(zone):
+ try:
+ with zone.writer() as txn:
+ a99 = dns.name.from_text('a99.example.')
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+ txn.add(a99, rds)
+ assert txn.name_exists(a99)
+ raise ExpectedException
+ except ExpectedException:
+ pass
+ assert not zone.get_node(a99)
+
+def test_zone_ooz_name(zone):
+ with zone.writer() as txn:
+ with pytest.raises(KeyError):
+ a99 = dns.name.from_text('a99.not-example.')
+ assert txn.name_exists(a99)
+
+def test_zone_iteration(zone):
+ expected = {}
+ for (name, rdataset) in zone.iterate_rdatasets():
+ expected[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ with zone.writer() as txn:
+ actual = {}
+ for (name, rdataset) in txn:
+ actual[(name, rdataset.rdtype, rdataset.covers)] = rdataset
+ assert actual == expected
+
+@pytest.fixture
+def vzone():
+ return dns.zone.from_text(example_text, zone_factory=dns.versioned.Zone)
+
+def test_vzone_read_only(vzone):
+ with vzone.reader() as txn:
+ rdataset = txn.get(dns.name.empty, dns.rdataclass.IN,
+ dns.rdatatype.NS, dns.rdatatype.NONE)
+ expected = dns.rdataset.from_text('in', 'ns', 300, 'ns1', 'ns2')
+ assert rdataset == expected
+ with pytest.raises(dns.transaction.ReadOnly):
+ txn.replace(dns.name.empty, expected)
+
+def test_vzone_multiple_versions(vzone):
+ assert len(vzone.versions) == 1
+ vzone.set_max_versions(None) # unlimited!
+ with vzone.writer() as txn:
+ txn.set_serial()
+ with vzone.writer() as txn:
+ txn.set_serial()
+ with vzone.writer() as txn:
+ txn.set_serial()
+ rdataset = vzone.find_rdataset('@', 'soa')
+ assert rdataset[0].serial == 4
+ assert len(vzone.versions) == 4
+ vzone.set_max_versions(2)
+ assert len(vzone.versions) == 2
+ # The ones that survived should be 3 and 4
+ rdataset = vzone.versions[0].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
+ dns.rdatatype.NONE)
+ assert rdataset[0].serial == 3
+ rdataset = vzone.versions[1].get_rdataset(dns.name.empty, dns.rdatatype.SOA,
+ dns.rdatatype.NONE)
+ assert rdataset[0].serial == 4
+ with pytest.raises(ValueError):
+ vzone.set_max_versions(0)
+
+try:
+ import threading
+
+ one_got_lock = threading.Event()
+
+ def run_one(zone):
+ with zone.writer() as txn:
+ one_got_lock.set()
+ # wait until two blocks
+ while len(zone._write_waiters) == 0:
+ time.sleep(0.01)
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.98')
+ txn.add('a98', rds)
+
+ def run_two(zone):
+ # wait until one has the lock so we know we will block if we
+ # get the call done before the sleep in one completes
+ one_got_lock.wait()
+ with zone.writer() as txn:
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.99')
+ txn.add('a99', rds)
+
+ def test_vzone_concurrency(vzone):
+ t1 = threading.Thread(target=run_one, args=(vzone,))
+ t1.start()
+ t2 = threading.Thread(target=run_two, args=(vzone,))
+ t2.start()
+ t1.join()
+ t2.join()
+ with vzone.reader() as txn:
+ assert txn.name_exists('a98')
+ assert txn.name_exists('a99')
+
+except ImportError: # pragma: no cover
+ pass