diff options
author | Floris Bruynooghe <flub@devork.be> | 2013-11-07 03:06:34 +0000 |
---|---|---|
committer | Sergey Shepelev <temotor@gmail.com> | 2015-02-20 01:00:53 +0300 |
commit | a6ce444265d36fb23361809d40da73caf4864487 (patch) | |
tree | 3756347f32e92fa2af861c2975eccc1aa623170c | |
parent | 184dce2104b5a4ea24a8482013655140442fadc9 (diff) | |
download | eventlet-a6ce444265d36fb23361809d40da73caf4864487.tar.gz |
greendns: IPv6 support, improved handling of /etc/hostsbb-40-greendns-ipv6
https://github.com/eventlet/eventlet/issues/8
https://bitbucket.org/eventlet/eventlet/issue/105/name-resolution-needs-to-support-ipv6
https://bitbucket.org/eventlet/eventlet/pull-request/40/improve-asynchronous-ipv6-address
-rw-r--r-- | eventlet/support/greendns.py | 546 | ||||
-rw-r--r-- | tests/__init__.py | 16 | ||||
-rw-r--r-- | tests/greendns_test.py | 798 |
3 files changed, 1198 insertions, 162 deletions
diff --git a/eventlet/support/greendns.py b/eventlet/support/greendns.py index c357866..07d13c5 100644 --- a/eventlet/support/greendns.py +++ b/eventlet/support/greendns.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python -''' - greendns - non-blocking DNS support for Eventlet +'''greendns - non-blocking DNS support for Eventlet ''' # Portions of this code taken from the gogreen project: @@ -35,187 +33,447 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import struct -import sys from eventlet import patcher from eventlet.green import _socket_nodns -from eventlet.green import time +from eventlet.green import os from eventlet.green import select +from eventlet.green import time +from eventlet.support import six + dns = patcher.import_patched('dns', socket=_socket_nodns, time=time, select=select) for pkg in ('dns.query', 'dns.exception', 'dns.inet', 'dns.message', - 'dns.rdatatype', 'dns.resolver', 'dns.reversename'): - setattr(dns, pkg.split('.')[1], patcher.import_patched( - pkg, - socket=_socket_nodns, - time=time, - select=select)) + 'dns.rdatatype', 'dns.resolver', 'dns.reversename', + 'dns.rdataclass', 'dns.name', 'dns.rrset', 'dns.rdtypes', + 'dns.ipv4', 'dns.ipv6'): + setattr(dns, pkg.split('.')[1], + patcher.import_patched(pkg, socket=_socket_nodns, + time=time, select=select)) +for pkg in ['dns.rdtypes.IN', 'dns.rdtypes.ANY']: + setattr(dns.rdtypes, pkg.split('.')[-1], + patcher.import_patched(pkg, socket=_socket_nodns, + time=time, select=select)) +for pkg in ['dns.rdtypes.IN.A', 'dns.rdtypes.IN.AAAA']: + setattr(dns.rdtypes.IN, pkg.split('.')[-1], + patcher.import_patched(pkg, socket=_socket_nodns, + time=time, select=select)) +for pkg in ['dns.rdtypes.ANY.CNAME']: + setattr(dns.rdtypes.ANY, pkg.split('.')[-1], + patcher.import_patched(pkg, socket=_socket_nodns, + time=time, select=select)) + socket = _socket_nodns DNS_QUERY_TIMEOUT = 10.0 +HOSTS_TTL = 10.0 +EAI_EAGAIN_ERROR = socket.gaierror(socket.EAI_AGAIN, 'Lookup timed out') +EAI_NODATA_ERROR = socket.gaierror(socket.EAI_NODATA, 'No address associated with hostname') +EAI_NONAME_ERROR = socket.gaierror(socket.EAI_NONAME, 'Name or service not known') -# -# Resolver instance used to perfrom DNS lookups. -# -class FakeAnswer(list): - expiration = 0 +def is_ipv4_addr(host): + """Return True if host is a valid IPv4 address""" + if not isinstance(host, six.string_types): + return False + try: + dns.ipv4.inet_aton(host) + except dns.exception.SyntaxError: + return False + else: + return True -class FakeRecord(object): - pass +def is_ipv6_addr(host): + """Return True if host is a valid IPv6 address""" + if not isinstance(host, six.string_types): + return False + try: + dns.ipv6.inet_aton(host) + except dns.exception.SyntaxError: + return False + else: + return True -class ResolverProxy(object): - def __init__(self, *args, **kwargs): - self._resolver = None - self._filename = kwargs.get('filename', '/etc/resolv.conf') - self._hosts = {} - if kwargs.pop('dev', False): - self._load_etc_hosts() - - def _load_etc_hosts(self): + +def is_ip_addr(host): + """Return True if host is a valid IPv4 or IPv6 address""" + return is_ipv4_addr(host) or is_ipv6_addr(host) + + +class HostsAnswer(dns.resolver.Answer): + """Answer class for HostsResolver object""" + + def __init__(self, qname, rdtype, rdclass, rrset, raise_on_no_answer=True): + """Create a new answer + + :qname: A dns.name.Name instance of the query name + :rdtype: The rdatatype of the query + :rdclass: The rdataclass of the query + :rrset: The dns.rrset.RRset with the response, must have ttl attribute + :raise_on_no_answer: Whether to raise dns.resolver.NoAnswer if no + answer. + """ + self.response = None + self.qname = qname + self.rdtype = rdtype + self.rdclass = rdclass + self.canonical_name = qname + if not rrset and raise_on_no_answer: + raise dns.resolver.NoAnswer() + self.rrset = rrset + self.expiration = (time.time() + + rrset.ttl if hasattr(rrset, 'ttl') else 0) + + +class HostsResolver(object): + """Class to parse the hosts file + + Attributes + ---------- + + :fname: The filename of the hosts file in use. + :interval: The time between checking for hosts file modification + """ + + def __init__(self, fname=None, interval=HOSTS_TTL): + self._v4 = {} # name -> ipv4 + self._v6 = {} # name -> ipv6 + self._aliases = {} # name -> cannonical_name + self.interval = interval + self.fname = fname + if fname is None: + if os.name == 'posix': + self.fname = '/etc/hosts' + elif os.name == 'nt': + self.fname = os.path.expandvars( + r'%SystemRoot%\system32\drivers\etc\hosts') + self._last_load = 0 + if self.fname: + self._load() + + def _readlines(self): + """Read the contents of the hosts file + + Return list of lines, comment lines and empty lines are + excluded. + + Note that this performs disk I/O so can be blocking. + """ + lines = [] try: - fd = open('/etc/hosts', 'r') - contents = fd.read() - fd.close() + with open(self.fname, 'rU') as fp: + for line in fp: + line = line.strip() + if line and line[0] != '#': + lines.append(line) except (IOError, OSError): - return - contents = [line for line in contents.split('\n') if line and not line[0] == '#'] - for line in contents: - line = line.replace('\t', ' ') - parts = line.split(' ') - parts = [p for p in parts if p] - if not len(parts): + pass + return lines + + def _load(self): + """Load hosts file + + This will unconditionally (re)load the data from the hosts + file. + """ + lines = self._readlines() + self._v4.clear() + self._v6.clear() + self._aliases.clear() + for line in lines: + parts = line.split() + if len(parts) < 2: continue - ip = parts[0] - for part in parts[1:]: - self._hosts[part] = ip + ip = parts.pop(0) + if is_ipv4_addr(ip): + ipmap = self._v4 + elif is_ipv6_addr(ip): + if ip.startswith('fe80'): + # Do not use link-local addresses, OSX stores these here + continue + ipmap = self._v6 + else: + continue + cname = parts.pop(0) + ipmap[cname] = ip + for alias in parts: + ipmap[alias] = ip + self._aliases[alias] = cname + self._last_load = time.time() + + def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, + tcp=False, source=None, raise_on_no_answer=True): + """Query the hosts file + + The known rdtypes are dns.rdatatype.A, dns.rdatatype.AAAA and + dns.rdatatype.CNAME. + + The ``rdclass`` parameter must be dns.rdataclass.IN while the + ``tcp`` and ``source`` parameters are ignored. + + Return a HostAnswer instance or raise a dns.resolver.NoAnswer + exception. + """ + now = time.time() + if self._last_load + self.interval < now: + self._load() + rdclass = dns.rdataclass.IN + if isinstance(qname, six.string_types): + name = qname + qname = dns.name.from_text(qname) + else: + name = str(qname) + rrset = dns.rrset.RRset(qname, rdclass, rdtype) + rrset.ttl = self._last_load + self.interval - now + if rdclass == dns.rdataclass.IN and rdtype == dns.rdatatype.A: + addr = self._v4.get(name) + if not addr and qname.is_absolute(): + addr = self._v4.get(name[:-1]) + if addr: + rrset.add(dns.rdtypes.IN.A.A(rdclass, rdtype, addr)) + elif rdclass == dns.rdataclass.IN and rdtype == dns.rdatatype.AAAA: + addr = self._v6.get(name) + if not addr and qname.is_absolute(): + addr = self._v6.get(name[:-1]) + if addr: + rrset.add(dns.rdtypes.IN.AAAA.AAAA(rdclass, rdtype, addr)) + elif rdclass == dns.rdataclass.IN and rdtype == dns.rdatatype.CNAME: + cname = self._aliases.get(name) + if not cname and qname.is_absolute(): + cname = self._aliases.get(name[:-1]) + if cname: + rrset.add(dns.rdtypes.ANY.CNAME.CNAME( + rdclass, rdtype, dns.name.from_text(cname))) + return HostsAnswer(qname, rdtype, rdclass, rrset, raise_on_no_answer) + + def getaliases(self, hostname): + """Return a list of all the aliases of a given cname""" + # Due to the way store aliases this is a bit inefficient, this + # clearly was an afterthought. But this is only used by + # gethostbyname_ex so it's probably fine. + aliases = [] + if hostname in self._aliases: + cannon = self._aliases[hostname] + else: + cannon = hostname + aliases.append(cannon) + for alias, cname in self._aliases.iteritems(): + if cannon == cname: + aliases.append(alias) + aliases.remove(hostname) + return aliases - def clear(self): - self._resolver = None - - def query(self, *args, **kwargs): - if self._resolver is None: - self._resolver = dns.resolver.Resolver(filename=self._filename) - self._resolver.cache = dns.resolver.Cache() - - query = args[0] - if query is None: - args = list(args) - query = args[0] = '0.0.0.0' - if self._hosts and self._hosts.get(query): - answer = FakeAnswer() - record = FakeRecord() - setattr(record, 'address', self._hosts[query]) - answer.append(record) - return answer - return self._resolver.query(*args, **kwargs) -# -# cache -# -resolver = ResolverProxy(dev=True) +class ResolverProxy(object): + """Resolver class which can also use /etc/hosts -def resolve(name): - error = None - rrset = None + Initialise with a HostsResolver instance in order for it to also + use the hosts file. + """ - if rrset is None or time.time() > rrset.expiration: - try: - rrset = resolver.query(name) - except dns.exception.Timeout: - error = (socket.EAI_AGAIN, 'Lookup timed out') - except dns.exception.DNSException: - error = (socket.EAI_NODATA, 'No address associated with hostname') - else: - pass - # responses.insert(name, rrset) + def __init__(self, hosts_resolver=None, filename='/etc/resolv.conf'): + """Initialise the resolver proxy + + :param hosts_resolver: An instance of HostsResolver to use. + + :param filename: The filename containing the resolver + configuration. The default value is correct for both UNIX + and Windows, on Windows it will result in the configuration + being read from the Windows registry. + """ + self._hosts = hosts_resolver + self._filename = filename + self._resolver = dns.resolver.Resolver(filename=self._filename) + self._resolver.cache = dns.resolver.LRUCache() - if error: - if rrset is None: - raise socket.gaierror(error) + def clear(self): + self._resolver = dns.resolver.Resolver(filename=self._filename) + self._resolver.cache = dns.resolver.Cache() + + def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, + tcp=False, source=None, raise_on_no_answer=True): + """Query the resolver, using /etc/hosts if enabled""" + if qname is None: + qname = '0.0.0.0' + if rdclass == dns.rdataclass.IN and self._hosts: + try: + return self._hosts.query(qname, rdtype) + except dns.resolver.NoAnswer: + pass + return self._resolver.query(qname, rdtype, rdclass, + tcp, source, raise_on_no_answer) + + def getaliases(self, hostname): + """Return a list of all the aliases of a given hostname""" + if self._hosts: + aliases = self._hosts.getaliases(hostname) else: - sys.stderr.write('DNS error: %r %r\n' % (name, error)) - return rrset + aliases = [] + while True: + try: + ans = self._resolver.query(hostname, dns.rdatatype.CNAME) + except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN): + break + else: + aliases.extend(str(rr.target) for rr in ans.rrset) + hostname = ans[0].target + return aliases -# -# methods -# -def getaliases(host): - """Checks for aliases of the given hostname (cname records) - returns a list of alias targets - will return an empty list if no aliases +resolver = ResolverProxy(hosts_resolver=HostsResolver()) + + +def resolve(name, family=socket.AF_INET, raises=True): + """Resolve a name for a given family using the global resolver proxy + + This method is called by the global getaddrinfo() function. + + Return a dns.resolver.Answer instance. If there is no answer it's + rrset will be emtpy. """ - cnames = [] - error = None + if family == socket.AF_INET: + rdtype = dns.rdatatype.A + elif family == socket.AF_INET6: + rdtype = dns.rdatatype.AAAA + else: + raise socket.gaierror(socket.EAI_FAMILY, + 'Address family not supported') + try: + try: + return resolver.query(name, rdtype, raise_on_no_answer=raises) + except dns.resolver.NXDOMAIN: + if not raises: + return HostsAnswer(dns.name.Name(name), + rdtype, dns.rdataclass.IN, None, False) + raise + except dns.exception.Timeout: + raise EAI_EAGAIN_ERROR + except dns.exception.DNSException: + raise EAI_NODATA_ERROR + +def resolve_cname(host): + """Return the canonical name of a hostname""" try: - answers = dns.resolver.query(host, 'cname') + ans = resolver.query(host, dns.rdatatype.CNAME) + except dns.resolver.NoAnswer: + return host except dns.exception.Timeout: - error = (socket.EAI_AGAIN, 'Lookup timed out') + raise EAI_EAGAIN_ERROR except dns.exception.DNSException: - error = (socket.EAI_NODATA, 'No address associated with hostname') + raise EAI_NODATA_ERROR else: - for record in answers: - cnames.append(str(answers[0].target)) + return str(ans[0].target) + - if error: - sys.stderr.write('DNS error: %r %r\n' % (host, error)) +def getaliases(host): + """Return a list of for aliases for the given hostname - return cnames + This method does translate the dnspython exceptions into + socket.gaierror exceptions. If no aliases are available an empty + list will be returned. + """ + try: + return resolver.getaliases(host) + except dns.exception.Timeout: + raise EAI_EAGAIN_ERROR + except dns.exception.DNSException: + raise EAI_NODATA_ERROR -def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0): - """Replacement for Python's socket.getaddrinfo. +def _getaddrinfo_lookup(host, family, flags): + """Resolve a hostname to a list of addresses - Currently only supports IPv4. At present, flags are not - implemented. + Helper function for getaddrinfo. """ - socktype = socktype or socket.SOCK_STREAM + if flags & socket.AI_NUMERICHOST: + raise EAI_NONAME_ERROR + addrs = [] + if family == socket.AF_UNSPEC: + for qfamily in [socket.AF_INET6, socket.AF_INET]: + answer = resolve(host, qfamily, 0) + if answer.rrset: + addrs.extend([rr.address for rr in answer.rrset]) + elif family == socket.AF_INET6 and flags & socket.AI_V4MAPPED: + answer = resolve(host, socket.AF_INET6, 0) + if answer.rrset: + addrs = [rr.address for rr in answer.rrset] + if not addrs or flags & socket.AI_ALL: + answer = resolve(host, socket.AF_INET, 0) + if answer.rrset: + addrs = ['::ffff:' + rr.address for rr in answer.rrset] + else: + answer = resolve(host, family, 0) + if answer.rrset: + addrs = [rr.address for rr in answer.rrset] + return str(answer.qname), addrs - if is_ipv4_addr(host): - return [(socket.AF_INET, socktype, proto, '', (host, port))] - rrset = resolve(host) - value = [] +def getaddrinfo(host, port, family=0, socktype=0, proto=0, flags=0): + """Replacement for Python's socket.getaddrinfo - for rr in rrset: - value.append((socket.AF_INET, socktype, proto, '', (rr.address, port))) - return value + This does the A and AAAA lookups asynchronously after which it + calls the OS' getaddrinfo(3) using the AI_NUMERICHOST flag. This + flag ensures getaddrinfo(3) does not use the network itself and + allows us to respect all the other arguments like the native OS. + """ + if isinstance(host, six.string_types): + host = host.encode('idna') + if host is not None and not is_ip_addr(host): + qname, addrs = _getaddrinfo_lookup(host, family, flags) + else: + qname = host + addrs = [host] + aiflags = (flags | socket.AI_NUMERICHOST) & (0xffff ^ socket.AI_CANONNAME) + res = [] + err = None + for addr in addrs: + try: + ai = socket.getaddrinfo(addr, port, family, + socktype, proto, aiflags) + except socket.error as e: + if flags & socket.AI_ADDRCONFIG: + err = e + continue + raise + res.extend(ai) + if not res: + if err: + raise err + raise socket.gaierror(socket.EAI_NONAME, 'No address found') + if flags & socket.AI_CANONNAME: + if not is_ip_addr(qname): + qname = resolve_cname(qname).decode('idna') + ai = res[0] + res[0] = (ai[0], ai[1], ai[2], qname, ai[4]) + return res def gethostbyname(hostname): - """Replacement for Python's socket.gethostbyname. - - Currently only supports IPv4. - """ + """Replacement for Python's socket.gethostbyname""" if is_ipv4_addr(hostname): return hostname - rrset = resolve(hostname) return rrset[0].address def gethostbyname_ex(hostname): - """Replacement for Python's socket.gethostbyname_ex. - - Currently only supports IPv4. - """ + """Replacement for Python's socket.gethostbyname_ex""" if is_ipv4_addr(hostname): return (hostname, [], [hostname]) - - rrset = resolve(hostname) - addrs = [] - - for rr in rrset: - addrs.append(rr.address) - return (hostname, [], addrs) + ans = resolve(hostname) + aliases = getaliases(hostname) + addrs = [rr.address for rr in ans.rrset] + qname = str(ans.qname) + if qname[-1] == '.': + qname = qname[:-1] + return (qname, aliases, addrs) def getnameinfo(sockaddr, flags): @@ -232,12 +490,11 @@ def getnameinfo(sockaddr, flags): raise TypeError('getnameinfo() argument 1 must be a tuple') else: # must be ipv6 sockaddr, pretending we don't know how to resolve it - raise socket.gaierror(-2, 'name or service not known') + raise EAI_NONAME_ERROR if (flags & socket.NI_NAMEREQD) and (flags & socket.NI_NUMERICHOST): # Conflicting flags. Punt. - raise socket.gaierror( - (socket.EAI_NONAME, 'Name or service not known')) + raise EAI_NONAME_ERROR if is_ipv4_addr(host): try: @@ -248,11 +505,10 @@ def getnameinfo(sockaddr, flags): host = rrset[0].target.to_text(omit_final_dot=True) except dns.exception.Timeout: if flags & socket.NI_NAMEREQD: - raise socket.gaierror((socket.EAI_AGAIN, 'Lookup timed out')) + raise EAI_EAGAIN_ERROR except dns.exception.DNSException: if flags & socket.NI_NAMEREQD: - raise socket.gaierror( - (socket.EAI_NONAME, 'Name or service not known')) + raise EAI_NONAME_ERROR else: try: rrset = resolver.query(host) @@ -261,32 +517,18 @@ def getnameinfo(sockaddr, flags): if flags & socket.NI_NUMERICHOST: host = rrset[0].address except dns.exception.Timeout: - raise socket.gaierror((socket.EAI_AGAIN, 'Lookup timed out')) + raise EAI_EAGAIN_ERROR except dns.exception.DNSException: raise socket.gaierror( (socket.EAI_NODATA, 'No address associated with hostname')) - if not (flags & socket.NI_NUMERICSERV): - proto = (flags & socket.NI_DGRAM) and 'udp' or 'tcp' - port = socket.getservbyport(port, proto) + if not (flags & socket.NI_NUMERICSERV): + proto = (flags & socket.NI_DGRAM) and 'udp' or 'tcp' + port = socket.getservbyport(port, proto) return (host, port) -def is_ipv4_addr(host): - """is_ipv4_addr returns true if host is a valid IPv4 address in - dotted quad notation. - """ - try: - d1, d2, d3, d4 = map(int, host.split('.')) - except (ValueError, AttributeError): - return False - - if 0 <= d1 <= 255 and 0 <= d2 <= 255 and 0 <= d3 <= 255 and 0 <= d4 <= 255: - return True - return False - - def _net_read(sock, count, expiration): """coro friendly replacement for dns.query._net_write Read the specified number of bytes from sock. Keep trying until we @@ -326,8 +568,8 @@ def _net_write(sock, data, expiration): raise dns.exception.Timeout -def udp(q, where, timeout=DNS_QUERY_TIMEOUT, port=53, af=None, source=None, - source_port=0, ignore_unexpected=False): +def udp(q, where, timeout=DNS_QUERY_TIMEOUT, port=53, + af=None, source=None, source_port=0, ignore_unexpected=False): """coro friendly replacement for dns.query.udp Return the response obtained after sending a query via UDP. diff --git a/tests/__init__.py b/tests/__init__.py index 9b96e7c..f2e7e4e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,6 +1,7 @@ # package is named tests, not test, so it won't be confused with test in stdlib from __future__ import print_function +import contextlib import errno import gc import os @@ -22,6 +23,21 @@ from eventlet import tpool main = unittest.main +@contextlib.contextmanager +def assert_raises(exc_type): + try: + yield + except exc_type: + pass + else: + name = str(exc_type) + try: + name = exc_type.__name__ + except AttributeError: + pass + assert False, 'Expected exception {0}'.format(name) + + def skipped(func): """ Decorator that marks a function as skipped. Uses nose's SkipTest exception if installed. Without nose, this will count skipped tests as passing tests.""" diff --git a/tests/greendns_test.py b/tests/greendns_test.py index 7a64b1e..c921951 100644 --- a/tests/greendns_test.py +++ b/tests/greendns_test.py @@ -1,13 +1,791 @@ -from nose.plugins.skip import SkipTest +"""Tests for the eventlet.support.greendns module""" +import os +import socket +import tempfile +import time -def test_greendns_getnameinfo_resolve_port(): - try: - from eventlet.support import greendns - except ImportError: - raise SkipTest('greendns requires package dnspython') +from tests import assert_raises, mock, skip_unless, LimitedTestCase - # https://bitbucket.org/eventlet/eventlet/issue/152 - _, port1 = greendns.getnameinfo(('127.0.0.1', 80), 0) - _, port2 = greendns.getnameinfo(('localhost', 80), 0) - assert port1 == port2 == 'http' +try: + import dns.rdatatype + import dns.rdtypes.IN.A + import dns.rdtypes.IN.AAAA + import dns.resolver + import dns.rrset + from eventlet.support import greendns + greendns_available = True +except ImportError: + greendns_available = False + greendns = mock.Mock() + + +def greendns_requirement(_f): + """We want to skip tests if greendns is not installed. + """ + return greendns_available + + +class TestHostsResolver(LimitedTestCase): + + def _make_host_resolver(self): + """Returns a HostResolver instance + + The hosts file will be empty but accessible as a py.path.local + instance using the ``hosts`` attribute. + """ + hosts = tempfile.NamedTemporaryFile() + hr = greendns.HostsResolver(fname=hosts.name) + hr.hosts = hosts + hr._last_stat = 0 + return hr + + @skip_unless(greendns_requirement) + def test_default_fname(self): + hr = greendns.HostsResolver() + assert os.path.exists(hr.fname) + + @skip_unless(greendns_requirement) + def test_readlines_lines(self): + hr = self._make_host_resolver() + hr.hosts.write('line0\n') + hr.hosts.flush() + assert hr._readlines() == ['line0'] + hr._last_stat = 0 + hr.hosts.write('line1\n') + hr.hosts.flush() + assert hr._readlines() == ['line0', 'line1'] + hr._last_stat = 0 + hr.hosts.write('#comment0\nline0\n #comment1\nline1') + assert hr._readlines() == ['line0', 'line1'] + + @skip_unless(greendns_requirement) + def test_readlines_missing_file(self): + hr = self._make_host_resolver() + hr.hosts.close() + hr._last_stat = 0 + assert hr._readlines() == [] + + @skip_unless(greendns_requirement) + def test_load_no_contents(self): + hr = self._make_host_resolver() + hr._load() + assert not hr._v4 + assert not hr._v6 + assert not hr._aliases + + @skip_unless(greendns_requirement) + def test_load_v4_v6_cname_aliases(self): + hr = self._make_host_resolver() + hr.hosts.write('1.2.3.4 v4.example.com v4\n' + 'dead:beef::1 v6.example.com v6\n') + hr.hosts.flush() + hr._load() + assert hr._v4 == {'v4.example.com': '1.2.3.4', 'v4': '1.2.3.4'} + assert hr._v6 == {'v6.example.com': 'dead:beef::1', + 'v6': 'dead:beef::1'} + assert hr._aliases == {'v4': 'v4.example.com', + 'v6': 'v6.example.com'} + + @skip_unless(greendns_requirement) + def test_load_v6_link_local(self): + hr = self._make_host_resolver() + hr.hosts.write('fe80:: foo\n' + 'fe80:dead:beef::1 bar\n') + hr.hosts.flush() + hr._load() + assert not hr._v4 + assert not hr._v6 + + @skip_unless(greendns_requirement) + def test_query_A(self): + hr = self._make_host_resolver() + hr._v4 = {'v4.example.com': '1.2.3.4'} + ans = hr.query('v4.example.com') + assert ans[0].address == '1.2.3.4' + + @skip_unless(greendns_requirement) + def test_query_ans_types(self): + # This assumes test_query_A above succeeds + hr = self._make_host_resolver() + hr._v4 = {'v4.example.com': '1.2.3.4'} + hr._last_stat = time.time() + ans = hr.query('v4.example.com') + assert isinstance(ans, greendns.dns.resolver.Answer) + assert ans.response is None + assert ans.qname == dns.name.from_text('v4.example.com') + assert ans.rdtype == dns.rdatatype.A + assert ans.rdclass == dns.rdataclass.IN + assert ans.canonical_name == dns.name.from_text('v4.example.com') + assert ans.expiration + assert isinstance(ans.rrset, dns.rrset.RRset) + assert ans.rrset.rdtype == dns.rdatatype.A + assert ans.rrset.rdclass == dns.rdataclass.IN + ttl = greendns.HOSTS_TTL + assert ttl - 1 <= ans.rrset.ttl <= ttl + 1 + rr = ans.rrset[0] + assert isinstance(rr, greendns.dns.rdtypes.IN.A.A) + assert rr.rdtype == dns.rdatatype.A + assert rr.rdclass == dns.rdataclass.IN + assert rr.address == '1.2.3.4' + + @skip_unless(greendns_requirement) + def test_query_AAAA(self): + hr = self._make_host_resolver() + hr._v6 = {'v6.example.com': 'dead:beef::1'} + ans = hr.query('v6.example.com', dns.rdatatype.AAAA) + assert ans[0].address == 'dead:beef::1' + + @skip_unless(greendns_requirement) + def test_query_unknown_raises(self): + hr = self._make_host_resolver() + with assert_raises(greendns.dns.resolver.NoAnswer): + hr.query('example.com') + + @skip_unless(greendns_requirement) + def test_query_unknown_no_raise(self): + hr = self._make_host_resolver() + ans = hr.query('example.com', raise_on_no_answer=False) + assert isinstance(ans, greendns.dns.resolver.Answer) + assert ans.response is None + assert ans.qname == dns.name.from_text('example.com') + assert ans.rdtype == dns.rdatatype.A + assert ans.rdclass == dns.rdataclass.IN + assert ans.canonical_name == dns.name.from_text('example.com') + assert ans.expiration + assert isinstance(ans.rrset, greendns.dns.rrset.RRset) + assert ans.rrset.rdtype == dns.rdatatype.A + assert ans.rrset.rdclass == dns.rdataclass.IN + assert len(ans.rrset) == 0 + + @skip_unless(greendns_requirement) + def test_query_CNAME(self): + hr = self._make_host_resolver() + hr._aliases = {'host': 'host.example.com'} + ans = hr.query('host', dns.rdatatype.CNAME) + assert ans[0].target == dns.name.from_text('host.example.com') + assert str(ans[0].target) == 'host.example.com.' + + @skip_unless(greendns_requirement) + def test_query_unknown_type(self): + hr = self._make_host_resolver() + with assert_raises(greendns.dns.resolver.NoAnswer): + hr.query('example.com', dns.rdatatype.MX) + + @skip_unless(greendns_requirement) + def test_getaliases(self): + hr = self._make_host_resolver() + hr._aliases = {'host': 'host.example.com', + 'localhost': 'host.example.com'} + res = set(hr.getaliases('host')) + assert res == set(['host.example.com', 'localhost']) + + @skip_unless(greendns_requirement) + def test_getaliases_unknown(self): + hr = self._make_host_resolver() + assert hr.getaliases('host.example.com') == [] + + @skip_unless(greendns_requirement) + def test_getaliases_fqdn(self): + hr = self._make_host_resolver() + hr._aliases = {'host': 'host.example.com'} + res = set(hr.getaliases('host.example.com')) + assert res == set(['host']) + + +def _make_mock_base_resolver(): + """A mocked base resolver class""" + class RR(object): + pass + + class Resolver(object): + aliases = ['cname.example.com'] + raises = None + rr = RR() + + def query(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + if self.raises: + raise self.raises() + if hasattr(self, 'rrset'): + rrset = self.rrset + else: + rrset = [self.rr] + return greendns.HostsAnswer('foo', 1, 1, rrset, False) + + def getaliases(self, *args, **kwargs): + return self.aliases + + return Resolver + + +class TestProxyResolver(LimitedTestCase): + + @skip_unless(greendns_requirement) + def test_clear(self): + rp = greendns.ResolverProxy() + resolver = rp._resolver + rp.clear() + assert rp._resolver != resolver + + @skip_unless(greendns_requirement) + def _make_mock_hostsresolver(self): + """A mocked HostsResolver""" + base_resolver = _make_mock_base_resolver() + base_resolver.rr.address = '1.2.3.4' + return base_resolver() + + @skip_unless(greendns_requirement) + def _make_mock_resolver(self): + """A mocked Resolver""" + base_resolver = _make_mock_base_resolver() + base_resolver.rr.address = '5.6.7.8' + return base_resolver() + + @skip_unless(greendns_requirement) + def test_hosts(self): + hostsres = self._make_mock_hostsresolver() + rp = greendns.ResolverProxy(hostsres) + ans = rp.query('host.example.com') + assert ans[0].address == '1.2.3.4' + + @skip_unless(greendns_requirement) + def test_hosts_noanswer(self): + hostsres = self._make_mock_hostsresolver() + res = self._make_mock_resolver() + rp = greendns.ResolverProxy(hostsres) + rp._resolver = res + hostsres.raises = greendns.dns.resolver.NoAnswer + ans = rp.query('host.example.com') + assert ans[0].address == '5.6.7.8' + + @skip_unless(greendns_requirement) + def test_resolver(self): + res = self._make_mock_resolver() + rp = greendns.ResolverProxy() + rp._resolver = res + ans = rp.query('host.example.com') + assert ans[0].address == '5.6.7.8' + + @skip_unless(greendns_requirement) + def test_noanswer(self): + res = self._make_mock_resolver() + rp = greendns.ResolverProxy() + rp._resolver = res + res.raises = greendns.dns.resolver.NoAnswer + with assert_raises(greendns.dns.resolver.NoAnswer): + rp.query('host.example.com') + + @skip_unless(greendns_requirement) + def test_nxdomain(self): + res = self._make_mock_resolver() + rp = greendns.ResolverProxy() + rp._resolver = res + res.raises = greendns.dns.resolver.NXDOMAIN + with assert_raises(greendns.dns.resolver.NXDOMAIN): + rp.query('host.example.com') + + @skip_unless(greendns_requirement) + def test_noanswer_hosts(self): + hostsres = self._make_mock_hostsresolver() + res = self._make_mock_resolver() + rp = greendns.ResolverProxy(hostsres) + rp._resolver = res + hostsres.raises = greendns.dns.resolver.NoAnswer + res.raises = greendns.dns.resolver.NoAnswer + with assert_raises(greendns.dns.resolver.NoAnswer): + rp.query('host.example.com') + + def _make_mock_resolver_aliases(self): + + class RR(object): + target = 'host.example.com' + + class Resolver(object): + call_count = 0 + exc_type = greendns.dns.resolver.NoAnswer + + def query(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + self.call_count += 1 + if self.call_count < 2: + return greendns.HostsAnswer(args[0], 1, 5, [RR()], False) + else: + raise self.exc_type() + + return Resolver() + + @skip_unless(greendns_requirement) + def test_getaliases(self): + aliases_res = self._make_mock_resolver_aliases() + rp = greendns.ResolverProxy() + rp._resolver = aliases_res + aliases = set(rp.getaliases('alias.example.com')) + assert aliases == set(['host.example.com']) + + @skip_unless(greendns_requirement) + def test_getaliases_fqdn(self): + aliases_res = self._make_mock_resolver_aliases() + rp = greendns.ResolverProxy() + rp._resolver = aliases_res + rp._resolver.call_count = 1 + assert rp.getaliases('host.example.com') == [] + + @skip_unless(greendns_requirement) + def test_getaliases_nxdomain(self): + aliases_res = self._make_mock_resolver_aliases() + rp = greendns.ResolverProxy() + rp._resolver = aliases_res + rp._resolver.call_count = 1 + rp._resolver.exc_type = greendns.dns.resolver.NXDOMAIN + assert rp.getaliases('host.example.com') == [] + + +class TestResolve(LimitedTestCase): + + def setUp(self): + base_resolver = _make_mock_base_resolver() + base_resolver.rr.address = '1.2.3.4' + self._old_resolver = greendns.resolver + greendns.resolver = base_resolver() + + def tearDown(self): + greendns.resolver = self._old_resolver + + @skip_unless(greendns_requirement) + def test_A(self): + ans = greendns.resolve('host.example.com', socket.AF_INET) + assert ans[0].address == '1.2.3.4' + assert greendns.resolver.args == ('host.example.com', dns.rdatatype.A) + + @skip_unless(greendns_requirement) + def test_AAAA(self): + greendns.resolver.rr.address = 'dead:beef::1' + ans = greendns.resolve('host.example.com', socket.AF_INET6) + assert ans[0].address == 'dead:beef::1' + assert greendns.resolver.args == ('host.example.com', dns.rdatatype.AAAA) + + @skip_unless(greendns_requirement) + def test_unknown_rdtype(self): + with assert_raises(socket.gaierror): + greendns.resolve('host.example.com', socket.AF_INET6 + 1) + + @skip_unless(greendns_requirement) + def test_timeout(self): + greendns.resolver.raises = greendns.dns.exception.Timeout + with assert_raises(socket.gaierror): + greendns.resolve('host.example.com') + + @skip_unless(greendns_requirement) + def test_exc(self): + greendns.resolver.raises = greendns.dns.exception.DNSException + with assert_raises(socket.gaierror): + greendns.resolve('host.example.com') + + @skip_unless(greendns_requirement) + def test_noraise_noanswer(self): + greendns.resolver.rrset = None + ans = greendns.resolve('example.com', raises=False) + assert not ans.rrset + + @skip_unless(greendns_requirement) + def test_noraise_nxdomain(self): + greendns.resolver.raises = greendns.dns.resolver.NXDOMAIN + ans = greendns.resolve('example.com', raises=False) + assert not ans.rrset + + +class TestResolveCname(LimitedTestCase): + + def setUp(self): + base_resolver = _make_mock_base_resolver() + base_resolver.rr.target = 'cname.example.com' + self._old_resolver = greendns.resolver + greendns.resolver = base_resolver() + + def tearDown(self): + greendns.resolver = self._old_resolver + + @skip_unless(greendns_requirement) + def test_success(self): + cname = greendns.resolve_cname('alias.example.com') + assert cname == 'cname.example.com' + + @skip_unless(greendns_requirement) + def test_timeout(self): + greendns.resolver.raises = greendns.dns.exception.Timeout + with assert_raises(socket.gaierror): + greendns.resolve_cname('alias.example.com') + + @skip_unless(greendns_requirement) + def test_nodata(self): + greendns.resolver.raises = greendns.dns.exception.DNSException + with assert_raises(socket.gaierror): + greendns.resolve_cname('alias.example.com') + + @skip_unless(greendns_requirement) + def test_no_answer(self): + greendns.resolver.raises = greendns.dns.resolver.NoAnswer + assert greendns.resolve_cname('host.example.com') == 'host.example.com' + + +def _make_mock_resolve(): + """A stubbed out resolve function + + This monkeypatches the greendns.resolve() function with a mock. + You must give it answers by calling .add(). + """ + + class MockAnswer(list): + pass + + class MockResolve(object): + + def __init__(self): + self.answers = {} + + def __call__(self, name, family=socket.AF_INET, raises=True): + qname = dns.name.from_text(name) + try: + rrset = self.answers[name][family] + except KeyError: + if raises: + raise greendns.dns.resolver.NoAnswer() + rrset = dns.rrset.RRset(qname, 1, 1) + ans = MockAnswer() + ans.qname = qname + ans.rrset = rrset + ans.extend(rrset.items) + return ans + + def add(self, name, addr): + """Add an address to a name and family""" + try: + rdata = dns.rdtypes.IN.A.A(dns.rdataclass.IN, + dns.rdatatype.A, addr) + family = socket.AF_INET + except (socket.error, dns.exception.SyntaxError): + rdata = dns.rdtypes.IN.AAAA.AAAA(dns.rdataclass.IN, + dns.rdatatype.AAAA, addr) + family = socket.AF_INET6 + family_dict = self.answers.setdefault(name, {}) + rrset = family_dict.get(family) + if not rrset: + family_dict[family] = rrset = dns.rrset.RRset( + dns.name.from_text(name), rdata.rdclass, rdata.rdtype) + rrset.add(rdata) + + resolve = MockResolve() + return resolve + + +class TestGetaddrinfo(LimitedTestCase): + + def _make_mock_resolve_cname(self): + """A stubbed out cname function""" + + class ResolveCname(object): + qname = None + cname = 'cname.example.com' + + def __call__(self, host): + self.qname = host + return self.cname + + resolve_cname = ResolveCname() + return resolve_cname + + def setUp(self): + self._old_resolve = greendns.resolve + self._old_resolve_cname = greendns.resolve_cname + self._old_orig_getaddrinfo = greendns.socket.getaddrinfo + + def tearDown(self): + greendns.resolve = self._old_resolve + greendns.resolve_cname = self._old_resolve_cname + greendns.socket.getaddrinfo = self._old_orig_getaddrinfo + + @skip_unless(greendns_requirement) + def test_getaddrinfo_inet(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '127.0.0.2') + res = greendns.getaddrinfo('example.com', 'ssh', socket.AF_INET) + addr = ('127.0.0.2', 22) + tcp = (socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, addr) + udp = (socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP, addr) + assert tcp in [ai[:3] + (ai[4],) for ai in res] + assert udp in [ai[:3] + (ai[4],) for ai in res] + + @skip_unless(greendns_requirement) + def test_getaddrinfo_inet6(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '::1') + res = greendns.getaddrinfo('example.com', 'ssh', socket.AF_INET6) + addr = ('::1', 22, 0, 0) + tcp = (socket.AF_INET6, socket.SOCK_STREAM, socket.IPPROTO_TCP, addr) + udp = (socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDP, addr) + assert tcp in [ai[:3] + (ai[4],) for ai in res] + assert udp in [ai[:3] + (ai[4],) for ai in res] + + @skip_unless(greendns_requirement) + def test_getaddrinfo(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '127.0.0.2') + greendns.resolve.add('example.com', '::1') + res = greendns.getaddrinfo('example.com', 'ssh') + addr = ('127.0.0.2', 22) + tcp = (socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, addr) + udp = (socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP, addr) + addr = ('::1', 22, 0, 0) + tcp6 = (socket.AF_INET6, socket.SOCK_STREAM, socket.IPPROTO_TCP, addr) + udp6 = (socket.AF_INET6, socket.SOCK_DGRAM, socket.IPPROTO_UDP, addr) + filt_res = [ai[:3] + (ai[4],) for ai in res] + assert tcp in filt_res + assert udp in filt_res + assert tcp6 in filt_res + assert udp6 in filt_res + + @skip_unless(greendns_requirement) + def test_getaddrinfo_only_a_ans(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '1.2.3.4') + res = greendns.getaddrinfo('example.com', 0) + addr = [('1.2.3.4', 0)] * len(res) + assert addr == [ai[-1] for ai in res] + + @skip_unless(greendns_requirement) + def test_getaddrinfo_only_aaaa_ans(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', 'dead:beef::1') + res = greendns.getaddrinfo('example.com', 0) + addr = [('dead:beef::1', 0, 0, 0)] * len(res) + assert addr == [ai[-1] for ai in res] + + @skip_unless(greendns_requirement) + def test_canonname(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('host.example.com', '1.2.3.4') + greendns.resolve_cname = self._make_mock_resolve_cname() + res = greendns.getaddrinfo('host.example.com', 0, + 0, 0, 0, socket.AI_CANONNAME) + assert res[0][3] == 'cname.example.com' + + @skip_unless(greendns_requirement) + def test_host_none(self): + res = greendns.getaddrinfo(None, 80) + for addr in set(ai[-1] for ai in res): + assert addr in [('127.0.0.1', 80), ('::1', 80, 0, 0)] + + @skip_unless(greendns_requirement) + def test_host_none_passive(self): + res = greendns.getaddrinfo(None, 80, 0, 0, 0, socket.AI_PASSIVE) + for addr in set(ai[-1] for ai in res): + assert addr in [('0.0.0.0', 80), ('::', 80, 0, 0)] + + @skip_unless(greendns_requirement) + def test_v4mapped(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '1.2.3.4') + res = greendns.getaddrinfo('example.com', 80, + socket.AF_INET6, 0, 0, socket.AI_V4MAPPED) + addrs = set(ai[-1] for ai in res) + assert addrs == set([('::ffff:1.2.3.4', 80, 0, 0)]) + + @skip_unless(greendns_requirement) + def test_v4mapped_all(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '1.2.3.4') + greendns.resolve.add('example.com', 'dead:beef::1') + res = greendns.getaddrinfo('example.com', 80, socket.AF_INET6, 0, 0, + socket.AI_V4MAPPED | socket.AI_ALL) + addrs = set(ai[-1] for ai in res) + for addr in addrs: + assert addr in [('::ffff:1.2.3.4', 80, 0, 0), + ('dead:beef::1', 80, 0, 0)] + + @skip_unless(greendns_requirement) + def test_numericserv(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '1.2.3.4') + with assert_raises(socket.gaierror): + greendns.getaddrinfo('example.com', 'www', 0, 0, 0, socket.AI_NUMERICSERV) + + @skip_unless(greendns_requirement) + def test_numerichost(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '1.2.3.4') + with assert_raises(socket.gaierror): + greendns.getaddrinfo('example.com', 80, 0, 0, 0, socket.AI_NUMERICHOST) + + @skip_unless(greendns_requirement) + def test_noport(self): + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('example.com', '1.2.3.4') + ai = greendns.getaddrinfo('example.com', None) + assert ai[0][-1][1] == 0 + + @skip_unless(greendns_requirement) + def test_AI_ADDRCONFIG(self): + # When the users sets AI_ADDRCONFIG but only has an IPv4 + # address configured we will iterate over the results, but the + # call for the IPv6 address will fail rather then return an + # empty list. In that case we should catch the exception and + # only return the ones which worked. + def getaddrinfo(addr, port, family, socktype, proto, aiflags): + if addr == '127.0.0.1': + return [(socket.AF_INET, 1, 0, '', ('127.0.0.1', 0))] + elif addr == '::1' and aiflags & socket.AI_ADDRCONFIG: + raise socket.error(socket.EAI_ADDRFAMILY, + 'Address family for hostname not supported') + elif addr == '::1' and not aiflags & socket.AI_ADDRCONFIG: + return [(socket.AF_INET6, 1, 0, '', ('::1', 0, 0, 0))] + greendns.socket.getaddrinfo = getaddrinfo + greendns.resolve = _make_mock_resolve() + greendns.resolve.add('localhost', '127.0.0.1') + greendns.resolve.add('localhost', '::1') + res = greendns.getaddrinfo('localhost', None, + 0, 0, 0, socket.AI_ADDRCONFIG) + assert res == [(socket.AF_INET, 1, 0, '', ('127.0.0.1', 0))] + + @skip_unless(greendns_requirement) + def test_AI_ADDRCONFIG_noaddr(self): + # If AI_ADDRCONFIG is used but there is no address we need to + # get an exception, not an empty list. + def getaddrinfo(addr, port, family, socktype, proto, aiflags): + raise socket.error(socket.EAI_ADDRFAMILY, + 'Address family for hostname not supported') + greendns.socket.getaddrinfo = getaddrinfo + greendns.resolve = _make_mock_resolve() + try: + greendns.getaddrinfo('::1', None, 0, 0, 0, socket.AI_ADDRCONFIG) + except socket.error as e: + assert e.errno == socket.EAI_ADDRFAMILY + + +class TestIsIpAddr(LimitedTestCase): + + @skip_unless(greendns_requirement) + def test_isv4(self): + assert greendns.is_ipv4_addr('1.2.3.4') + + @skip_unless(greendns_requirement) + def test_isv4_false(self): + assert not greendns.is_ipv4_addr('260.0.0.0') + + @skip_unless(greendns_requirement) + def test_isv6(self): + assert greendns.is_ipv6_addr('dead:beef::1') + + @skip_unless(greendns_requirement) + def test_isv6_invalid(self): + assert not greendns.is_ipv6_addr('foobar::1') + + @skip_unless(greendns_requirement) + def test_v4(self): + assert greendns.is_ip_addr('1.2.3.4') + + @skip_unless(greendns_requirement) + def test_v4_illegal(self): + assert not greendns.is_ip_addr('300.0.0.1') + + @skip_unless(greendns_requirement) + def test_v6_addr(self): + assert greendns.is_ip_addr('::1') + + @skip_unless(greendns_requirement) + def test_isv4_none(self): + assert not greendns.is_ipv4_addr(None) + + @skip_unless(greendns_requirement) + def test_isv6_none(self): + assert not greendns.is_ipv6_addr(None) + + @skip_unless(greendns_requirement) + def test_none(self): + assert not greendns.is_ip_addr(None) + + +class TestGethostbyname(LimitedTestCase): + + def setUp(self): + self._old_resolve = greendns.resolve + greendns.resolve = _make_mock_resolve() + + def tearDown(self): + greendns.resolve = self._old_resolve + + @skip_unless(greendns_requirement) + def test_ipaddr(self): + assert greendns.gethostbyname('1.2.3.4') == '1.2.3.4' + + @skip_unless(greendns_requirement) + def test_name(self): + greendns.resolve.add('host.example.com', '1.2.3.4') + assert greendns.gethostbyname('host.example.com') == '1.2.3.4' + + +class TestGetaliases(LimitedTestCase): + + def _make_mock_resolver(self): + base_resolver = _make_mock_base_resolver() + resolver = base_resolver() + resolver.aliases = ['cname.example.com'] + return resolver + + def setUp(self): + self._old_resolver = greendns.resolver + greendns.resolver = self._make_mock_resolver() + + def tearDown(self): + greendns.resolver = self._old_resolver + + @skip_unless(greendns_requirement) + def test_getaliases(self): + assert greendns.getaliases('host.example.com') == ['cname.example.com'] + + +class TestGethostbyname_ex(LimitedTestCase): + + def _make_mock_getaliases(self): + + class GetAliases(object): + aliases = ['cname.example.com'] + + def __call__(self, *args, **kwargs): + return self.aliases + + getaliases = GetAliases() + return getaliases + + def setUp(self): + self._old_resolve = greendns.resolve + greendns.resolve = _make_mock_resolve() + self._old_getaliases = greendns.getaliases + + def tearDown(self): + greendns.resolve = self._old_resolve + greendns.getaliases = self._old_getaliases + + @skip_unless(greendns_requirement) + def test_ipaddr(self): + res = greendns.gethostbyname_ex('1.2.3.4') + assert res == ('1.2.3.4', [], ['1.2.3.4']) + + @skip_unless(greendns_requirement) + def test_name(self): + greendns.resolve.add('host.example.com', '1.2.3.4') + greendns.getaliases = self._make_mock_getaliases() + greendns.getaliases.aliases = [] + res = greendns.gethostbyname_ex('host.example.com') + assert res == ('host.example.com', [], ['1.2.3.4']) + + @skip_unless(greendns_requirement) + def test_multiple_addrs(self): + greendns.resolve.add('host.example.com', '1.2.3.4') + greendns.resolve.add('host.example.com', '1.2.3.5') + greendns.getaliases = self._make_mock_getaliases() + greendns.getaliases.aliases = [] + res = greendns.gethostbyname_ex('host.example.com') + assert res == ('host.example.com', [], ['1.2.3.4', '1.2.3.5']) |