diff options
author | Yury Selivanov <yury@magic.io> | 2016-06-08 12:33:31 -0400 |
---|---|---|
committer | Yury Selivanov <yury@magic.io> | 2016-06-08 12:33:31 -0400 |
commit | f1c6fa986647791977d974bd43119b46a7a3cdbc (patch) | |
tree | 191f22cd29bf8bdf269be09fd14ff5a7f7cbf660 /Lib | |
parent | 7d7a11b5d700c54260c517d0fb57fe1caf591e31 (diff) | |
download | cpython-git-f1c6fa986647791977d974bd43119b46a7a3cdbc.tar.gz |
Issue #27136: Fix DNS static resolution; don't use it in getaddrinfo
Patch by A. Jesse Jiryu Davis
Diffstat (limited to 'Lib')
-rw-r--r-- | Lib/asyncio/base_events.py | 116 | ||||
-rw-r--r-- | Lib/asyncio/proactor_events.py | 9 | ||||
-rw-r--r-- | Lib/asyncio/selector_events.py | 24 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_base_events.py | 91 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_events.py | 19 | ||||
-rw-r--r-- | Lib/test/test_asyncio/test_selector_events.py | 20 |
6 files changed, 116 insertions, 163 deletions
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py index 2b2c18536d..172a463ef8 100644 --- a/Lib/asyncio/base_events.py +++ b/Lib/asyncio/base_events.py @@ -16,10 +16,8 @@ to modify the meaning of the API call itself. import collections import concurrent.futures -import functools import heapq import inspect -import ipaddress import itertools import logging import os @@ -86,12 +84,14 @@ if hasattr(socket, 'SOCK_CLOEXEC'): _SOCKET_TYPE_MASK |= socket.SOCK_CLOEXEC -@functools.lru_cache(maxsize=1024, typed=True) def _ipaddr_info(host, port, family, type, proto): - # Try to skip getaddrinfo if "host" is already an IP. Since getaddrinfo - # blocks on an exclusive lock on some platforms, users might handle name - # resolution in their own code and pass in resolved IPs. - if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or host is None: + # Try to skip getaddrinfo if "host" is already an IP. Users might have + # handled name resolution in their own code and pass in resolved IPs. + if not hasattr(socket, 'inet_pton'): + return + + if proto not in {0, socket.IPPROTO_TCP, socket.IPPROTO_UDP} or \ + host is None: return None type &= ~_SOCKET_TYPE_MASK @@ -123,59 +123,42 @@ def _ipaddr_info(host, port, family, type, proto): # Might be a service name like "http". port = socket.getservbyname(port) - if hasattr(socket, 'inet_pton'): - if family == socket.AF_UNSPEC: - afs = [socket.AF_INET, socket.AF_INET6] - else: - afs = [family] - - for af in afs: - # Linux's inet_pton doesn't accept an IPv6 zone index after host, - # like '::1%lo0', so strip it. If we happen to make an invalid - # address look valid, we fail later in sock.connect or sock.bind. - try: - if af == socket.AF_INET6: - socket.inet_pton(af, host.partition('%')[0]) - else: - socket.inet_pton(af, host) - return af, type, proto, '', (host, port) - except OSError: - pass + if family == socket.AF_UNSPEC: + afs = [socket.AF_INET, socket.AF_INET6] + else: + afs = [family] - # "host" is not an IP address. + if isinstance(host, bytes): + host = host.decode('idna') + if '%' in host: + # Linux's inet_pton doesn't accept an IPv6 zone index after host, + # like '::1%lo0'. return None - # No inet_pton. (On Windows it's only available since Python 3.4.) - # Even though getaddrinfo with AI_NUMERICHOST would be non-blocking, it - # still requires a lock on some platforms, and waiting for that lock could - # block the event loop. Use ipaddress instead, it's just text parsing. - try: - addr = ipaddress.IPv4Address(host) - except ValueError: + for af in afs: try: - addr = ipaddress.IPv6Address(host.partition('%')[0]) - except ValueError: - return None + socket.inet_pton(af, host) + # The host has already been resolved. + return af, type, proto, '', (host, port) + except OSError: + pass - af = socket.AF_INET if addr.version == 4 else socket.AF_INET6 - if family not in (socket.AF_UNSPEC, af): - # "host" is wrong IP version for "family". - return None - - return af, type, proto, '', (host, port) + # "host" is not an IP address. + return None -def _check_resolved_address(sock, address): - # Ensure that the address is already resolved to avoid the trap of hanging - # the entire event loop when the address requires doing a DNS lookup. - - if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: - return - +def _ensure_resolved(address, *, family=0, type=socket.SOCK_STREAM, proto=0, + flags=0, loop): host, port = address[:2] - if _ipaddr_info(host, port, sock.family, sock.type, sock.proto) is None: - raise ValueError("address must be resolved (IP address)," - " got host %r" % host) + info = _ipaddr_info(host, port, family, type, proto) + if info is not None: + # "host" is already a resolved IP. + fut = loop.create_future() + fut.set_result([info]) + return fut + else: + return loop.getaddrinfo(host, port, family=family, type=type, + proto=proto, flags=flags) def _run_until_complete_cb(fut): @@ -602,12 +585,7 @@ class BaseEventLoop(events.AbstractEventLoop): def getaddrinfo(self, host, port, *, family=0, type=0, proto=0, flags=0): - info = _ipaddr_info(host, port, family, type, proto) - if info is not None: - fut = self.create_future() - fut.set_result([info]) - return fut - elif self._debug: + if self._debug: return self.run_in_executor(None, self._getaddrinfo_debug, host, port, family, type, proto, flags) else: @@ -656,14 +634,14 @@ class BaseEventLoop(events.AbstractEventLoop): raise ValueError( 'host/port and sock can not be specified at the same time') - f1 = self.getaddrinfo( - host, port, family=family, - type=socket.SOCK_STREAM, proto=proto, flags=flags) + f1 = _ensure_resolved((host, port), family=family, + type=socket.SOCK_STREAM, proto=proto, + flags=flags, loop=self) fs = [f1] if local_addr is not None: - f2 = self.getaddrinfo( - *local_addr, family=family, - type=socket.SOCK_STREAM, proto=proto, flags=flags) + f2 = _ensure_resolved(local_addr, family=family, + type=socket.SOCK_STREAM, proto=proto, + flags=flags, loop=self) fs.append(f2) else: f2 = None @@ -798,9 +776,9 @@ class BaseEventLoop(events.AbstractEventLoop): assert isinstance(addr, tuple) and len(addr) == 2, ( '2-tuple is expected') - infos = yield from self.getaddrinfo( - *addr, family=family, type=socket.SOCK_DGRAM, - proto=proto, flags=flags) + infos = yield from _ensure_resolved( + addr, family=family, type=socket.SOCK_DGRAM, + proto=proto, flags=flags, loop=self) if not infos: raise OSError('getaddrinfo() returned empty list') @@ -888,9 +866,9 @@ class BaseEventLoop(events.AbstractEventLoop): @coroutine def _create_server_getaddrinfo(self, host, port, family, flags): - infos = yield from self.getaddrinfo(host, port, family=family, + infos = yield from _ensure_resolved((host, port), family=family, type=socket.SOCK_STREAM, - flags=flags) + flags=flags, loop=self) if not infos: raise OSError('getaddrinfo({!r}) returned empty list'.format(host)) return infos diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py index eb92458ada..3ac314c0cc 100644 --- a/Lib/asyncio/proactor_events.py +++ b/Lib/asyncio/proactor_events.py @@ -440,14 +440,7 @@ class BaseProactorEventLoop(base_events.BaseEventLoop): return self._proactor.send(sock, data) def sock_connect(self, sock, address): - try: - base_events._check_resolved_address(sock, address) - except ValueError as err: - fut = self.create_future() - fut.set_exception(err) - return fut - else: - return self._proactor.connect(sock, address) + return self._proactor.connect(sock, address) def sock_accept(self, sock): return self._proactor.accept(sock) diff --git a/Lib/asyncio/selector_events.py b/Lib/asyncio/selector_events.py index b34fee34df..fb7ab2108e 100644 --- a/Lib/asyncio/selector_events.py +++ b/Lib/asyncio/selector_events.py @@ -385,24 +385,28 @@ class BaseSelectorEventLoop(base_events.BaseEventLoop): def sock_connect(self, sock, address): """Connect to a remote socket at address. - The address must be already resolved to avoid the trap of hanging the - entire event loop when the address requires doing a DNS lookup. For - example, it must be an IP address, not a hostname, for AF_INET and - AF_INET6 address families. Use getaddrinfo() to resolve the hostname - asynchronously. - This method is a coroutine. """ if self._debug and sock.gettimeout() != 0: raise ValueError("the socket must be non-blocking") + fut = self.create_future() + if hasattr(socket, 'AF_UNIX') and sock.family == socket.AF_UNIX: + self._sock_connect(fut, sock, address) + else: + resolved = base_events._ensure_resolved(address, loop=self) + resolved.add_done_callback( + lambda resolved: self._on_resolved(fut, sock, resolved)) + + return fut + + def _on_resolved(self, fut, sock, resolved): try: - base_events._check_resolved_address(sock, address) - except ValueError as err: - fut.set_exception(err) + _, _, _, _, address = resolved.result()[0] + except Exception as exc: + fut.set_exception(exc) else: self._sock_connect(fut, sock, address) - return fut def _sock_connect(self, fut, sock, address): fd = sock.fileno() diff --git a/Lib/test/test_asyncio/test_base_events.py b/Lib/test/test_asyncio/test_base_events.py index e800ec4340..0807dfbf4c 100644 --- a/Lib/test/test_asyncio/test_base_events.py +++ b/Lib/test/test_asyncio/test_base_events.py @@ -45,6 +45,7 @@ def mock_socket_module(): m_socket.socket = mock.MagicMock() m_socket.socket.return_value = test_utils.mock_nonblocking_socket() + m_socket.getaddrinfo._is_coroutine = False return m_socket @@ -56,14 +57,6 @@ def patch_socket(f): class BaseEventTests(test_utils.TestCase): - def setUp(self): - super().setUp() - base_events._ipaddr_info.cache_clear() - - def tearDown(self): - base_events._ipaddr_info.cache_clear() - super().tearDown() - def test_ipaddr_info(self): UNSPEC = socket.AF_UNSPEC INET = socket.AF_INET @@ -79,6 +72,10 @@ class BaseEventTests(test_utils.TestCase): self.assertEqual( (INET, STREAM, TCP, '', ('1.2.3.4', 1)), + base_events._ipaddr_info(b'1.2.3.4', 1, INET, STREAM, TCP)) + + self.assertEqual( + (INET, STREAM, TCP, '', ('1.2.3.4', 1)), base_events._ipaddr_info('1.2.3.4', 1, UNSPEC, STREAM, TCP)) self.assertEqual( @@ -116,8 +113,7 @@ class BaseEventTests(test_utils.TestCase): base_events._ipaddr_info('::3', 1, INET, STREAM, TCP)) # IPv6 address with zone index. - self.assertEqual( - (INET6, STREAM, TCP, '', ('::3%lo0', 1)), + self.assertIsNone( base_events._ipaddr_info('::3%lo0', 1, INET6, STREAM, TCP)) def test_port_parameter_types(self): @@ -169,31 +165,10 @@ class BaseEventTests(test_utils.TestCase): @patch_socket def test_ipaddr_info_no_inet_pton(self, m_socket): del m_socket.inet_pton - self.test_ipaddr_info() - - def test_check_resolved_address(self): - sock = socket.socket(socket.AF_INET) - with sock: - base_events._check_resolved_address(sock, ('1.2.3.4', 1)) - - sock = socket.socket(socket.AF_INET6) - with sock: - base_events._check_resolved_address(sock, ('::3', 1)) - base_events._check_resolved_address(sock, ('::3%lo0', 1)) - with self.assertRaises(ValueError): - base_events._check_resolved_address(sock, ('foo', 1)) - - def test_check_resolved_sock_type(self): - # Ensure we ignore extra flags in sock.type. - if hasattr(socket, 'SOCK_NONBLOCK'): - sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_NONBLOCK) - with sock: - base_events._check_resolved_address(sock, ('1.2.3.4', 1)) - - if hasattr(socket, 'SOCK_CLOEXEC'): - sock = socket.socket(type=socket.SOCK_STREAM | socket.SOCK_CLOEXEC) - with sock: - base_events._check_resolved_address(sock, ('1.2.3.4', 1)) + self.assertIsNone(base_events._ipaddr_info('1.2.3.4', 1, + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP)) class BaseEventLoopTests(test_utils.TestCase): @@ -1042,11 +1017,6 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): self.loop = asyncio.new_event_loop() self.set_event_loop(self.loop) - def tearDown(self): - # Clear mocked constants like AF_INET from the cache. - base_events._ipaddr_info.cache_clear() - super().tearDown() - @patch_socket def test_create_connection_multiple_errors(self, m_socket): @@ -1195,10 +1165,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): if not allow_inet_pton: del m_socket.inet_pton - def getaddrinfo(*args, **kw): - self.fail('should not have called getaddrinfo') - - m_socket.getaddrinfo = getaddrinfo + m_socket.getaddrinfo = socket.getaddrinfo sock = m_socket.socket.return_value self.loop.add_reader = mock.Mock() @@ -1210,9 +1177,9 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): t, p = self.loop.run_until_complete(coro) try: sock.connect.assert_called_with(('1.2.3.4', 80)) - m_socket.socket.assert_called_with(family=m_socket.AF_INET, - proto=m_socket.IPPROTO_TCP, - type=m_socket.SOCK_STREAM) + _, kwargs = m_socket.socket.call_args + self.assertEqual(kwargs['family'], m_socket.AF_INET) + self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM) finally: t.close() test_utils.run_briefly(self.loop) # allow transport to close @@ -1221,10 +1188,15 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): coro = self.loop.create_connection(asyncio.Protocol, '::2', 80) t, p = self.loop.run_until_complete(coro) try: - sock.connect.assert_called_with(('::2', 80)) - m_socket.socket.assert_called_with(family=m_socket.AF_INET6, - proto=m_socket.IPPROTO_TCP, - type=m_socket.SOCK_STREAM) + # Without inet_pton we use getaddrinfo, which transforms ('::2', 80) + # to ('::0.0.0.2', 80, 0, 0). The last 0s are flow info, scope id. + [address] = sock.connect.call_args[0] + host, port = address[:2] + self.assertRegex(host, r'::(0\.)*2') + self.assertEqual(port, 80) + _, kwargs = m_socket.socket.call_args + self.assertEqual(kwargs['family'], m_socket.AF_INET6) + self.assertEqual(kwargs['type'], m_socket.SOCK_STREAM) finally: t.close() test_utils.run_briefly(self.loop) # allow transport to close @@ -1256,6 +1228,21 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): self.assertRaises( OSError, self.loop.run_until_complete, coro) + @patch_socket + def test_create_connection_bluetooth(self, m_socket): + # See http://bugs.python.org/issue27136, fallback to getaddrinfo when + # we can't recognize an address is resolved, e.g. a Bluetooth address. + addr = ('00:01:02:03:04:05', 1) + + def getaddrinfo(host, port, *args, **kw): + assert (host, port) == addr + return [(999, 1, 999, '', (addr, 1))] + + m_socket.getaddrinfo = getaddrinfo + sock = m_socket.socket() + coro = self.loop.sock_connect(sock, addr) + self.loop.run_until_complete(coro) + def test_create_connection_ssl_server_hostname_default(self): self.loop.getaddrinfo = mock.Mock() @@ -1369,7 +1356,7 @@ class BaseEventLoopWithSelectorTests(test_utils.TestCase): getaddrinfo = self.loop.getaddrinfo = mock.Mock() getaddrinfo.return_value = [] - f = self.loop.create_server(MyProto, '0.0.0.0', 0) + f = self.loop.create_server(MyProto, 'python.org', 0) self.assertRaises(OSError, self.loop.run_until_complete, f) @patch_socket diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py index d52213ceb2..d0777758a7 100644 --- a/Lib/test/test_asyncio/test_events.py +++ b/Lib/test/test_asyncio/test_events.py @@ -1610,25 +1610,6 @@ class EventLoopTestsMixin: {'clock_resolution': self.loop._clock_resolution, 'selector': self.loop._selector.__class__.__name__}) - def test_sock_connect_address(self): - addresses = [(socket.AF_INET, ('www.python.org', 80))] - if support.IPV6_ENABLED: - addresses.extend(( - (socket.AF_INET6, ('www.python.org', 80)), - (socket.AF_INET6, ('www.python.org', 80, 0, 0)), - )) - - for family, address in addresses: - for sock_type in (socket.SOCK_STREAM, socket.SOCK_DGRAM): - sock = socket.socket(family, sock_type) - with sock: - sock.setblocking(False) - connect = self.loop.sock_connect(sock, address) - with self.assertRaises(ValueError) as cm: - self.loop.run_until_complete(connect) - self.assertIn('address must be resolved', - str(cm.exception)) - def test_remove_fds_after_closing(self): loop = self.create_event_loop() callback = lambda: None diff --git a/Lib/test/test_asyncio/test_selector_events.py b/Lib/test/test_asyncio/test_selector_events.py index 77e72e5705..8ad55358b1 100644 --- a/Lib/test/test_asyncio/test_selector_events.py +++ b/Lib/test/test_asyncio/test_selector_events.py @@ -343,9 +343,11 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): f = self.loop.sock_connect(sock, ('127.0.0.1', 8080)) self.assertIsInstance(f, asyncio.Future) - self.assertEqual( - (f, sock, ('127.0.0.1', 8080)), - self.loop._sock_connect.call_args[0]) + self.loop._run_once() + future_in, sock_in, address_in = self.loop._sock_connect.call_args[0] + self.assertEqual(future_in, f) + self.assertEqual(sock_in, sock) + self.assertEqual(address_in, ('127.0.0.1', 8080)) def test_sock_connect_timeout(self): # asyncio issue #205: sock_connect() must unregister the socket on @@ -359,6 +361,7 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): # first call to sock_connect() registers the socket fut = self.loop.sock_connect(sock, ('127.0.0.1', 80)) + self.loop._run_once() self.assertTrue(sock.connect.called) self.assertTrue(self.loop.add_writer.called) self.assertEqual(len(fut._callbacks), 1) @@ -376,7 +379,10 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): sock = mock.Mock() sock.fileno.return_value = 10 - self.loop._sock_connect(f, sock, ('127.0.0.1', 8080)) + resolved = self.loop.create_future() + resolved.set_result([(socket.AF_INET, socket.SOCK_STREAM, + socket.IPPROTO_TCP, '', ('127.0.0.1', 8080))]) + self.loop._sock_connect(f, sock, resolved) self.assertTrue(f.done()) self.assertIsNone(f.result()) self.assertTrue(sock.connect.called) @@ -402,9 +408,13 @@ class BaseSelectorEventLoopTests(test_utils.TestCase): sock.connect.side_effect = BlockingIOError sock.getsockopt.return_value = 0 address = ('127.0.0.1', 8080) + resolved = self.loop.create_future() + resolved.set_result([(socket.AF_INET, socket.SOCK_STREAM, + socket.IPPROTO_TCP, '', address)]) f = asyncio.Future(loop=self.loop) - self.loop._sock_connect(f, sock, address) + self.loop._sock_connect(f, sock, resolved) + self.loop._run_once() self.assertTrue(self.loop.add_writer.called) self.assertEqual(10, self.loop.add_writer.call_args[0][0]) |