diff options
author | Bob Halley <halley@dnspython.org> | 2020-08-16 17:58:29 -0700 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2020-08-21 07:40:45 -0700 |
commit | a7de0230bcbd9eb1a92cebe988394231cd6437da (patch) | |
tree | 80eaac1c15eda312309c0d87f904a19a55fafc1c | |
parent | e2888f116e0c98748f63044e9801acd0d18defd5 (diff) | |
download | dnspython-xfr.tar.gz |
Implement new inbound xfr design.xfr
-rw-r--r-- | dns/__init__.py | 1 | ||||
-rw-r--r-- | dns/asyncquery.py | 91 | ||||
-rw-r--r-- | dns/query.py | 137 | ||||
-rw-r--r-- | dns/transaction.py | 23 | ||||
-rw-r--r-- | dns/xfr.py | 291 | ||||
-rw-r--r-- | doc/async-query.rst | 9 | ||||
-rw-r--r-- | doc/inbound-xfr-class.rst | 14 | ||||
-rw-r--r-- | doc/query.rst | 7 | ||||
-rw-r--r-- | doc/zone-class.rst | 9 | ||||
-rw-r--r-- | doc/zone.rst | 1 | ||||
-rw-r--r-- | tests/test_transaction.py | 10 | ||||
-rw-r--r-- | tests/test_xfr.py | 714 |
12 files changed, 1288 insertions, 19 deletions
diff --git a/dns/__init__.py b/dns/__init__.py index 3a51a53..0473ca1 100644 --- a/dns/__init__.py +++ b/dns/__init__.py @@ -58,6 +58,7 @@ __all__ = [ 'version', 'versioned', 'wire', + 'xfr', 'zone', 'zonefile', ] diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 8a10dae..3787c07 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -30,7 +30,8 @@ import dns.rcode import dns.rdataclass import dns.rdatatype -from dns.query import _compute_times, _matches_destination, BadResponse, ssl +from dns.query import _compute_times, _matches_destination, BadResponse, ssl, \ + UDPMode # for brevity @@ -498,3 +499,91 @@ async def tls(q, where, timeout=None, port=853, source=None, source_port=0, finally: if not sock and s: await s.close() + +async def inbound_xfr(where, txn_manager, query=None, + port=53, timeout=None, lifetime=None, source=None, + source_port=0, udp_mode=UDPMode.NEVER, + keyring=None, keyname=None, + keyalgorithm=dns.tsig.default_algorithm, + backend=None): + """Conduct an inbound transfer and apply it via a transaction from the + txn_manager. + + For a description of most of the parameters to this method, see + the documentation of :py:func:`dns.query.inbound_xfr()`. + + *backend*, a ``dns.asyncbackend.Backend``, or ``None``. If ``None``, + the default, then dnspython will use the default backend. + + Raises on errors. + + """ + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + af = dns.inet.af_for_address(where) + stuple = _source_tuple(af, source, source_port) + dtuple = (where, port) + (_, expiration) = _compute_times(lifetime) + retry = True + while retry: + retry = False + if is_ixfr and udp_mode != UDPMode.NEVER: + sock_type = socket.SOCK_DGRAM + is_udp = True + else: + sock_type = socket.SOCK_STREAM + is_udp = False + if not backend: + backend = dns.asyncbackend.get_default_backend() + s = await backend.make_socket(af, sock_type, 0, stuple, dtuple, + _timeout(expiration)) + async with s: + if is_udp: + await s.sendto(wire, dtuple, _timeout(expiration)) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + await s.sendall(tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or \ + (expiration is not None and mexpiration > expiration): + mexpiration = expiration + if is_udp: + destination = _lltuple((where, port), af) + while True: + timeout = _timeout(mexpiration) + (rwire, from_address) = await s.recvfrom(65535, + timeout) + if _matches_destination(af, from_address, + destination, True): + break + else: + ldata = await _read_exactly(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = await _read_exactly(s, l, mexpiration) + is_ixfr = (rdtype == dns.rdatatype.IXFR) + r = dns.message.from_wire(rwire, keyring=query.keyring, + request_mac=query.mac, xfr=True, + origin=origin, tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr) + try: + done = inbound.process_message(r, is_udp) + except dns.xfr.UseTCP: + assert is_udp # should not happen if we used TCP! + if udp_mode == UDPMode.ONLY: + raise + done = True + retry = True + udp_mode = UDPMode.NEVER + continue + tsig_ctx = r.tsig_ctx + if not retry and query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") diff --git a/dns/query.py b/dns/query.py index d4a3afa..01452ee 100644 --- a/dns/query.py +++ b/dns/query.py @@ -18,6 +18,7 @@ """Talk to a DNS server.""" import contextlib +import enum import errno import os import selectors @@ -35,6 +36,7 @@ import dns.rcode import dns.rdataclass import dns.rdatatype import dns.serial +import dns.xfr try: import requests @@ -73,20 +75,15 @@ class BadResponse(dns.exception.FormError): """A DNS query response does not respond to the question asked.""" -class TransferError(dns.exception.DNSException): - """A zone transfer response got a non-zero rcode.""" - - def __init__(self, rcode): - message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) - super().__init__(message) - self.rcode = rcode - - class NoDOH(dns.exception.DNSException): """DNS over HTTPS (DOH) was requested but the requests module is not available.""" +# for backwards compatibility +TransferError = dns.xfr.TransferError + + def _compute_times(timeout): now = time.time() if timeout is None: @@ -917,7 +914,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, (expiration is not None and mexpiration > expiration): mexpiration = expiration if use_udp: - (wire, _) = _udp_recv(s, 65535, expiration) + (wire, _) = _udp_recv(s, 65535, mexpiration) else: ldata = _net_read(s, 2, mexpiration) (l,) = struct.unpack("!H", ldata) @@ -984,3 +981,123 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if done and q.keyring and not r.had_tsig: raise dns.exception.FormError("missing TSIG") yield r + + +class UDPMode(enum.IntEnum): + """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`? + + NEVER means "never use UDP; always use TCP" + TRY_FIRST means "try to use UDP but fall back to TCP if needed" + ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed" + """ + NEVER = 0 + TRY_FIRST = 1 + ONLY = 2 + + +def inbound_xfr(where, txn_manager, query=None, + port=53, timeout=None, lifetime=None, source=None, + source_port=0, udp_mode=UDPMode.NEVER, + keyring=None, keyname=None, + keyalgorithm=dns.tsig.default_algorithm): + """Conduct an inbound transfer and apply it via a transaction from the + txn_manager. + + *where*, a ``str`` containing an IPv4 or IPv6 address, where + to send the message. + + *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager + for this transfer (typically a ``dns.zone.Zone``). + + *query*, the query to send. If not supplied, a default query is + constructed using information from the *txn_manager*. + + *port*, an ``int``, the port send the message to. The default is 53. + + *timeout*, a ``float``, the number of seconds to wait for each + response message. If None, the default, wait forever. + + *lifetime*, a ``float``, the total number of seconds to spend + doing the transfer. If ``None``, the default, then there is no + limit on the time the transfer may take. + + *source*, a ``str`` containing an IPv4 or IPv6 address, specifying + the source address. The default is the wildcard address. + + *source_port*, an ``int``, the port from which to send the message. + The default is 0. + + *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used + for IXFRs. The default is ``dns.UDPMode.NEVER``, i.e. only use + TCP. Other possibilites are ``dns.UDPMode.TRY_FIRST``, which + means "try UDP but fallback to TCP if needed", and + ``dns.UDPMode.ONLY``, which means "try UDP and raise + ``dns.xfr.UseTCP`` if it does not succeeed. + + *keyring*, a ``dict``, the keyring to use for TSIG. + + *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG + key to use. + + *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use. + + Raises on errors. + + """ + if query is None: + (query, serial) = dns.xfr.make_query(txn_manager) + rdtype = query.question[0].rdtype + is_ixfr = rdtype == dns.rdatatype.IXFR + origin = txn_manager.from_wire_origin() + wire = query.to_wire() + (af, destination, source) = _destination_and_source(where, port, + source, source_port) + (_, expiration) = _compute_times(lifetime) + retry = True + while retry: + retry = False + if is_ixfr and udp_mode != UDPMode.NEVER: + sock_type = socket.SOCK_DGRAM + is_udp = True + else: + sock_type = socket.SOCK_STREAM + is_udp = False + with _make_socket(af, sock_type, source) as s: + _connect(s, destination, expiration) + if is_udp: + _udp_send(s, wire, None, expiration) + else: + tcpmsg = struct.pack("!H", len(wire)) + wire + _net_write(s, tcpmsg, expiration) + with dns.xfr.Inbound(txn_manager, rdtype, serial) as inbound: + done = False + tsig_ctx = None + while not done: + (_, mexpiration) = _compute_times(timeout) + if mexpiration is None or \ + (expiration is not None and mexpiration > expiration): + mexpiration = expiration + if is_udp: + (rwire, _) = _udp_recv(s, 65535, mexpiration) + else: + ldata = _net_read(s, 2, mexpiration) + (l,) = struct.unpack("!H", ldata) + rwire = _net_read(s, l, mexpiration) + r = dns.message.from_wire(rwire, keyring=query.keyring, + request_mac=query.mac, xfr=True, + origin=origin, tsig_ctx=tsig_ctx, + multi=(not is_udp), + one_rr_per_rrset=is_ixfr) + try: + done = inbound.process_message(r, is_udp) + except dns.xfr.UseTCP: + assert is_udp # should not happen if we used TCP! + if udp_mode == UDPMode.ONLY: + raise + done = True + retry = True + udp_mode = UDPMode.NEVER + continue + tsig_ctx = r.tsig_ctx + if not retry and query.keyring and not r.had_tsig: + raise dns.exception.FormError("missing TSIG") diff --git a/dns/transaction.py b/dns/transaction.py index c1645c2..8aec2e8 100644 --- a/dns/transaction.py +++ b/dns/transaction.py @@ -57,6 +57,15 @@ class TransactionManager: """ raise NotImplementedError # pragma: no cover + def from_wire_origin(self): + """Origin to use in from_wire() calls. + """ + (absolute_origin, relativize, _) = self.origin_information() + if relativize: + return absolute_origin + else: + return None + class DeleteNotExact(dns.exception.DNSException): """Existing data did not match data specified by an exact delete.""" @@ -273,7 +282,9 @@ class Transaction: def _rdataset_from_args(self, method, deleting, args): try: arg = args.popleft() - if isinstance(arg, dns.rdataset.Rdataset): + if isinstance(arg, dns.rrset.RRset): + rdataset = arg.to_rdataset() + elif isinstance(arg, dns.rdataset.Rdataset): rdataset = arg else: if deleting: @@ -315,15 +326,17 @@ class Transaction: rrset = arg name = rrset.name # rrsets are also rdatasets, but they don't print the - # same, so convert. - rdataset = dns.rdataset.Rdataset(rrset.rdclass, rrset.rdtype, - rrset.covers, rrset.ttl) - rdataset.union_update(rrset) + # same and can't be stored in nodes, so convert. + rdataset = rrset.to_rdataset() else: raise TypeError(f'{method} requires a name or RRset ' + 'as the first argument') if rdataset.rdclass != self.manager.get_class(): raise ValueError(f'{method} has objects of wrong RdataClass') + if rdataset.rdtype == dns.rdatatype.SOA: + (_, _, origin) = self.manager.origin_information() + if name != origin: + raise ValueError(f'{method} has non-origin SOA') self._raise_if_not_empty(method, args) if not replace: existing = self._get_rdataset(name, rdataset.rdtype, diff --git a/dns/xfr.py b/dns/xfr.py new file mode 100644 index 0000000..311e60e --- /dev/null +++ b/dns/xfr.py @@ -0,0 +1,291 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +# Copyright (C) 2003-2017 Nominum, Inc. +# +# Permission to use, copy, modify, and distribute this software and its +# documentation for any purpose with or without fee is hereby granted, +# provided that the above copyright notice and this permission notice +# appear in all copies. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES +# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF +# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR +# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES +# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN +# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT +# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +import dns.exception +import dns.message +import dns.name +import dns.rcode +import dns.serial +import dns.rdatatype +import dns.zone + + +class TransferError(dns.exception.DNSException): + """A zone transfer response got a non-zero rcode.""" + + def __init__(self, rcode): + message = 'Zone transfer error: %s' % dns.rcode.to_text(rcode) + super().__init__(message) + self.rcode = rcode + + +class SerialWentBackwards(dns.exception.FormError): + """The current serial number is less than the serial we know.""" + + +class UseTCP(dns.exception.DNSException): + """This IXFR cannot be completed with UDP.""" + + +class Inbound: + """ + State machine for zone transfers. + """ + + def __init__(self, txn_manager, rdtype=dns.rdatatype.AXFR, + serial=None): + """Initialize an inbound zone transfer. + + *txn_manager* is a :py:class:`dns.transaction.TransactionManager`. + + *rdtype* can be `dns.rdatatype.AXFR` or `dns.rdatatype.IXFR` + + *serial* is the base serial number for IXFRs, and is required in + that case. + """ + self.txn_manager = txn_manager + self.txn = None + self.rdtype = rdtype + if rdtype == dns.rdatatype.IXFR and serial is None: + raise ValueError('a starting serial must be supplied for IXFRs') + self.serial = serial + (_, _, self.origin) = txn_manager.origin_information() + self.soa_rdataset = None + self.done = False + self.expecting_SOA = False + self.delete_mode = False + + def process_message(self, message, is_udp=False): + """Process one message in the transfer. + + The message should have the same relativization as was specified when + the `dns.xfr.Inbound` was created. The message should also have been + created with `one_rr_per_rrset=True` because order matters. + + *is_udp*, a ``bool`` indidicates if this message was received using + UDP. + + Returns `True` if the transfer is complete, and `False` otherwise. + """ + if self.txn is None: + replacement = self.rdtype == dns.rdatatype.AXFR + self.txn = self.txn_manager.writer(replacement) + rcode = message.rcode() + if rcode != dns.rcode.NOERROR: + raise TransferError(rcode) + # + # We don't require a question section, but if it is present is + # should be correct. + # + if len(message.question) > 0: + if message.question[0].name != self.origin: + raise dns.exception.FormError("wrong question name") + if message.question[0].rdtype != self.rdtype: + raise dns.exception.FormError("wrong question rdatatype") + answer_index = 0 + if self.soa_rdataset is None: + # + # This is the first message. We're expecting an SOA at + # the origin. + # + if not message.answer or message.answer[0].name != self.origin: + raise dns.exception.FormError("No answer or RRset not " + "for zone origin") + rrset = message.answer[0] + name = rrset.name + rdataset = rrset + if rdataset.rdtype != dns.rdatatype.SOA: + raise dns.exception.FormError("first RRset is not an SOA") + answer_index = 1 + self.soa_rdataset = rdataset.copy() + if self.rdtype == dns.rdatatype.IXFR: + if self.soa_rdataset[0].serial == self.serial: + # + # We're already up-to-date. + # + self.done = True + elif dns.serial.Serial(self.soa_rdataset[0].serial) < \ + self.serial: + # It went backwards! + print(dns.serial.Serial(self.soa_rdataset[0].serial), + self.serial) + raise SerialWentBackwards + else: + if is_udp and len(message.answer[answer_index:]) == 0: + # + # There are no more records, so this is the + # "truncated" response. Say to use TCP + # + raise UseTCP + # + # Note we're expecting another SOA so we can detect + # if this IXFR response is an AXFR-style response. + # + self.expecting_SOA = True + # + # Process the answer section (other than the initial SOA in + # the first message). + # + for rrset in message.answer[answer_index:]: + name = rrset.name + rdataset = rrset + if self.done: + raise dns.exception.FormError("answers after final SOA") + if rdataset.rdtype == dns.rdatatype.SOA and \ + name == self.origin: + # + # Every time we see an origin SOA delete_mode inverts + # + if self.rdtype == dns.rdatatype.IXFR: + self.delete_mode = not self.delete_mode + # + # If this SOA Rdataset is equal to the first we saw + # then we're finished. If this is an IXFR we also + # check that we're seeing the record in the expected + # part of the response. + # + if rdataset == self.soa_rdataset and \ + (self.rdtype == dns.rdatatype.AXFR or + (self.rdtype == dns.rdatatype.IXFR and + self.delete_mode)): + # + # This is the final SOA + # + if self.expecting_SOA: + # We got an empty IXFR sequence! + raise dns.exception.FormError('empty IXFR sequence') + if self.rdtype == dns.rdatatype.IXFR \ + and self.serial != rdataset[0].serial: + raise dns.exception.FormError('unexpected end of IXFR ' + 'sequence') + self.txn.replace(name, rdataset) + self.txn.commit() + self.txn = None + self.done = True + else: + # + # This is not the final SOA + # + self.expecting_SOA = False + if self.rdtype == dns.rdatatype.IXFR: + if self.delete_mode: + # This is the start of an IXFR deletion set + if rdataset[0].serial != self.serial: + raise dns.exception.FormError( + "IXFR base serial mismatch") + else: + # This is the start of an IXFR addition set + self.serial = rdataset[0].serial + self.txn.replace(name, rdataset) + else: + # We saw a non-final SOA for the origin in an AXFR. + raise dns.exception.FormError('unexpected origin SOA ' + 'in AXFR') + continue + if self.expecting_SOA: + # + # We made an IXFR request and are expecting another + # SOA RR, but saw something else, so this must be an + # AXFR response. + # + self.rdtype = dns.rdatatype.AXFR + self.expecting_SOA = False + self.delete_mode = False + self.txn.rollback() + self.txn = self.txn_manager.writer(True) + # + # Note we are falling through into the code below + # so whatever rdataset this was gets written. + # + # Add or remove the data + if self.delete_mode: + self.txn.delete_exact(name, rdataset) + else: + self.txn.add(name, rdataset) + if is_udp and not self.done: + # + # This is a UDP IXFR and we didn't get to done, and we didn't + # get the proper "truncated" response + # + raise dns.exception.FormError('unexpected end of UDP IXFR') + return self.done + + # + # Inbounds are context managers. + # + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.txn: + self.txn.rollback() + return False + + +def make_query(txn_manager, serial=0, + use_edns=None, ednsflags=None, payload=None, + request_payload=None, options=None, + keyring=None, keyname=None, + keyalgorithm=dns.tsig.default_algorithm): + """Make an AXFR or IXFR query. + + *txn_manager* is a ``dns.transaction.TransactionManager``, typically a + ``dns.zone.Zone``. + + *serial* is an ``int`` or ``None``. If 0, then IXFR will be + attempted using the most recent serial number from the + *txn_manager*; it is the caller's responsibility to ensure there + are no write transactions active that could invalidate the + retrieved serial. If a serial cannot be determined, AXFR will be + forced. Other integer values are the starting serial to use. + ``None`` forces an AXFR. + + Please see the documentation for :py:func:`dns.message.make_query` and + :py:func:`dns.message.Message.use_tsig` for details on the other parameters + to this function. + + Returns a `(query, serial)` tuple. + """ + (zone_origin, _, origin) = txn_manager.origin_information() + if serial is None: + rdtype = dns.rdatatype.AXFR + elif not isinstance(serial, int): + raise ValueError('serial is not an integer') + elif serial == 0: + with txn_manager.reader() as txn: + rdataset = txn.get(origin, 'SOA') + if rdataset: + serial = rdataset[0].serial + rdtype = dns.rdatatype.IXFR + else: + serial = None + rdtype = dns.rdatatype.AXFR + elif serial > 0 and serial < 4294967296: + rdtype = dns.rdatatype.IXFR + else: + raise ValueError('serial out-of-range') + q = dns.message.make_query(zone_origin, rdtype, txn_manager.get_class(), + use_edns, False, ednsflags, payload, + request_payload, options) + if serial is not None: + rrset = dns.rrset.from_text(zone_origin, 0, 'IN', 'SOA', + f'. . {serial} 0 0 0 0') + q.authority.append(rrset) + if keyring is not None: + q.use_tsig(keyring, keyname, algorithm=keyalgorithm) + return (q, serial) diff --git a/doc/async-query.rst b/doc/async-query.rst index e2466ea..7202bdf 100644 --- a/doc/async-query.rst +++ b/doc/async-query.rst @@ -9,8 +9,8 @@ processing their responses. If you want "stub resolver" behavior, then you should use the higher level ``dns.asyncresolver`` module; see :ref:`async_resolver`. -There is currently no support for zone transfers or DNS-over-HTTPS -using asynchronous I/O but we hope to offer this in the future. +There is currently no support for DNS-over-HTTPS using asynchronous +I/O but we hope to offer this in the future. UDP --- @@ -31,3 +31,8 @@ TLS --- .. autofunction:: dns.asyncquery.tls + +Zone Transfers +-------------- + +.. autofunction:: dns.asyncquery.inbound_xfr diff --git a/doc/inbound-xfr-class.rst b/doc/inbound-xfr-class.rst new file mode 100644 index 0000000..73eaf57 --- /dev/null +++ b/doc/inbound-xfr-class.rst @@ -0,0 +1,14 @@ +.. _inbound-xfr-class: + +The dns.xfr.Inbound Class and make_query() function +--------------------------------------------------- + +The ``Inbound`` class provides support for inbound DNS zone transfers, both +AXFR and IXFR. I/O is handled in other classes. When a message related +to the transfer arrives, the I/O code calls the ``process_message()`` method +which adds the content to the pending transaction. + +.. autoclass:: dns.xfr.Inbound + :members: + +.. autofunction:: dns.xfr.make_query diff --git a/doc/query.rst b/doc/query.rst index 08940b4..beb0869 100644 --- a/doc/query.rst +++ b/doc/query.rst @@ -41,4 +41,11 @@ HTTPS Zone Transfers -------------- +As of dnspython 2.1, ``dns.query.xfr`` is deprecated. Please use +``dns.query.inbound_xfr`` instead. + +.. autoclass:: dns.query.UDPMode + +.. autofunction:: dns.query.inbound_xfr + .. autofunction:: dns.query.xfr diff --git a/doc/zone-class.rst b/doc/zone-class.rst index bdaf884..48e138e 100644 --- a/doc/zone-class.rst +++ b/doc/zone-class.rst @@ -91,6 +91,15 @@ See below for more information on the ``Transaction`` API. A ``bool``, which is ``True`` if names in the zone should be relativized. +The TransactionManager Class +---------------------------- + +This is the abstract base class of all objects that support transactions. + +.. autoclass:: dns.transaction.TransactionManager + :members: + + The Transaction Class --------------------- diff --git a/doc/zone.rst b/doc/zone.rst index 777f08b..17d9e9d 100644 --- a/doc/zone.rst +++ b/doc/zone.rst @@ -8,3 +8,4 @@ DNS Zones zone-class zone-make + inbound-xfr-class diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 7fb353c..bb69b71 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -13,6 +13,7 @@ import dns.transaction import dns.versioned import dns.zone + class DB(dns.transaction.TransactionManager): def __init__(self): self.rdatasets = {} @@ -24,7 +25,7 @@ class DB(dns.transaction.TransactionManager): return Transaction(self, replacement, False) def origin_information(self): - return (None, True) + return (dns.name.from_text('example'), True, dns.name.empty) def get_class(self): return dns.rdataclass.IN @@ -224,6 +225,13 @@ def test_bad_parameters(db): with pytest.raises(TypeError): txn.delete(1) +def test_cannot_store_non_origin_soa(db): + with pytest.raises(ValueError): + with db.writer() as txn: + rrset = dns.rrset.from_text('foo', 300, 'in', 'SOA', + '. . 1 2 3 4 5') + txn.add(rrset) + example_text = """$TTL 3600 $ORIGIN example. @ soa foo bar 1 2 3 4 5 diff --git a/tests/test_xfr.py b/tests/test_xfr.py new file mode 100644 index 0000000..fbda5fa --- /dev/null +++ b/tests/test_xfr.py @@ -0,0 +1,714 @@ +# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license + +import asyncio + +import pytest + +import dns.asyncbackend +import dns.asyncquery +import dns.message +import dns.query +import dns.tsigkeyring +import dns.versioned +import dns.xfr + +# Some tests use a "nano nameserver" for testing. It requires trio +# and threading, so try to import it and if it doesn't work, skip +# those tests. +try: + from .nanonameserver import Server + _nanonameserver_available = True +except ImportError: + _nanonameserver_available = False + class Server(object): + pass + +axfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN AXFR +;ANSWER +@ 3600 IN SOA foo bar 1 2 3 4 5 +@ 3600 IN NS ns1 +@ 3600 IN NS ns2 +bar.foo 300 IN MX 0 blaz.foo +ns1 3600 IN A 10.0.0.1 +ns2 3600 IN A 10.0.0.2 +@ 3600 IN SOA foo bar 1 2 3 4 5 +''' + +axfr1 = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN AXFR +;ANSWER +@ 3600 IN SOA foo bar 1 2 3 4 5 +@ 3600 IN NS ns1 +@ 3600 IN NS ns2 +''' +axfr2 = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;ANSWER +bar.foo 300 IN MX 0 blaz.foo +ns1 3600 IN A 10.0.0.1 +ns2 3600 IN A 10.0.0.2 +@ 3600 IN SOA foo bar 1 2 3 4 5 +''' + +base = """@ 3600 IN SOA foo bar 1 2 3 4 5 +@ 3600 IN NS ns1 +@ 3600 IN NS ns2 +bar.foo 300 IN MX 0 blaz.foo +ns1 3600 IN A 10.0.0.1 +ns2 3600 IN A 10.0.0.2 +""" + +axfr_unexpected_origin = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN AXFR +;ANSWER +@ 3600 IN SOA foo bar 1 2 3 4 5 +@ 3600 IN SOA foo bar 1 2 3 4 7 +''' + +ixfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 4 2 3 4 5 +@ 3600 IN SOA foo bar 1 2 3 4 5 +bar.foo 300 IN MX 0 blaz.foo +ns2 3600 IN A 10.0.0.2 +@ 3600 IN SOA foo bar 2 2 3 4 5 +ns2 3600 IN A 10.0.0.4 +@ 3600 IN SOA foo bar 2 2 3 4 5 +@ 3600 IN SOA foo bar 3 2 3 4 5 +ns3 3600 IN A 10.0.0.3 +@ 3600 IN SOA foo bar 3 2 3 4 5 +@ 3600 IN NS ns2 +@ 3600 IN SOA foo bar 4 2 3 4 5 +@ 3600 IN SOA foo bar 4 2 3 4 5 +''' + +compressed_ixfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 4 2 3 4 5 +@ 3600 IN SOA foo bar 1 2 3 4 5 +bar.foo 300 IN MX 0 blaz.foo +ns2 3600 IN A 10.0.0.2 +@ 3600 IN NS ns2 +@ 3600 IN SOA foo bar 4 2 3 4 5 +ns2 3600 IN A 10.0.0.4 +ns3 3600 IN A 10.0.0.3 +@ 3600 IN SOA foo bar 4 2 3 4 5 +''' + +ixfr_expected = """@ 3600 IN SOA foo bar 4 2 3 4 5 +@ 3600 IN NS ns1 +ns1 3600 IN A 10.0.0.1 +ns2 3600 IN A 10.0.0.4 +ns3 3600 IN A 10.0.0.3 +""" + +ixfr_first_message = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 4 2 3 4 5 +''' + +ixfr_header = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;ANSWER +''' + +ixfr_body = [ + '@ 3600 IN SOA foo bar 1 2 3 4 5', + 'bar.foo 300 IN MX 0 blaz.foo', + 'ns2 3600 IN A 10.0.0.2', + '@ 3600 IN SOA foo bar 2 2 3 4 5', + 'ns2 3600 IN A 10.0.0.4', + '@ 3600 IN SOA foo bar 2 2 3 4 5', + '@ 3600 IN SOA foo bar 3 2 3 4 5', + 'ns3 3600 IN A 10.0.0.3', + '@ 3600 IN SOA foo bar 3 2 3 4 5', + '@ 3600 IN NS ns2', + '@ 3600 IN SOA foo bar 4 2 3 4 5', + '@ 3600 IN SOA foo bar 4 2 3 4 5', +] + +ixfrs = [ixfr_first_message] +ixfrs.extend([ixfr_header + l for l in ixfr_body]) + +good_empty_ixfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 1 2 3 4 5 +''' + +retry_tcp_ixfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 5 2 3 4 5 +''' + +bad_empty_ixfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 4 2 3 4 5 +@ 3600 IN SOA foo bar 4 2 3 4 5 +''' + +unexpected_end_ixfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 4 2 3 4 5 +@ 3600 IN SOA foo bar 1 2 3 4 5 +bar.foo 300 IN MX 0 blaz.foo +ns2 3600 IN A 10.0.0.2 +@ 3600 IN NS ns2 +@ 3600 IN SOA foo bar 3 2 3 4 5 +ns2 3600 IN A 10.0.0.4 +ns3 3600 IN A 10.0.0.3 +@ 3600 IN SOA foo bar 4 2 3 4 5 +''' + +bad_serial_ixfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 4 2 3 4 5 +@ 3600 IN SOA foo bar 2 2 3 4 5 +bar.foo 300 IN MX 0 blaz.foo +ns2 3600 IN A 10.0.0.2 +@ 3600 IN NS ns2 +@ 3600 IN SOA foo bar 4 2 3 4 5 +ns2 3600 IN A 10.0.0.4 +ns3 3600 IN A 10.0.0.3 +@ 3600 IN SOA foo bar 4 2 3 4 5 +''' + +ixfr_axfr = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 3600 IN SOA foo bar 1 2 3 4 5 +@ 3600 IN NS ns1 +@ 3600 IN NS ns2 +bar.foo 300 IN MX 0 blaz.foo +ns1 3600 IN A 10.0.0.1 +ns2 3600 IN A 10.0.0.2 +@ 3600 IN SOA foo bar 1 2 3 4 5 +''' + +def test_basic_axfr(): + z = dns.versioned.Zone('example.') + m = dns.message.from_text(axfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: + done = xfr.process_message(m) + assert done + ez = dns.zone.from_text(base, 'example.') + assert z == ez + +def test_basic_axfr_two_parts(): + z = dns.versioned.Zone('example.') + m1 = dns.message.from_text(axfr1, origin=z.origin, + one_rr_per_rrset=True) + m2 = dns.message.from_text(axfr2, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: + done = xfr.process_message(m1) + assert not done + done = xfr.process_message(m2) + assert done + ez = dns.zone.from_text(base, 'example.') + assert z == ez + +def test_axfr_unexpected_origin(): + z = dns.versioned.Zone('example.') + m = dns.message.from_text(axfr_unexpected_origin, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +def test_basic_ixfr(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(ixfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + done = xfr.process_message(m) + assert done + ez = dns.zone.from_text(ixfr_expected, 'example.') + assert z == ez + +def test_compressed_ixfr(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(compressed_ixfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + done = xfr.process_message(m) + assert done + ez = dns.zone.from_text(ixfr_expected, 'example.') + assert z == ez + +def test_basic_ixfr_many_parts(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + done = False + for text in ixfrs: + assert not done + m = dns.message.from_text(text, origin=z.origin, + one_rr_per_rrset=True) + done = xfr.process_message(m) + assert done + ez = dns.zone.from_text(ixfr_expected, 'example.') + assert z == ez + +def test_good_empty_ixfr(): + z = dns.zone.from_text(ixfr_expected, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(good_empty_ixfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + done = xfr.process_message(m) + assert done + ez = dns.zone.from_text(ixfr_expected, 'example.') + assert z == ez + +def test_retry_tcp_ixfr(): + z = dns.zone.from_text(ixfr_expected, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(retry_tcp_ixfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.xfr.UseTCP): + xfr.process_message(m, True) + +def test_bad_empty_ixfr(): + z = dns.zone.from_text(ixfr_expected, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_empty_ixfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=3) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +def test_serial_went_backwards_ixfr(): + z = dns.zone.from_text(ixfr_expected, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_empty_ixfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=5) as xfr: + with pytest.raises(dns.xfr.SerialWentBackwards): + xfr.process_message(m) + +def test_ixfr_is_axfr(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(ixfr_axfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=0xffffffff) as xfr: + done = xfr.process_message(m) + assert done + ez = dns.zone.from_text(base, 'example.') + assert z == ez + +def test_ixfr_requires_serial(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + with pytest.raises(ValueError): + dns.xfr.Inbound(z, dns.rdatatype.IXFR) + +def test_ixfr_unexpected_end(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(unexpected_end_ixfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +def test_ixfr_bad_serial(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_serial_ixfr, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +refused = '''id 1 +opcode QUERY +rcode REFUSED +flags AA +;QUESTION +example. IN AXFR +''' + +bad_qname = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +not-example. IN IXFR +''' + +bad_qtype = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN AXFR +''' + +soa_not_first = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +bar.foo 300 IN MX 0 blaz.foo +''' + +soa_not_first_2 = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ANSWER +@ 300 IN MX 0 blaz.foo +''' + +no_answer = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN IXFR +;ADDITIONAL +bar.foo 300 IN MX 0 blaz.foo +''' + +axfr_answers_after_final_soa = '''id 1 +opcode QUERY +rcode NOERROR +flags AA +;QUESTION +example. IN AXFR +;ANSWER +@ 3600 IN SOA foo bar 1 2 3 4 5 +@ 3600 IN NS ns1 +@ 3600 IN NS ns2 +bar.foo 300 IN MX 0 blaz.foo +ns1 3600 IN A 10.0.0.1 +ns2 3600 IN A 10.0.0.2 +@ 3600 IN SOA foo bar 1 2 3 4 5 +ns3 3600 IN A 10.0.0.3 +''' + +def test_refused(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(refused, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.xfr.TransferError): + xfr.process_message(m) + +def test_bad_qname(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_qname, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +def test_bad_qtype(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(bad_qtype, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +def test_soa_not_first(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(soa_not_first, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + m = dns.message.from_text(soa_not_first_2, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +def test_no_answer(): + z = dns.zone.from_text(base, 'example.', + zone_factory=dns.versioned.Zone) + m = dns.message.from_text(no_answer, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.IXFR, serial=1) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +def test_axfr_answers_after_final_soa(): + z = dns.versioned.Zone('example.') + m = dns.message.from_text(axfr_answers_after_final_soa, origin=z.origin, + one_rr_per_rrset=True) + with dns.xfr.Inbound(z, dns.rdatatype.AXFR) as xfr: + with pytest.raises(dns.exception.FormError): + xfr.process_message(m) + +keyring = dns.tsigkeyring.from_text( + { + 'keyname.': 'NjHwPsMKjdN++dOfE5iAiQ==' + } +) + +keyname = dns.name.from_text('keyname') + +def test_make_query_basic(): + z = dns.versioned.Zone('example.') + (q, s) = dns.xfr.make_query(z) + assert q.question[0].rdtype == dns.rdatatype.AXFR + assert s is None + (q, s) = dns.xfr.make_query(z, serial=None) + assert q.question[0].rdtype == dns.rdatatype.AXFR + assert s is None + (q, s) = dns.xfr.make_query(z, serial=10) + assert q.question[0].rdtype == dns.rdatatype.IXFR + assert q.authority[0].rdtype == dns.rdatatype.SOA + assert q.authority[0][0].serial == 10 + assert s == 10 + with z.writer() as txn: + txn.add('@', 300, dns.rdata.from_text('in', 'soa', '. . 1 2 3 4 5')) + (q, s) = dns.xfr.make_query(z) + assert q.question[0].rdtype == dns.rdatatype.IXFR + assert q.authority[0].rdtype == dns.rdatatype.SOA + assert q.authority[0][0].serial == 1 + assert s == 1 + (q, s) = dns.xfr.make_query(z, keyring=keyring, keyname=keyname) + assert q.question[0].rdtype == dns.rdatatype.IXFR + assert q.authority[0].rdtype == dns.rdatatype.SOA + assert q.authority[0][0].serial == 1 + assert s == 1 + assert q.keyname == keyname + + +def test_make_query_bad_serial(): + z = dns.versioned.Zone('example.') + with pytest.raises(ValueError): + dns.xfr.make_query(z, serial='hi') + with pytest.raises(ValueError): + dns.xfr.make_query(z, serial=-1) + with pytest.raises(ValueError): + dns.xfr.make_query(z, serial=4294967296) + + +class XFRNanoNameserver(Server): + + def __init__(self): + super().__init__(origin=dns.name.from_text('example')) + + def handle(self, request): + try: + if request.message.question[0].rdtype == dns.rdatatype.IXFR: + text = ixfr + else: + text = axfr + r = dns.message.from_text(text, one_rr_per_rrset=True, + origin=self.origin) + r.id = request.message.id + return r + except Exception: + pass + +@pytest.mark.skipif(not _nanonameserver_available, + reason="requires nanonameserver") +def test_sync_inbound_xfr(): + with XFRNanoNameserver() as ns: + zone = dns.versioned.Zone('example') + dns.query.inbound_xfr(ns.tcp_address[0], zone, port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST) + dns.query.inbound_xfr(ns.tcp_address[0], zone, port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST) + expected = dns.zone.from_text(ixfr_expected, 'example') + assert zone == expected + +async def async_inbound_xfr(): + with XFRNanoNameserver() as ns: + zone = dns.versioned.Zone('example') + await dns.asyncquery.inbound_xfr(ns.tcp_address[0], zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST) + await dns.asyncquery.inbound_xfr(ns.tcp_address[0], zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST) + expected = dns.zone.from_text(ixfr_expected, 'example') + assert zone == expected + +@pytest.mark.skipif(not _nanonameserver_available, + reason="requires nanonameserver") +def test_asyncio_inbound_xfr(): + dns.asyncbackend.set_default_backend('asyncio') + async def run(): + await async_inbound_xfr() + try: + runner = asyncio.run + except AttributeError: + # this is only needed for 3.6 + def old_runner(awaitable): + loop = asyncio.get_event_loop() + return loop.run_until_complete(awaitable) + runner = old_runner + runner(run()) + +# +# We don't need to do this as it's all generic code, but +# just for extra caution we do it for each backend. +# + +try: + import trio + + @pytest.mark.skipif(not _nanonameserver_available, + reason="requires nanonameserver") + def test_trio_inbound_xfr(): + dns.asyncbackend.set_default_backend('trio') + async def run(): + await async_inbound_xfr() + trio.run(run) +except ImportError: + pass + +try: + import curio + + @pytest.mark.skipif(not _nanonameserver_available, + reason="requires nanonameserver") + def test_curio_inbound_xfr(): + dns.asyncbackend.set_default_backend('curio') + async def run(): + await async_inbound_xfr() + curio.run(run) +except ImportError: + pass + + +class UDPXFRNanoNameserver(Server): + + def __init__(self): + super().__init__(origin=dns.name.from_text('example')) + self.did_truncation = False + + def handle(self, request): + try: + if request.message.question[0].rdtype == dns.rdatatype.IXFR: + if self.did_truncation: + text = ixfr + else: + text = retry_tcp_ixfr + self.did_truncation = True + else: + text = axfr + r = dns.message.from_text(text, one_rr_per_rrset=True, + origin=self.origin) + r.id = request.message.id + return r + except Exception: + pass + +@pytest.mark.skipif(not _nanonameserver_available, + reason="requires nanonameserver") +def test_sync_retry_tcp_inbound_xfr(): + with UDPXFRNanoNameserver() as ns: + zone = dns.versioned.Zone('example') + dns.query.inbound_xfr(ns.tcp_address[0], zone, port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST) + dns.query.inbound_xfr(ns.tcp_address[0], zone, port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST) + expected = dns.zone.from_text(ixfr_expected, 'example') + assert zone == expected + +async def udp_async_inbound_xfr(): + with UDPXFRNanoNameserver() as ns: + zone = dns.versioned.Zone('example') + await dns.asyncquery.inbound_xfr(ns.tcp_address[0], zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST) + await dns.asyncquery.inbound_xfr(ns.tcp_address[0], zone, + port=ns.tcp_address[1], + udp_mode=dns.query.UDPMode.TRY_FIRST) + expected = dns.zone.from_text(ixfr_expected, 'example') + assert zone == expected + +@pytest.mark.skipif(not _nanonameserver_available, + reason="requires nanonameserver") +def test_asyncio_retry_tcp_inbound_xfr(): + dns.asyncbackend.set_default_backend('asyncio') + async def run(): + await udp_async_inbound_xfr() + try: + runner = asyncio.run + except AttributeError: + def old_runner(awaitable): + loop = asyncio.get_event_loop() + return loop.run_until_complete(awaitable) + runner = old_runner + runner(run()) |