summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--amqp/transport.py70
-rw-r--r--t/unit/test_transport.py40
2 files changed, 22 insertions, 88 deletions
diff --git a/amqp/transport.py b/amqp/transport.py
index b5a0d4b..2761f09 100644
--- a/amqp/transport.py
+++ b/amqp/transport.py
@@ -169,59 +169,27 @@ class _AbstractTransport:
sock.settimeout(prev)
def _connect(self, host, port, timeout):
- e = None
-
- # Below we are trying to avoid additional DNS requests for AAAA if A
- # succeeds. This helps a lot in case when a hostname has an IPv4 entry
- # in /etc/hosts but not IPv6. Without the (arguably somewhat twisted)
- # logic below, getaddrinfo would attempt to resolve the hostname for
- # both IP versions, which would make the resolver talk to configured
- # DNS servers. If those servers are for some reason not available
- # during resolution attempt (either because of system misconfiguration,
- # or network connectivity problem), resolution process locks the
- # _connect call for extended time.
- addr_types = (socket.AF_INET, socket.AF_INET6)
- addr_types_num = len(addr_types)
- for n, family in enumerate(addr_types):
- # first, resolve the address for a single address family
+ entries = socket.getaddrinfo(
+ host, port, socket.AF_UNSPEC, socket.SOCK_STREAM, SOL_TCP,
+ )
+ for i, res in enumerate(entries):
+ af, socktype, proto, canonname, sa = res
try:
- entries = socket.getaddrinfo(
- host, port, family, socket.SOCK_STREAM, SOL_TCP)
- entries_num = len(entries)
- except socket.gaierror:
- # we may have depleted all our options
- if n + 1 >= addr_types_num:
- # if getaddrinfo succeeded before for another address
- # family, reraise the previous socket.error since it's more
- # relevant to users
- raise (e
- if e is not None
- else socket.error(
- "failed to resolve broker hostname"))
- continue # pragma: no cover
-
- # now that we have address(es) for the hostname, connect to broker
- for i, res in enumerate(entries):
- af, socktype, proto, _, sa = res
+ self.sock = socket.socket(af, socktype, proto)
try:
- self.sock = socket.socket(af, socktype, proto)
- try:
- set_cloexec(self.sock, True)
- except NotImplementedError:
- pass
- self.sock.settimeout(timeout)
- self.sock.connect(sa)
- except OSError as ex:
- e = ex
- if self.sock is not None:
- self.sock.close()
- self.sock = None
- # we may have depleted all our options
- if i + 1 >= entries_num and n + 1 >= addr_types_num:
- raise
- else:
- # hurray, we established connection
- return
+ set_cloexec(self.sock, True)
+ except NotImplementedError:
+ pass
+ self.sock.settimeout(timeout)
+ self.sock.connect(sa)
+ except socket.error:
+ if self.sock:
+ self.sock.close()
+ self.sock = None
+ if i + 1 >= len(entries):
+ raise
+ else:
+ break
def _init_socket(self, socket_settings, read_timeout, write_timeout):
self.sock.settimeout(None) # set socket back to blocking mode
diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py
index 348b6c2..e9c7114 100644
--- a/t/unit/test_transport.py
+++ b/t/unit/test_transport.py
@@ -520,54 +520,20 @@ class test_AbstractTransport_connect:
side_effect=(socket.error, None)):
self.t.connect()
- def test_connect_short_curcuit_on_INET_succeed(self):
+ def test_connect_calls_getaddrinfo_with_af_unspec(self):
with patch('socket.socket', return_value=MockSocket()), \
- patch('socket.getaddrinfo',
- side_effect=[
- [(socket.AF_INET, 1, socket.IPPROTO_TCP,
- '', ('127.0.0.1', 5672))],
- [(socket.AF_INET6, 1, socket.IPPROTO_TCP,
- '', ('::1', 5672))]
- ]) as getaddrinfo:
+ patch('socket.getaddrinfo') as getaddrinfo:
self.t.sock = Mock()
self.t.close()
self.t.connect()
getaddrinfo.assert_called_with(
- 'localhost', 5672, socket.AF_INET, ANY, ANY)
-
- def test_connect_short_curcuit_on_INET_fails(self):
- with patch('socket.socket', return_value=MockSocket()) as sock_mock, \
- patch('socket.getaddrinfo',
- side_effect=[
- [(socket.AF_INET, 1, socket.IPPROTO_TCP,
- '', ('127.0.0.1', 5672))],
- [(socket.AF_INET6, 1, socket.IPPROTO_TCP,
- '', ('::1', 5672))]
- ]) as getaddrinfo:
- self.t.sock = Mock()
- self.t.close()
- with patch.object(sock_mock.return_value, 'connect',
- side_effect=(socket.error, None)):
- self.t.connect()
- getaddrinfo.assert_has_calls(
- [call('localhost', 5672, addr_type, ANY, ANY)
- for addr_type in (socket.AF_INET, socket.AF_INET6)])
+ 'localhost', 5672, socket.AF_UNSPEC, ANY, ANY)
def test_connect_getaddrinfo_raises_gaierror(self):
with patch('socket.getaddrinfo', side_effect=socket.gaierror):
with pytest.raises(socket.error):
self.t.connect()
- def test_connect_getaddrinfo_raises_gaierror_once_recovers(self):
- with patch('socket.socket', return_value=MockSocket()), \
- patch('socket.getaddrinfo',
- side_effect=[
- socket.gaierror,
- [(socket.AF_INET6, 1, socket.IPPROTO_TCP,
- '', ('::1', 5672))]
- ]):
- self.t.connect()
-
def test_connect_survives_not_implemented_set_cloexec(self):
with patch('socket.socket', return_value=MockSocket()), \
patch('socket.getaddrinfo',