summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authortomc797 <34632752+tomc797@users.noreply.github.com>2018-10-11 23:08:38 -0700
committerAsif Saif Uddin <auvipy@gmail.com>2018-10-12 12:08:38 +0600
commit2e6d6e945dae5088b480428580821ad20f491e9b (patch)
treeb6ec47ee97de8fce827ae52189bd74ef501eefc1
parentdc5b07c2320f932b2dd8342d0f4d578c095600cf (diff)
downloadpy-amqp-2e6d6e945dae5088b480428580821ad20f491e9b.tar.gz
Issue#209: Connection.connect, Connection.close don't close socket on exception (#210)
* If there is an exception raised on Connection.connect or Connection.close, ensure that the underlying transport socket is closed * Added more informative message for exception * Fix code style to satisfy travis checks * Fix code style to satisfy travis checks * Provided more informative exception message when remote unexpectedly hangs up * Test the setting of SO_RCVTIMEO, SO_SNDTIMEO * callback is soley invoked when connected From the unit tests, I gleened that the callback should only be invoked when Connection.connected is true * Code style fixes * Code style fixes * Code style fixes * Code fix, wrap line continuation differently * Added three tests to improve coverage Test to ensure connect is idempotent. Tests to confirm EOF behavior.
-rw-r--r--amqp/connection.py44
-rw-r--r--amqp/transport.py65
-rw-r--r--t/unit/test_connection.py21
-rw-r--r--t/unit/test_transport.py60
4 files changed, 141 insertions, 49 deletions
diff --git a/amqp/connection.py b/amqp/connection.py
index eb720a6..ea56a2d 100644
--- a/amqp/connection.py
+++ b/amqp/connection.py
@@ -294,18 +294,23 @@ class Connection(AbstractChannel):
#
if self.connected:
return callback() if callback else None
- self.transport = self.Transport(
- self.host, self.connect_timeout, self.ssl,
- self.read_timeout, self.write_timeout,
- socket_settings=self.socket_settings,
- )
- self.transport.connect()
- self.on_inbound_frame = self.frame_handler_cls(
- self, self.on_inbound_method)
- self.frame_writer = self.frame_writer_cls(self, self.transport)
-
- while not self._handshake_complete:
- self.drain_events(timeout=self.connect_timeout)
+ try:
+ self.transport = self.Transport(
+ self.host, self.connect_timeout, self.ssl,
+ self.read_timeout, self.write_timeout,
+ socket_settings=self.socket_settings,
+ )
+ self.transport.connect()
+ self.on_inbound_frame = self.frame_handler_cls(
+ self, self.on_inbound_method)
+ self.frame_writer = self.frame_writer_cls(self, self.transport)
+
+ while not self._handshake_complete:
+ self.drain_events(timeout=self.connect_timeout)
+
+ except (OSError, IOError, SSLError):
+ self.collect()
+ raise
def _warn_force_connect(self, attr):
warnings.warn(AMQPDeprecationWarning(
@@ -559,11 +564,16 @@ class Connection(AbstractChannel):
# already closed
return
- return self.send_method(
- spec.Connection.Close, argsig,
- (reply_code, reply_text, method_sig[0], method_sig[1]),
- wait=spec.Connection.CloseOk,
- )
+ try:
+ return self.send_method(
+ spec.Connection.Close, argsig,
+ (reply_code, reply_text, method_sig[0], method_sig[1]),
+ wait=spec.Connection.CloseOk,
+ )
+ except (OSError, IOError, SSLError):
+ # close connection
+ self.collect()
+ raise
def _on_close(self, reply_code, reply_text, class_id, method_id):
"""Request a connection close.
diff --git a/amqp/transport.py b/amqp/transport.py
index ca2a095..90dffd7 100644
--- a/amqp/transport.py
+++ b/amqp/transport.py
@@ -60,12 +60,10 @@ def to_host_port(host, default=AMQP_PORT):
class _AbstractTransport(object):
"""Common superclass for TCP and SSL transports."""
- connected = False
-
def __init__(self, host, connect_timeout=None,
read_timeout=None, write_timeout=None,
socket_settings=None, raise_on_initial_eintr=True, **kwargs):
- self.connected = True
+ self.connected = False
self.sock = None
self.raise_on_initial_eintr = raise_on_initial_eintr
self._read_buffer = EMPTY_BUFFER
@@ -76,10 +74,24 @@ class _AbstractTransport(object):
self.socket_settings = socket_settings
def connect(self):
- self._connect(self.host, self.port, self.connect_timeout)
- self._init_socket(
- self.socket_settings, self.read_timeout, self.write_timeout,
- )
+ try:
+ # are we already connected?
+ if self.connected:
+ return
+ self._connect(self.host, self.port, self.connect_timeout)
+ self._init_socket(
+ self.socket_settings, self.read_timeout, self.write_timeout,
+ )
+ # we've sent the banner; signal connect
+ # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner
+ # has _not_ been sent
+ self.connected = True
+ except (OSError, IOError, SSLError):
+ # if not fully connected, close socket, and reraise error
+ if self.sock and not self.connected:
+ self.sock.close()
+ self.sock = None
+ raise
@contextmanager
def having_timeout(self, timeout):
@@ -160,26 +172,21 @@ class _AbstractTransport(object):
return
def _init_socket(self, socket_settings, read_timeout, write_timeout):
- try:
- self.sock.settimeout(None) # set socket back to blocking mode
- self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
- self._set_socket_options(socket_settings)
-
- # set socket timeouts
- for timeout, interval in ((socket.SO_SNDTIMEO, write_timeout),
- (socket.SO_RCVTIMEO, read_timeout)):
- if interval is not None:
- self.sock.setsockopt(
- socket.SOL_SOCKET, timeout,
- pack('ll', interval, 0),
- )
- self._setup_transport()
-
- self._write(AMQP_PROTOCOL_HEADER)
- except (OSError, IOError, socket.error) as exc:
- if get_errno(exc) not in _UNAVAIL:
- self.connected = False
- raise
+ self.sock.settimeout(None) # set socket back to blocking mode
+ self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
+ self._set_socket_options(socket_settings)
+
+ # set socket timeouts
+ for timeout, interval in ((socket.SO_SNDTIMEO, write_timeout),
+ (socket.SO_RCVTIMEO, read_timeout)):
+ if interval is not None:
+ self.sock.setsockopt(
+ socket.SOL_SOCKET, timeout,
+ pack('ll', interval, 0),
+ )
+ self._setup_transport()
+
+ self._write(AMQP_PROTOCOL_HEADER)
def _get_tcp_socket_defaults(self, sock):
tcp_opts = {}
@@ -370,7 +377,7 @@ class SSLTransport(_AbstractTransport):
continue
raise
if not s:
- raise IOError('Socket closed')
+ raise IOError('Server unexpectedly closed connection')
rbuf += s
except: # noqa
self._read_buffer = rbuf
@@ -423,7 +430,7 @@ class TCPTransport(_AbstractTransport):
continue
raise
if not s:
- raise IOError('Socket closed')
+ raise IOError('Server unexpectedly closed connection')
rbuf += s
except: # noqa
self._read_buffer = rbuf
diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py
index a0de6d1..b34ba7b 100644
--- a/t/unit/test_connection.py
+++ b/t/unit/test_connection.py
@@ -4,7 +4,7 @@ import socket
import warnings
import pytest
-from case import ContextMock, Mock, call
+from case import ContextMock, Mock, call, patch
from amqp import Connection, spec
from amqp.connection import SSLError
@@ -127,6 +127,14 @@ class test_Connection:
assert self.conn.connect(callback) == callback.return_value
callback.assert_called_with()
+ def test_connect__socket_error(self):
+ self.conn = Connection()
+ self.conn.collect = Mock(name='collect')
+ with patch('socket.socket', side_effect=socket.error):
+ with pytest.raises(socket.error):
+ self.conn.connect()
+ self.conn.collect.assert_called_with()
+
def test_on_start(self):
self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z AMQPLAIN PLAIN',
'en_US en_GB')
@@ -279,6 +287,7 @@ class test_Connection:
for i, channel in items(channels):
if i:
channel.collect.assert_called_with()
+ assert self.conn._transport is None
def test_collect__channel_raises_socket_error(self):
self.conn.channels = self.conn.channels = {1: Mock(name='c1')}
@@ -376,6 +385,7 @@ class test_Connection:
)
def test_close(self):
+ self.conn.collect = Mock(name='collect')
self.conn.close(reply_text='foo', method_sig=spec.Channel.Open)
self.conn.send_method.assert_called_with(
spec.Connection.Close, 'BsBB',
@@ -387,6 +397,15 @@ class test_Connection:
self.conn.transport = None
self.conn.close()
+ def test_close__socket_error(self):
+ self.conn.send_method = Mock(name='send_method',
+ side_effect=socket.error)
+ self.conn.collect = Mock(name='collect')
+ with pytest.raises(socket.error):
+ self.conn.close()
+ self.conn.send_method.assert_called()
+ self.conn.collect.assert_called_with()
+
def test_on_close(self):
self.conn._x_close_ok = Mock(name='_x_close_ok')
with pytest.raises(NotFound):
diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py
index 3e2223a..9e4e27c 100644
--- a/t/unit/test_transport.py
+++ b/t/unit/test_transport.py
@@ -21,7 +21,10 @@ class MockSocket(object):
self.sa = None
def setsockopt(self, family, key, value):
- if not isinstance(value, int):
+ if (family == socket.SOL_SOCKET and
+ key in (socket.SO_RCVTIMEO, socket.SO_SNDTIMEO)):
+ self.options[key] = value
+ elif not isinstance(value, int):
raise socket.error()
self.options[key] = value
@@ -202,6 +205,18 @@ class test_socket_options:
assert opts
+ def test_set_sockopt_opts_timeout(self):
+ # tests socket options SO_RCVTIMEO and SO_SNDTIMEO
+ # this test is soley for coverage as socket.settimeout
+ # is pythonic way to have timeouts
+ self.transp = transport.Transport(
+ self.host, self.connect_timeout,
+ )
+ self.transp.read_timeout = 0xdead
+ self.transp.write_timeout = 0xbeef
+ with patch('socket.socket', return_value=MockSocket()):
+ self.transp.connect()
+
class test_AbstractTransport:
@@ -242,8 +257,9 @@ class test_AbstractTransport:
self.t.close()
sock.shutdown.assert_called_with(socket.SHUT_RDWR)
sock.close.assert_called_with()
- assert self.t.sock is None
+ assert self.t.sock is None and self.t.connected is False
self.t.close()
+ assert self.t.sock is None and self.t.connected is False
def test_read_frame__timeout(self):
self.t._read = Mock()
@@ -299,6 +315,19 @@ class test_AbstractTransport:
with pytest.raises(UnexpectedFrame):
self.t.read_frame()
+ def transport_read_EOF(self):
+ for host, ssl in (('localhost:5672', False),
+ ('localhost:5671', True),):
+ self.t = transport.Transport(host, ssl)
+ self.t.sock = Mock(name='socket')
+ self.t.connected = True
+ self.t._quick_recv = Mock(name='recv', return_value='')
+ with pytest.raises(
+ IOError,
+ match=r'.*Server unexpectedly closed connection.*'
+ ):
+ self.t.read_frame()
+
def test_write__success(self):
self.t._write = Mock()
self.t.write('foo')
@@ -350,6 +379,7 @@ class test_AbstractTransport_connect:
with patch('socket.socket', side_effect=socket.error):
with pytest.raises(socket.error):
self.t.connect()
+ assert self.t.sock is None and self.t.connected is False
def test_connect_socket_initialization_fails(self):
with patch('socket.socket', side_effect=socket.error), \
@@ -362,6 +392,7 @@ class test_AbstractTransport_connect:
]):
with pytest.raises(socket.error):
self.t.connect()
+ assert self.t.sock is None and self.t.connected is False
def test_connect_multiple_addr_entries_fails(self):
with patch('socket.socket', return_value=MockSocket()) as sock_mock, \
@@ -452,6 +483,15 @@ class test_AbstractTransport_connect:
self.t.connect()
assert cloexec_mock.called
+ def test_connect_already_connected(self):
+ assert not self.t.connected
+ with patch('socket.socket', return_value=MockSocket()):
+ self.t.connect()
+ assert self.t.connected
+ sock_obj = self.t.sock
+ self.t.connect()
+ assert self.t.connected and self.t.sock is sock_obj
+
class test_SSLTransport:
@@ -506,6 +546,14 @@ class test_SSLTransport:
self.t._shutdown_transport()
assert self.t.sock is sock.unwrap()
+ def test_read_EOF(self):
+ self.t.sock = Mock(name='SSLSocket')
+ self.t.connected = True
+ self.t._quick_recv = Mock(name='recv', return_value='')
+ with pytest.raises(IOError,
+ match=r'.*Server unexpectedly closed connection.*'):
+ self.t._read(64)
+
class test_TCPTransport:
@@ -527,3 +575,11 @@ class test_TCPTransport:
assert self.t._write is self.t.sock.sendall
assert self.t._read_buffer is not None
assert self.t._quick_recv is self.t.sock.recv
+
+ def test_read_EOF(self):
+ self.t.sock = Mock(name='socket')
+ self.t.connected = True
+ self.t._quick_recv = Mock(name='recv', return_value='')
+ with pytest.raises(IOError,
+ match=r'.*Server unexpectedly closed connection.*'):
+ self.t._read(64)