summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@nominum.com>2011-07-13 12:54:54 -0700
committerBob Halley <halley@nominum.com>2011-07-13 12:54:54 -0700
commit37884a38dd59f643862262df386c8971aeb0f70f (patch)
treecadea094dd023a0591b1e68bfc4b8ce8b7f74b46
parentc1bc09d236fe7784a303cd601a8a973d1758e3a6 (diff)
downloaddnspython-37884a38dd59f643862262df386c8971aeb0f70f.tar.gz
add dns.resolver.override_system_resolver() and dns.resolver.restore_system_resolver()
-rw-r--r--ChangeLog12
-rw-r--r--dns/resolver.py235
2 files changed, 247 insertions, 0 deletions
diff --git a/ChangeLog b/ChangeLog
index 08d93d8..df20c87 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,15 @@
+2011-07-13 Bob Halley <halley@dnspython.org>
+
+ * dns/resolver.py: dns.resolver.override_system_resolver()
+ overrides the socket module's versions of getaddrinfo(),
+ getnameinfo(), getfqdn(), gethostbyname(), gethostbyname_ex() and
+ gethostbyaddr() with an implementation which uses a dnspython stub
+ resolver instead of the system's stub resolver. This can be
+ useful in testing situations where you want to control the
+ resolution behavior of python code without having to change the
+ system's resolver settings (e.g. /etc/resolv.conf).
+ dns.resolver.restore_system_resolver() undoes the change.
+
2011-07-08 Bob Halley <halley@dnspython.org>
* dns/ipv4.py: dnspython now provides its own, stricter, versions
diff --git a/dns/resolver.py b/dns/resolver.py
index e711dee..ab36994 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -23,12 +23,15 @@ import sys
import time
import dns.exception
+import dns.ipv4
+import dns.ipv6
import dns.message
import dns.name
import dns.query
import dns.rcode
import dns.rdataclass
import dns.rdatatype
+import dns.reversename
if sys.platform == 'win32':
import _winreg
@@ -800,3 +803,235 @@ def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
name = name.parent()
except dns.name.NoParent:
raise NoRootSOA
+
+#
+# Support for overriding the system resolver for all python code in the
+# running process.
+#
+
+_protocols_for_socktype = {
+ socket.SOCK_DGRAM : [socket.SOL_UDP],
+ socket.SOCK_STREAM : [socket.SOL_TCP],
+ }
+
+_resolver = None
+_original_getaddrinfo = socket.getaddrinfo
+_original_getnameinfo = socket.getnameinfo
+_original_getfqdn = socket.getfqdn
+_original_gethostbyname = socket.gethostbyname
+_original_gethostbyname_ex = socket.gethostbyname_ex
+_original_gethostbyaddr = socket.gethostbyaddr
+
+def _getaddrinfo(host=None, service=None, family=socket.AF_UNSPEC, socktype=0,
+ proto=0, flags=0):
+ if flags & (socket.AI_ADDRCONFIG|socket.AI_V4MAPPED) != 0:
+ raise NotImplementedError
+ if host is None and service is None:
+ raise socket.gaierror(socket.EAI_NONAME)
+ v6addrs = []
+ v4addrs = []
+ canonical_name = None
+ try:
+ # Is host None or a V6 address literal?
+ if host is None:
+ canonical_name = 'localhost'
+ if flags & socket.AI_PASSIVE != 0:
+ v6addrs.append('::')
+ v4addrs.append('0.0.0.0')
+ else:
+ v6addrs.append('::1')
+ v4addrs.append('127.0.0.1')
+ else:
+ parts = host.split('%')
+ if len(parts) == 2:
+ ahost = parts[0]
+ else:
+ ahost = host
+ addr = dns.ipv6.inet_aton(ahost)
+ v6addrs.append(host)
+ canonical_name = host
+ except:
+ try:
+ # Is it a V4 address literal?
+ addr = dns.ipv4.inet_aton(host)
+ v4addrs.append(host)
+ canonical_name = host
+ except:
+ if flags & socket.AI_NUMERICHOST == 0:
+ try:
+ qname = None
+ if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+ v6 = _resolver.query(host, dns.rdatatype.AAAA,
+ raise_on_no_answer=False)
+ # Note that setting host ensures we query the same name
+ # for A as we did for AAAA.
+ host = v6.qname
+ canonical_name = v6.canonical_name.to_text(True)
+ if v6.rrset is not None:
+ for rdata in v6.rrset:
+ v6addrs.append(rdata.address)
+ if family == socket.AF_INET or family == socket.AF_UNSPEC:
+ v4 = _resolver.query(host, dns.rdatatype.A,
+ raise_on_no_answer=False)
+ host = v4.qname
+ canonical_name = v4.canonical_name.to_text(True)
+ if v4.rrset is not None:
+ for rdata in v4.rrset:
+ v4addrs.append(rdata.address)
+ except dns.resolver.NXDOMAIN:
+ raise socket.gaierror(socket.EAI_NONAME)
+ except:
+ raise socket.gaierror(socket.EAI_SYSTEM)
+ port = None
+ try:
+ # Is it a port literal?
+ if service is None:
+ port = 0
+ else:
+ port = int(service)
+ except:
+ if flags & socket.AI_NUMERICSERV == 0:
+ try:
+ port = socket.getservbyname(service)
+ except:
+ pass
+ if port is None:
+ raise socket.gaierror(socket.EAI_NONAME)
+ tuples = []
+ if socktype == 0:
+ socktypes = [socket.SOCK_DGRAM, socket.SOCK_STREAM]
+ else:
+ socktypes = [socktype]
+ if flags & socket.AI_CANONNAME != 0:
+ cname = canonical_name
+ else:
+ cname = ''
+ if family == socket.AF_INET6 or family == socket.AF_UNSPEC:
+ for addr in v6addrs:
+ for socktype in socktypes:
+ for proto in _protocols_for_socktype[socktype]:
+ tuples.append((socket.AF_INET6, socktype, proto,
+ cname, (addr, port, 0, 0)))
+ if family == socket.AF_INET or family == socket.AF_UNSPEC:
+ for addr in v4addrs:
+ for socktype in socktypes:
+ for proto in _protocols_for_socktype[socktype]:
+ tuples.append((socket.AF_INET, socktype, proto,
+ cname, (addr, port)))
+ if len(tuples) == 0:
+ raise socket.gaierror(socket.EAI_NONAME)
+ return tuples
+
+def _getnameinfo(sockaddr, flags=0):
+ host = sockaddr[0]
+ port = sockaddr[1]
+ if len(sockaddr) == 4:
+ scope = sockaddr[3]
+ family = socket.AF_INET6
+ else:
+ scope = None
+ family = socket.AF_INET
+ tuples = _getaddrinfo(host, port, family, socket.SOCK_STREAM,
+ socket.SOL_TCP, 0)
+ if len(tuples) > 1:
+ raise socket.error('sockaddr resolved to multiple addresses')
+ addr = tuples[0][4][0]
+ if flags & socket.NI_DGRAM:
+ pname = 'udp'
+ else:
+ pname = 'tcp'
+ qname = dns.reversename.from_address(addr)
+ if flags & socket.NI_NUMERICHOST == 0:
+ try:
+ answer = _resolver.query(qname, 'PTR')
+ hostname = answer.rrset[0].target.to_text(True)
+ except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer):
+ if flags & socket.NI_NAMEREQD:
+ raise socket.gaierror(socket.EAI_NONAME)
+ hostname = addr
+ if scope is not None:
+ hostname += '%' + str(scope)
+ else:
+ hostname = addr
+ if scope is not None:
+ hostname += '%' + str(scope)
+ if flags & socket.NI_NUMERICSERV:
+ service = str(port)
+ else:
+ service = socket.getservbyport(port, pname)
+ return (hostname, service)
+
+def _getfqdn(name=None):
+ if name is None:
+ name = socket.gethostname()
+ return _getnameinfo(_getaddrinfo(name, 80)[0][4])[0]
+
+def _gethostbyname(name):
+ return _gethostbyname_ex(name)[2][0]
+
+def _gethostbyname_ex(name):
+ aliases = []
+ addresses = []
+ tuples = _getaddrinfo(name, 0, socket.AF_INET, socket.SOCK_STREAM,
+ socket.SOL_TCP, socket.AI_CANONNAME)
+ canonical = tuples[0][3]
+ for item in tuples:
+ addresses.append(item[4][0])
+ # XXX we just ignore aliases
+ return (canonical, aliases, addresses)
+
+def _gethostbyaddr(ip):
+ try:
+ addr = dns.ipv6.inet_aton(ip)
+ sockaddr = (ip, 80, 0, 0)
+ family = socket.AF_INET6
+ except:
+ sockaddr = (ip, 80)
+ family = socket.AF_INET
+ (name, port) = _getnameinfo(sockaddr, socket.NI_NAMEREQD)
+ aliases = []
+ addresses = []
+ tuples = _getaddrinfo(name, 0, family, socket.SOCK_STREAM, socket.SOL_TCP,
+ socket.AI_CANONNAME)
+ canonical = tuples[0][3]
+ for item in tuples:
+ addresses.append(item[4][0])
+ # XXX we just ignore aliases
+ return (canonical, aliases, addresses)
+
+def override_system_resolver(resolver=None):
+ """Override the system resolver routines in the socket module with
+ versions which use dnspython's resolver.
+
+ This can be useful in testing situations where you want to control
+ the resolution behavior of python code without having to change
+ the system's resolver settings (e.g. /etc/resolv.conf).
+
+ The resolver to use may be specified; if it's not, the default
+ resolver will be used.
+
+ @param resolver: the resolver to use
+ @type resolver: dns.resolver.Resolver object or None
+ """
+ if resolver is None:
+ resolver = get_default_resolver()
+ global _resolver
+ _resolver = resolver
+ socket.getaddrinfo = _getaddrinfo
+ socket.getnameinfo = _getnameinfo
+ socket.getfqdn = _getfqdn
+ socket.gethostbyname = _gethostbyname
+ socket.gethostbyname_ex = _gethostbyname_ex
+ socket.gethostbyaddr = _gethostbyaddr
+
+def restore_system_resolver():
+ """Undo the effects of override_system_resolver().
+ """
+ global _resolver
+ _resolver = None
+ socket.getaddrinfo = _original_getaddrinfo
+ socket.getnameinfo = _original_getnameinfo
+ socket.getfqdn = _original_getfqdn
+ socket.gethostbyname = _original_gethostbyname
+ socket.gethostbyname_ex = _original_gethostbyname_ex
+ socket.gethostbyaddr = _original_gethostbyaddr