# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, # provided that the above copyright notice and this permission notice # appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import socket import sys import time import unittest try: import ssl have_ssl = True except Exception: have_ssl = False import dns.exception import dns.inet import dns.message import dns.name import dns.rdataclass import dns.rdatatype import dns.query import dns.tsigkeyring import dns.zone import tests.util # Some tests use a "nano nameserver" for testing. It requires trio # and threading, so try to import it and if it doesn't work, skip # those tests. try: from .nanonameserver import Server _nanonameserver_available = True except ImportError: _nanonameserver_available = False class Server(object): pass query_addresses = [] if tests.util.have_ipv4(): query_addresses.append("8.8.8.8") if tests.util.have_ipv6(): query_addresses.append("2001:4860:4860::8888") keyring = dns.tsigkeyring.from_text({"name": "tDz6cfXXGtNivRpQ98hr6A=="}) @unittest.skipIf(not tests.util.is_internet_reachable(), "Internet not reachable") class QueryTests(unittest.TestCase): def testQueryUDP(self): for address in query_addresses: qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.udp(q, address, timeout=2) rrs = response.get_rrset( response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) def testQueryUDPWithSocket(self): for address in query_addresses: with socket.socket( dns.inet.af_for_address(address), socket.SOCK_DGRAM ) as s: s.setblocking(0) qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.udp(q, address, sock=s, timeout=2) rrs = response.get_rrset( response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) def testQueryTCP(self): for address in query_addresses: qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tcp(q, address, timeout=2) rrs = response.get_rrset( response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) def testQueryTCPWithSocket(self): for address in query_addresses: with socket.socket( dns.inet.af_for_address(address), socket.SOCK_STREAM ) as s: ll = dns.inet.low_level_address_tuple((address, 53)) s.settimeout(2) s.connect(ll) s.setblocking(0) qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tcp(q, None, sock=s, timeout=2) rrs = response.get_rrset( response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) @unittest.skipUnless(have_ssl, "No SSL support") def testQueryTLS(self): for address in query_addresses: qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tls(q, address, timeout=2) rrs = response.get_rrset( response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) @unittest.skipUnless(have_ssl, "No SSL support") def testQueryTLSWithSocket(self): for address in query_addresses: with socket.socket( dns.inet.af_for_address(address), socket.SOCK_STREAM ) as base_s: ll = dns.inet.low_level_address_tuple((address, 853)) base_s.settimeout(2) base_s.connect(ll) ctx = ssl.create_default_context() ctx.minimum_version = ssl.TLSVersion.TLSv1_2 with ctx.wrap_socket( base_s, server_hostname="dns.google" ) as s: # lgtm[py/insecure-protocol] s.setblocking(0) qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tls(q, None, sock=s, timeout=2) rrs = response.get_rrset( response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) @unittest.skipUnless(have_ssl, "No SSL support") def testQueryTLSwithPadding(self): for address in query_addresses: qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A, use_edns=0, pad=128) response = dns.query.tls(q, address, timeout=2) rrs = response.get_rrset( response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue("8.8.8.8" in seen) self.assertTrue("8.8.4.4" in seen) # the response should have a padding option self.assertIsNotNone(response.opt) has_pad = False for o in response.opt[0].options: if o.otype == dns.edns.OptionType.PADDING: has_pad = True self.assertTrue(has_pad) def testQueryUDPFallback(self): for address in query_addresses: qname = dns.name.from_text(".") q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2) self.assertTrue(tcp) def testQueryUDPFallbackWithSocket(self): for address in query_addresses: af = dns.inet.af_for_address(address) with socket.socket(af, socket.SOCK_DGRAM) as udp_s: udp_s.setblocking(0) with socket.socket(af, socket.SOCK_STREAM) as tcp_s: ll = dns.inet.low_level_address_tuple((address, 53)) tcp_s.settimeout(2) tcp_s.connect(ll) tcp_s.setblocking(0) qname = dns.name.from_text(".") q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) (_, tcp) = dns.query.udp_with_fallback( q, address, udp_sock=udp_s, tcp_sock=tcp_s, timeout=2 ) self.assertTrue(tcp) def testQueryUDPFallbackNoFallback(self): for address in query_addresses: qname = dns.name.from_text("dns.google.") q = dns.message.make_query(qname, dns.rdatatype.A) (_, tcp) = dns.query.udp_with_fallback(q, address, timeout=2) self.assertFalse(tcp) def testUDPReceiveQuery(self): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener: listener.bind(("127.0.0.1", 0)) with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender: sender.bind(("127.0.0.1", 0)) q = dns.message.make_query("dns.google", dns.rdatatype.A) dns.query.send_udp(sender, q, listener.getsockname()) expiration = time.time() + 2 (q, _, addr) = dns.query.receive_udp(listener, expiration=expiration) self.assertEqual(addr, sender.getsockname()) # for brevity _d_and_s = dns.query._destination_and_source class DestinationAndSourceTests(unittest.TestCase): def test_af_inferred_from_where(self): (af, d, s) = _d_and_s("1.2.3.4", 53, None, 0) self.assertEqual(af, socket.AF_INET) def test_af_inferred_from_where(self): (af, d, s) = _d_and_s("1::2", 53, None, 0) self.assertEqual(af, socket.AF_INET6) def test_af_inferred_from_source(self): (af, d, s) = _d_and_s("https://example/dns-query", 443, "1.2.3.4", 0, False) self.assertEqual(af, socket.AF_INET) def test_af_mismatch(self): def bad(): (af, d, s) = _d_and_s("1::2", 53, "1.2.3.4", 0) self.assertRaises(ValueError, bad) def test_source_port_but_no_af_inferred(self): def bad(): (af, d, s) = _d_and_s("https://example/dns-query", 443, None, 12345, False) self.assertRaises(ValueError, bad) def test_where_must_be_an_address(self): def bad(): (af, d, s) = _d_and_s("not a valid address", 53, "1.2.3.4", 0) self.assertRaises(ValueError, bad) def test_destination_is_none_of_where_url(self): (af, d, s) = _d_and_s("https://example/dns-query", 443, None, 0, False) self.assertEqual(d, None) def test_v4_wildcard_source_set(self): (af, d, s) = _d_and_s("1.2.3.4", 53, None, 12345) self.assertEqual(s, ("0.0.0.0", 12345)) def test_v6_wildcard_source_set(self): (af, d, s) = _d_and_s("1::2", 53, None, 12345) self.assertEqual(s, ("::", 12345, 0, 0)) class AddressesEqualTestCase(unittest.TestCase): def test_v4(self): self.assertTrue( dns.query._addresses_equal( socket.AF_INET, ("10.0.0.1", 53), ("10.0.0.1", 53) ) ) self.assertFalse( dns.query._addresses_equal( socket.AF_INET, ("10.0.0.1", 53), ("10.0.0.2", 53) ) ) def test_v6(self): self.assertTrue( dns.query._addresses_equal( socket.AF_INET6, ("1::1", 53), ("0001:0000::1", 53) ) ) self.assertFalse( dns.query._addresses_equal(socket.AF_INET6, ("::1", 53), ("::2", 53)) ) def test_mixed(self): self.assertFalse( dns.query._addresses_equal(socket.AF_INET, ("10.0.0.1", 53), ("::2", 53)) ) axfr_zone = """ $TTL 300 @ SOA ns1 root 1 7200 900 1209600 86400 @ NS ns1 @ NS ns2 ns1 A 10.0.0.1 ns2 A 10.0.0.1 """ class AXFRNanoNameserver(Server): def handle(self, request): self.zone = dns.zone.from_text(axfr_zone, origin=self.origin) self.origin = self.zone.origin items = [] soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA) response = dns.message.make_response(request.message) response.flags |= dns.flags.AA response.answer.append(soa) items.append(response) response = dns.message.make_response(request.message) response.question = [] response.flags |= dns.flags.AA for (name, rdataset) in self.zone.iterate_rdatasets(): if rdataset.rdtype == dns.rdatatype.SOA and name == dns.name.empty: continue rrset = dns.rrset.RRset( name, rdataset.rdclass, rdataset.rdtype, rdataset.covers ) rrset.update(rdataset) response.answer.append(rrset) items.append(response) response = dns.message.make_response(request.message) response.question = [] response.flags |= dns.flags.AA response.answer.append(soa) items.append(response) return items ixfr_message = """id 12345 opcode QUERY rcode NOERROR flags AA ;QUESTION example. IN IXFR ;ANSWER example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400 example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400 deleted.example. 300 IN A 10.0.0.1 changed.example. 300 IN A 10.0.0.2 example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400 changed.example. 300 IN A 10.0.0.4 added.example. 300 IN A 10.0.0.3 example. 300 SOA ns1.example. root.example. 3 7200 900 1209600 86400 example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400 added2.example. 300 IN A 10.0.0.5 example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400 """ ixfr_trailing_junk = ixfr_message + "junk.example. 300 IN A 10.0.0.6" ixfr_up_to_date_message = """id 12345 opcode QUERY rcode NOERROR flags AA ;QUESTION example. IN IXFR ;ANSWER example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400 """ axfr_trailing_junk = """id 12345 opcode QUERY rcode NOERROR flags AA ;QUESTION example. IN AXFR ;ANSWER example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400 added.example. 300 IN A 10.0.0.3 added2.example. 300 IN A 10.0.0.5 changed.example. 300 IN A 10.0.0.4 example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400 junk.example. 300 IN A 10.0.0.6 """ class IXFRNanoNameserver(Server): def __init__(self, response_text): super().__init__() self.response_text = response_text def handle(self, request): try: r = dns.message.from_text(self.response_text, one_rr_per_rrset=True) r.id = request.message.id return r except Exception: pass @unittest.skipIf(not _nanonameserver_available, "nanonameserver required") class XfrTests(unittest.TestCase): def test_axfr(self): expected = dns.zone.from_text(axfr_zone, origin="example") with AXFRNanoNameserver(origin="example") as ns: xfr = dns.query.xfr(ns.tcp_address[0], "example", port=ns.tcp_address[1]) zone = dns.zone.from_xfr(xfr) self.assertEqual(zone, expected) def test_axfr_tsig(self): expected = dns.zone.from_text(axfr_zone, origin="example") with AXFRNanoNameserver(origin="example", keyring=keyring) as ns: xfr = dns.query.xfr( ns.tcp_address[0], "example", port=ns.tcp_address[1], keyring=keyring, keyname="name", ) zone = dns.zone.from_xfr(xfr) self.assertEqual(zone, expected) def test_axfr_root_tsig(self): expected = dns.zone.from_text(axfr_zone, origin=".") with AXFRNanoNameserver(origin=".", keyring=keyring) as ns: xfr = dns.query.xfr( ns.tcp_address[0], ".", port=ns.tcp_address[1], keyring=keyring, keyname="name", ) zone = dns.zone.from_xfr(xfr) self.assertEqual(zone, expected) def test_axfr_udp(self): def bad(): with AXFRNanoNameserver(origin="example") as ns: xfr = dns.query.xfr( ns.udp_address[0], "example", port=ns.udp_address[1], use_udp=True ) l = list(xfr) self.assertRaises(ValueError, bad) def test_axfr_bad_rcode(self): def bad(): # We just use Server here as by default it will refuse. with Server() as ns: xfr = dns.query.xfr( ns.tcp_address[0], "example", port=ns.tcp_address[1] ) l = list(xfr) self.assertRaises(dns.query.TransferError, bad) def test_axfr_trailing_junk(self): # we use the IXFR server here as it returns messages def bad(): with IXFRNanoNameserver(axfr_trailing_junk) as ns: xfr = dns.query.xfr( ns.tcp_address[0], "example", dns.rdatatype.AXFR, port=ns.tcp_address[1], ) l = list(xfr) self.assertRaises(dns.exception.FormError, bad) def test_ixfr_tcp(self): with IXFRNanoNameserver(ixfr_message) as ns: xfr = dns.query.xfr( ns.tcp_address[0], "example", dns.rdatatype.IXFR, port=ns.tcp_address[1], serial=2, relativize=False, ) l = list(xfr) self.assertEqual(len(l), 1) expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_udp(self): with IXFRNanoNameserver(ixfr_message) as ns: xfr = dns.query.xfr( ns.udp_address[0], "example", dns.rdatatype.IXFR, port=ns.udp_address[1], serial=2, relativize=False, use_udp=True, ) l = list(xfr) self.assertEqual(len(l), 1) expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_up_to_date(self): with IXFRNanoNameserver(ixfr_up_to_date_message) as ns: xfr = dns.query.xfr( ns.tcp_address[0], "example", dns.rdatatype.IXFR, port=ns.tcp_address[1], serial=2, relativize=False, ) l = list(xfr) self.assertEqual(len(l), 1) expected = dns.message.from_text( ixfr_up_to_date_message, one_rr_per_rrset=True ) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_trailing_junk(self): def bad(): with IXFRNanoNameserver(ixfr_trailing_junk) as ns: xfr = dns.query.xfr( ns.tcp_address[0], "example", dns.rdatatype.IXFR, port=ns.tcp_address[1], serial=2, relativize=False, ) l = list(xfr) self.assertRaises(dns.exception.FormError, bad) def test_ixfr_base_serial_mismatch(self): def bad(): with IXFRNanoNameserver(ixfr_message) as ns: xfr = dns.query.xfr( ns.tcp_address[0], "example", dns.rdatatype.IXFR, port=ns.tcp_address[1], serial=1, relativize=False, ) l = list(xfr) self.assertRaises(dns.exception.FormError, bad) class TSIGNanoNameserver(Server): def handle(self, request): response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.REFUSED) response.flags |= dns.flags.RA try: if request.qtype == dns.rdatatype.A and request.qclass == dns.rdataclass.IN: rrs = dns.rrset.from_text(request.qname, 300, "IN", "A", "1.2.3.4") response.answer.append(rrs) response.set_rcode(dns.rcode.NOERROR) response.flags |= dns.flags.AA except Exception: pass return response @unittest.skipIf(not _nanonameserver_available, "nanonameserver required") class TsigTests(unittest.TestCase): def test_tsig(self): with TSIGNanoNameserver(keyring=keyring) as ns: qname = dns.name.from_text("example.com") q = dns.message.make_query(qname, "A") q.use_tsig(keyring=keyring, keyname="name") response = dns.query.udp(q, ns.udp_address[0], port=ns.udp_address[1]) self.assertTrue(response.had_tsig) rrs = response.get_rrset( response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A ) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue("1.2.3.4" in seen) @unittest.skipIf(sys.platform == "win32", "low level tests do not work on win32") class LowLevelWaitTests(unittest.TestCase): def test_wait_for(self): try: (l, r) = socket.socketpair() # already expired with self.assertRaises(dns.exception.Timeout): dns.query._wait_for(l, True, True, True, 0) # simple timeout with self.assertRaises(dns.exception.Timeout): dns.query._wait_for(l, False, False, False, time.time() + 0.05) # writable no timeout (not hanging is passing) dns.query._wait_for(l, False, True, False, None) finally: l.close() r.close() class MiscTests(unittest.TestCase): def test_matches_destination(self): self.assertTrue( dns.query._matches_destination( socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1234), True ) ) self.assertTrue( dns.query._matches_destination( socket.AF_INET6, ("1::2", 1234), ("0001::2", 1234), True ) ) self.assertTrue( dns.query._matches_destination( socket.AF_INET, ("10.0.0.1", 1234), None, True ) ) self.assertFalse( dns.query._matches_destination( socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.2", 1234), True ) ) self.assertFalse( dns.query._matches_destination( socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), True ) ) with self.assertRaises(dns.query.UnexpectedSource): dns.query._matches_destination( socket.AF_INET, ("10.0.0.1", 1234), ("10.0.0.1", 1235), False )