# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license # Copyright (C) 2003-2017 Nominum, Inc. # # Permission to use, copy, modify, and distribute this software and its # documentation for any purpose with or without fee is hereby granted, # provided that the above copyright notice and this permission notice # appear in all copies. # # THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES # WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF # MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR # ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES # WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN # ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT # OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. import socket import sys import time import unittest try: import ssl have_ssl = True except Exception: have_ssl = False import dns.exception import dns.inet import dns.message import dns.name import dns.rdataclass import dns.rdatatype import dns.query import dns.tsigkeyring import dns.zone # Some tests require the internet to be available to run, so let's # skip those if it's not there. _network_available = True try: socket.gethostbyname('dnspython.org') except socket.gaierror: _network_available = False # Some tests use a "nano nameserver" for testing. It requires trio # and threading, so try to import it and if it doesn't work, skip # those tests. try: from .nanonameserver import Server _nanonameserver_available = True except ImportError: _nanonameserver_available = False class Server(object): pass # Probe for IPv4 and IPv6 query_addresses = [] for (af, address) in ((socket.AF_INET, '8.8.8.8'), (socket.AF_INET6, '2001:4860:4860::8888')): try: with socket.socket(af, socket.SOCK_DGRAM) as s: # Connecting a UDP socket is supposed to return ENETUNREACH if # no route to the network is present. s.connect((address, 53)) query_addresses.append(address) except Exception: pass keyring = dns.tsigkeyring.from_text({'name' : 'tDz6cfXXGtNivRpQ98hr6A=='}) @unittest.skipIf(not _network_available, "Internet not reachable") class QueryTests(unittest.TestCase): def testQueryUDP(self): for address in query_addresses: qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.udp(q, address) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) def testQueryUDPWithSocket(self): for address in query_addresses: with socket.socket(dns.inet.af_for_address(address), socket.SOCK_DGRAM) as s: s.setblocking(0) qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.udp(q, address, sock=s) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) def testQueryTCP(self): for address in query_addresses: qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tcp(q, address) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) def testQueryTCPWithSocket(self): for address in query_addresses: with socket.socket(dns.inet.af_for_address(address), socket.SOCK_STREAM) as s: ll = dns.inet.low_level_address_tuple((address, 53)) s.connect(ll) s.setblocking(0) qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tcp(q, None, sock=s) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) def testQueryTLS(self): for address in query_addresses: qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tls(q, address) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) @unittest.skipUnless(have_ssl, "No SSL support") def testQueryTLSWithSocket(self): for address in query_addresses: with socket.socket(dns.inet.af_for_address(address), socket.SOCK_STREAM) as base_s: ll = dns.inet.low_level_address_tuple((address, 853)) base_s.connect(ll) ctx = ssl.create_default_context() with ctx.wrap_socket(base_s, server_hostname='dns.google') as s: s.setblocking(0) qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) response = dns.query.tls(q, None, sock=s) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue('8.8.8.8' in seen) self.assertTrue('8.8.4.4' in seen) def testQueryUDPFallback(self): for address in query_addresses: qname = dns.name.from_text('.') q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) (_, tcp) = dns.query.udp_with_fallback(q, address) self.assertTrue(tcp) def testQueryUDPFallbackWithSocket(self): for address in query_addresses: af = dns.inet.af_for_address(address) with socket.socket(af, socket.SOCK_DGRAM) as udp_s: udp_s.setblocking(0) with socket.socket(af, socket.SOCK_STREAM) as tcp_s: ll = dns.inet.low_level_address_tuple((address, 53)) tcp_s.connect(ll) tcp_s.setblocking(0) qname = dns.name.from_text('.') q = dns.message.make_query(qname, dns.rdatatype.DNSKEY) (_, tcp) = dns.query.udp_with_fallback(q, address, udp_sock=udp_s, tcp_sock=tcp_s) self.assertTrue(tcp) def testQueryUDPFallbackNoFallback(self): for address in query_addresses: qname = dns.name.from_text('dns.google.') q = dns.message.make_query(qname, dns.rdatatype.A) (_, tcp) = dns.query.udp_with_fallback(q, address) self.assertFalse(tcp) def testUDPReceiveQuery(self): with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as listener: listener.bind(('127.0.0.1', 0)) with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sender: sender.bind(('127.0.0.1', 0)) q = dns.message.make_query('dns.google', dns.rdatatype.A) dns.query.send_udp(sender, q, listener.getsockname()) expiration = time.time() + 2 (q, _, addr) = dns.query.receive_udp(listener, expiration=expiration) self.assertEqual(addr, sender.getsockname()) # for brevity _d_and_s = dns.query._destination_and_source class DestinationAndSourceTests(unittest.TestCase): def test_af_inferred_from_where(self): (af, d, s) = _d_and_s('1.2.3.4', 53, None, 0) self.assertEqual(af, socket.AF_INET) def test_af_inferred_from_where(self): (af, d, s) = _d_and_s('1::2', 53, None, 0) self.assertEqual(af, socket.AF_INET6) def test_af_inferred_from_source(self): (af, d, s) = _d_and_s('https://example/dns-query', 443, '1.2.3.4', 0, False) self.assertEqual(af, socket.AF_INET) def test_af_mismatch(self): def bad(): (af, d, s) = _d_and_s('1::2', 53, '1.2.3.4', 0) self.assertRaises(ValueError, bad) def test_source_port_but_no_af_inferred(self): def bad(): (af, d, s) = _d_and_s('https://example/dns-query', 443, None, 12345, False) self.assertRaises(ValueError, bad) def test_where_must_be_an_address(self): def bad(): (af, d, s) = _d_and_s('not a valid address', 53, '1.2.3.4', 0) self.assertRaises(ValueError, bad) def test_destination_is_none_of_where_url(self): (af, d, s) = _d_and_s('https://example/dns-query', 443, None, 0, False) self.assertEqual(d, None) def test_v4_wildcard_source_set(self): (af, d, s) = _d_and_s('1.2.3.4', 53, None, 12345) self.assertEqual(s, ('0.0.0.0', 12345)) def test_v6_wildcard_source_set(self): (af, d, s) = _d_and_s('1::2', 53, None, 12345) self.assertEqual(s, ('::', 12345, 0, 0)) class AddressesEqualTestCase(unittest.TestCase): def test_v4(self): self.assertTrue(dns.query._addresses_equal(socket.AF_INET, ('10.0.0.1', 53), ('10.0.0.1', 53))) self.assertFalse(dns.query._addresses_equal(socket.AF_INET, ('10.0.0.1', 53), ('10.0.0.2', 53))) def test_v6(self): self.assertTrue(dns.query._addresses_equal(socket.AF_INET6, ('1::1', 53), ('0001:0000::1', 53))) self.assertFalse(dns.query._addresses_equal(socket.AF_INET6, ('::1', 53), ('::2', 53))) def test_mixed(self): self.assertFalse(dns.query._addresses_equal(socket.AF_INET, ('10.0.0.1', 53), ('::2', 53))) axfr_zone = ''' $ORIGIN example. $TTL 300 @ SOA ns1 root 1 7200 900 1209600 86400 @ NS ns1 @ NS ns2 ns1 A 10.0.0.1 ns2 A 10.0.0.1 ''' class AXFRNanoNameserver(Server): def handle(self, request): self.zone = dns.zone.from_text(axfr_zone, origin='example') self.origin = self.zone.origin items = [] soa = self.zone.find_rrset(dns.name.empty, dns.rdatatype.SOA) response = dns.message.make_response(request.message) response.flags |= dns.flags.AA response.answer.append(soa) items.append(response) response = dns.message.make_response(request.message) response.question = [] response.flags |= dns.flags.AA for (name, rdataset) in self.zone.iterate_rdatasets(): if rdataset.rdtype == dns.rdatatype.SOA and \ name == dns.name.empty: continue rrset = dns.rrset.RRset(name, rdataset.rdclass, rdataset.rdtype, rdataset.covers) rrset.update(rdataset) response.answer.append(rrset) items.append(response) response = dns.message.make_response(request.message) response.question = [] response.flags |= dns.flags.AA response.answer.append(soa) items.append(response) return items ixfr_message = '''id 12345 opcode QUERY rcode NOERROR flags AA ;QUESTION example. IN IXFR ;ANSWER example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400 example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400 deleted.example. 300 IN A 10.0.0.1 changed.example. 300 IN A 10.0.0.2 example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400 changed.example. 300 IN A 10.0.0.4 added.example. 300 IN A 10.0.0.3 example. 300 SOA ns1.example. root.example. 3 7200 900 1209600 86400 example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400 added2.example. 300 IN A 10.0.0.5 example. 300 IN SOA ns1.example. root.example. 4 7200 900 1209600 86400 ''' ixfr_trailing_junk = ixfr_message + 'junk.example. 300 IN A 10.0.0.6' ixfr_up_to_date_message = '''id 12345 opcode QUERY rcode NOERROR flags AA ;QUESTION example. IN IXFR ;ANSWER example. 300 IN SOA ns1.example. root.example. 2 7200 900 1209600 86400 ''' axfr_trailing_junk = '''id 12345 opcode QUERY rcode NOERROR flags AA ;QUESTION example. IN AXFR ;ANSWER example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400 added.example. 300 IN A 10.0.0.3 added2.example. 300 IN A 10.0.0.5 changed.example. 300 IN A 10.0.0.4 example. 300 IN SOA ns1.example. root.example. 3 7200 900 1209600 86400 junk.example. 300 IN A 10.0.0.6 ''' class IXFRNanoNameserver(Server): def __init__(self, response_text): super().__init__() self.response_text = response_text def handle(self, request): try: r = dns.message.from_text(self.response_text, one_rr_per_rrset=True) r.id = request.message.id return r except Exception: pass @unittest.skipIf(not _nanonameserver_available, "nanonameserver required") class XfrTests(unittest.TestCase): def test_axfr(self): expected = dns.zone.from_text(axfr_zone, origin='example') with AXFRNanoNameserver() as ns: xfr = dns.query.xfr(ns.tcp_address[0], 'example', port=ns.tcp_address[1]) zone = dns.zone.from_xfr(xfr) self.assertEqual(zone, expected) def test_axfr_tsig(self): expected = dns.zone.from_text(axfr_zone, origin='example') with AXFRNanoNameserver(keyring=keyring) as ns: xfr = dns.query.xfr(ns.tcp_address[0], 'example', port=ns.tcp_address[1], keyring=keyring, keyname='name') zone = dns.zone.from_xfr(xfr) self.assertEqual(zone, expected) def test_axfr_udp(self): def bad(): with AXFRNanoNameserver() as ns: xfr = dns.query.xfr(ns.udp_address[0], 'example', port=ns.udp_address[1], use_udp=True) l = list(xfr) self.assertRaises(ValueError, bad) def test_axfr_bad_rcode(self): def bad(): # We just use Server here as by default it will refuse. with Server() as ns: xfr = dns.query.xfr(ns.tcp_address[0], 'example', port=ns.tcp_address[1]) l = list(xfr) self.assertRaises(dns.query.TransferError, bad) def test_axfr_trailing_junk(self): # we use the IXFR server here as it returns messages def bad(): with IXFRNanoNameserver(axfr_trailing_junk) as ns: xfr = dns.query.xfr(ns.tcp_address[0], 'example', dns.rdatatype.AXFR, port=ns.tcp_address[1]) l = list(xfr) self.assertRaises(dns.exception.FormError, bad) def test_ixfr_tcp(self): with IXFRNanoNameserver(ixfr_message) as ns: xfr = dns.query.xfr(ns.tcp_address[0], 'example', dns.rdatatype.IXFR, port=ns.tcp_address[1], serial=2, relativize=False) l = list(xfr) self.assertEqual(len(l), 1) expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_udp(self): with IXFRNanoNameserver(ixfr_message) as ns: xfr = dns.query.xfr(ns.udp_address[0], 'example', dns.rdatatype.IXFR, port=ns.udp_address[1], serial=2, relativize=False, use_udp=True) l = list(xfr) self.assertEqual(len(l), 1) expected = dns.message.from_text(ixfr_message, one_rr_per_rrset=True) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_up_to_date(self): with IXFRNanoNameserver(ixfr_up_to_date_message) as ns: xfr = dns.query.xfr(ns.tcp_address[0], 'example', dns.rdatatype.IXFR, port=ns.tcp_address[1], serial=2, relativize=False) l = list(xfr) self.assertEqual(len(l), 1) expected = dns.message.from_text(ixfr_up_to_date_message, one_rr_per_rrset=True) expected.id = l[0].id self.assertEqual(l[0], expected) def test_ixfr_trailing_junk(self): def bad(): with IXFRNanoNameserver(ixfr_trailing_junk) as ns: xfr = dns.query.xfr(ns.tcp_address[0], 'example', dns.rdatatype.IXFR, port=ns.tcp_address[1], serial=2, relativize=False) l = list(xfr) self.assertRaises(dns.exception.FormError, bad) def test_ixfr_base_serial_mismatch(self): def bad(): with IXFRNanoNameserver(ixfr_message) as ns: xfr = dns.query.xfr(ns.tcp_address[0], 'example', dns.rdatatype.IXFR, port=ns.tcp_address[1], serial=1, relativize=False) l = list(xfr) self.assertRaises(dns.exception.FormError, bad) class TSIGNanoNameserver(Server): def handle(self, request): response = dns.message.make_response(request.message) response.set_rcode(dns.rcode.REFUSED) response.flags |= dns.flags.RA try: if request.qtype == dns.rdatatype.A and \ request.qclass == dns.rdataclass.IN: rrs = dns.rrset.from_text(request.qname, 300, 'IN', 'A', '1.2.3.4') response.answer.append(rrs) response.set_rcode(dns.rcode.NOERROR) response.flags |= dns.flags.AA except Exception: pass return response @unittest.skipIf(not _nanonameserver_available, "nanonameserver required") class TsigTests(unittest.TestCase): def test_tsig(self): with TSIGNanoNameserver(keyring=keyring) as ns: qname = dns.name.from_text('example.com') q = dns.message.make_query(qname, 'A') q.use_tsig(keyring=keyring, keyname='name') response = dns.query.udp(q, ns.udp_address[0], port=ns.udp_address[1]) self.assertTrue(response.had_tsig) rrs = response.get_rrset(response.answer, qname, dns.rdataclass.IN, dns.rdatatype.A) self.assertTrue(rrs is not None) seen = set([rdata.address for rdata in rrs]) self.assertTrue('1.2.3.4' in seen) @unittest.skipIf(sys.platform == 'win32', 'low level tests do not work on win32') class LowLevelWaitTests(unittest.TestCase): def test_wait_for(self): try: (l, r) = socket.socketpair() # already expired with self.assertRaises(dns.exception.Timeout): dns.query._wait_for(l, True, True, True, 0) # simple timeout with self.assertRaises(dns.exception.Timeout): dns.query._wait_for(l, False, False, False, time.time() + 0.05) # writable no timeout (not hanging is passing) dns.query._wait_for(l, False, True, False, None) finally: l.close() r.close()