diff options
-rw-r--r-- | amqp/transport.py | 70 | ||||
-rw-r--r-- | t/unit/test_transport.py | 40 |
2 files changed, 22 insertions, 88 deletions
diff --git a/amqp/transport.py b/amqp/transport.py index b5a0d4b..2761f09 100644 --- a/amqp/transport.py +++ b/amqp/transport.py @@ -169,59 +169,27 @@ class _AbstractTransport: sock.settimeout(prev) def _connect(self, host, port, timeout): - e = None - - # Below we are trying to avoid additional DNS requests for AAAA if A - # succeeds. This helps a lot in case when a hostname has an IPv4 entry - # in /etc/hosts but not IPv6. Without the (arguably somewhat twisted) - # logic below, getaddrinfo would attempt to resolve the hostname for - # both IP versions, which would make the resolver talk to configured - # DNS servers. If those servers are for some reason not available - # during resolution attempt (either because of system misconfiguration, - # or network connectivity problem), resolution process locks the - # _connect call for extended time. - addr_types = (socket.AF_INET, socket.AF_INET6) - addr_types_num = len(addr_types) - for n, family in enumerate(addr_types): - # first, resolve the address for a single address family + entries = socket.getaddrinfo( + host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, SOL_TCP, + ) + for i, res in enumerate(entries): + af, socktype, proto, canonname, sa = res try: - entries = socket.getaddrinfo( - host, port, family, socket.SOCK_STREAM, SOL_TCP) - entries_num = len(entries) - except socket.gaierror: - # we may have depleted all our options - if n + 1 >= addr_types_num: - # if getaddrinfo succeeded before for another address - # family, reraise the previous socket.error since it's more - # relevant to users - raise (e - if e is not None - else socket.error( - "failed to resolve broker hostname")) - continue # pragma: no cover - - # now that we have address(es) for the hostname, connect to broker - for i, res in enumerate(entries): - af, socktype, proto, _, sa = res + self.sock = socket.socket(af, socktype, proto) try: - self.sock = socket.socket(af, socktype, proto) - try: - set_cloexec(self.sock, True) - except NotImplementedError: - pass - self.sock.settimeout(timeout) - self.sock.connect(sa) - except OSError as ex: - e = ex - if self.sock is not None: - self.sock.close() - self.sock = None - # we may have depleted all our options - if i + 1 >= entries_num and n + 1 >= addr_types_num: - raise - else: - # hurray, we established connection - return + set_cloexec(self.sock, True) + except NotImplementedError: + pass + self.sock.settimeout(timeout) + self.sock.connect(sa) + except socket.error: + if self.sock: + self.sock.close() + self.sock = None + if i + 1 >= len(entries): + raise + else: + break def _init_socket(self, socket_settings, read_timeout, write_timeout): self.sock.settimeout(None) # set socket back to blocking mode diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py index 348b6c2..e9c7114 100644 --- a/t/unit/test_transport.py +++ b/t/unit/test_transport.py @@ -520,54 +520,20 @@ class test_AbstractTransport_connect: side_effect=(socket.error, None)): self.t.connect() - def test_connect_short_curcuit_on_INET_succeed(self): + def test_connect_calls_getaddrinfo_with_af_unspec(self): with patch('socket.socket', return_value=MockSocket()), \ - patch('socket.getaddrinfo', - side_effect=[ - [(socket.AF_INET, 1, socket.IPPROTO_TCP, - '', ('127.0.0.1', 5672))], - [(socket.AF_INET6, 1, socket.IPPROTO_TCP, - '', ('::1', 5672))] - ]) as getaddrinfo: + patch('socket.getaddrinfo') as getaddrinfo: self.t.sock = Mock() self.t.close() self.t.connect() getaddrinfo.assert_called_with( - 'localhost', 5672, socket.AF_INET, ANY, ANY) - - def test_connect_short_curcuit_on_INET_fails(self): - with patch('socket.socket', return_value=MockSocket()) as sock_mock, \ - patch('socket.getaddrinfo', - side_effect=[ - [(socket.AF_INET, 1, socket.IPPROTO_TCP, - '', ('127.0.0.1', 5672))], - [(socket.AF_INET6, 1, socket.IPPROTO_TCP, - '', ('::1', 5672))] - ]) as getaddrinfo: - self.t.sock = Mock() - self.t.close() - with patch.object(sock_mock.return_value, 'connect', - side_effect=(socket.error, None)): - self.t.connect() - getaddrinfo.assert_has_calls( - [call('localhost', 5672, addr_type, ANY, ANY) - for addr_type in (socket.AF_INET, socket.AF_INET6)]) + 'localhost', 5672, socket.AF_UNSPEC, ANY, ANY) def test_connect_getaddrinfo_raises_gaierror(self): with patch('socket.getaddrinfo', side_effect=socket.gaierror): with pytest.raises(socket.error): self.t.connect() - def test_connect_getaddrinfo_raises_gaierror_once_recovers(self): - with patch('socket.socket', return_value=MockSocket()), \ - patch('socket.getaddrinfo', - side_effect=[ - socket.gaierror, - [(socket.AF_INET6, 1, socket.IPPROTO_TCP, - '', ('::1', 5672))] - ]): - self.t.connect() - def test_connect_survives_not_implemented_set_cloexec(self): with patch('socket.socket', return_value=MockSocket()), \ patch('socket.getaddrinfo', |