diff options
author | tomc797 <34632752+tomc797@users.noreply.github.com> | 2018-10-11 23:08:38 -0700 |
---|---|---|
committer | Asif Saif Uddin <auvipy@gmail.com> | 2018-10-12 12:08:38 +0600 |
commit | 2e6d6e945dae5088b480428580821ad20f491e9b (patch) | |
tree | b6ec47ee97de8fce827ae52189bd74ef501eefc1 | |
parent | dc5b07c2320f932b2dd8342d0f4d578c095600cf (diff) | |
download | py-amqp-2e6d6e945dae5088b480428580821ad20f491e9b.tar.gz |
Issue#209: Connection.connect, Connection.close don't close socket on exception (#210)
* If there is an exception raised on Connection.connect or Connection.close,
ensure that the underlying transport socket is closed
* Added more informative message for exception
* Fix code style to satisfy travis checks
* Fix code style to satisfy travis checks
* Provided more informative exception message when remote unexpectedly hangs up
* Test the setting of SO_RCVTIMEO, SO_SNDTIMEO
* callback is soley invoked when connected
From the unit tests, I gleened that the callback should only be invoked
when Connection.connected is true
* Code style fixes
* Code style fixes
* Code style fixes
* Code fix, wrap line continuation differently
* Added three tests to improve coverage
Test to ensure connect is idempotent. Tests to confirm EOF behavior.
-rw-r--r-- | amqp/connection.py | 44 | ||||
-rw-r--r-- | amqp/transport.py | 65 | ||||
-rw-r--r-- | t/unit/test_connection.py | 21 | ||||
-rw-r--r-- | t/unit/test_transport.py | 60 |
4 files changed, 141 insertions, 49 deletions
diff --git a/amqp/connection.py b/amqp/connection.py index eb720a6..ea56a2d 100644 --- a/amqp/connection.py +++ b/amqp/connection.py @@ -294,18 +294,23 @@ class Connection(AbstractChannel): # if self.connected: return callback() if callback else None - self.transport = self.Transport( - self.host, self.connect_timeout, self.ssl, - self.read_timeout, self.write_timeout, - socket_settings=self.socket_settings, - ) - self.transport.connect() - self.on_inbound_frame = self.frame_handler_cls( - self, self.on_inbound_method) - self.frame_writer = self.frame_writer_cls(self, self.transport) - - while not self._handshake_complete: - self.drain_events(timeout=self.connect_timeout) + try: + self.transport = self.Transport( + self.host, self.connect_timeout, self.ssl, + self.read_timeout, self.write_timeout, + socket_settings=self.socket_settings, + ) + self.transport.connect() + self.on_inbound_frame = self.frame_handler_cls( + self, self.on_inbound_method) + self.frame_writer = self.frame_writer_cls(self, self.transport) + + while not self._handshake_complete: + self.drain_events(timeout=self.connect_timeout) + + except (OSError, IOError, SSLError): + self.collect() + raise def _warn_force_connect(self, attr): warnings.warn(AMQPDeprecationWarning( @@ -559,11 +564,16 @@ class Connection(AbstractChannel): # already closed return - return self.send_method( - spec.Connection.Close, argsig, - (reply_code, reply_text, method_sig[0], method_sig[1]), - wait=spec.Connection.CloseOk, - ) + try: + return self.send_method( + spec.Connection.Close, argsig, + (reply_code, reply_text, method_sig[0], method_sig[1]), + wait=spec.Connection.CloseOk, + ) + except (OSError, IOError, SSLError): + # close connection + self.collect() + raise def _on_close(self, reply_code, reply_text, class_id, method_id): """Request a connection close. diff --git a/amqp/transport.py b/amqp/transport.py index ca2a095..90dffd7 100644 --- a/amqp/transport.py +++ b/amqp/transport.py @@ -60,12 +60,10 @@ def to_host_port(host, default=AMQP_PORT): class _AbstractTransport(object): """Common superclass for TCP and SSL transports.""" - connected = False - def __init__(self, host, connect_timeout=None, read_timeout=None, write_timeout=None, socket_settings=None, raise_on_initial_eintr=True, **kwargs): - self.connected = True + self.connected = False self.sock = None self.raise_on_initial_eintr = raise_on_initial_eintr self._read_buffer = EMPTY_BUFFER @@ -76,10 +74,24 @@ class _AbstractTransport(object): self.socket_settings = socket_settings def connect(self): - self._connect(self.host, self.port, self.connect_timeout) - self._init_socket( - self.socket_settings, self.read_timeout, self.write_timeout, - ) + try: + # are we already connected? + if self.connected: + return + self._connect(self.host, self.port, self.connect_timeout) + self._init_socket( + self.socket_settings, self.read_timeout, self.write_timeout, + ) + # we've sent the banner; signal connect + # EINTR, EAGAIN, EWOULDBLOCK would signal that the banner + # has _not_ been sent + self.connected = True + except (OSError, IOError, SSLError): + # if not fully connected, close socket, and reraise error + if self.sock and not self.connected: + self.sock.close() + self.sock = None + raise @contextmanager def having_timeout(self, timeout): @@ -160,26 +172,21 @@ class _AbstractTransport(object): return def _init_socket(self, socket_settings, read_timeout, write_timeout): - try: - self.sock.settimeout(None) # set socket back to blocking mode - self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - self._set_socket_options(socket_settings) - - # set socket timeouts - for timeout, interval in ((socket.SO_SNDTIMEO, write_timeout), - (socket.SO_RCVTIMEO, read_timeout)): - if interval is not None: - self.sock.setsockopt( - socket.SOL_SOCKET, timeout, - pack('ll', interval, 0), - ) - self._setup_transport() - - self._write(AMQP_PROTOCOL_HEADER) - except (OSError, IOError, socket.error) as exc: - if get_errno(exc) not in _UNAVAIL: - self.connected = False - raise + self.sock.settimeout(None) # set socket back to blocking mode + self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + self._set_socket_options(socket_settings) + + # set socket timeouts + for timeout, interval in ((socket.SO_SNDTIMEO, write_timeout), + (socket.SO_RCVTIMEO, read_timeout)): + if interval is not None: + self.sock.setsockopt( + socket.SOL_SOCKET, timeout, + pack('ll', interval, 0), + ) + self._setup_transport() + + self._write(AMQP_PROTOCOL_HEADER) def _get_tcp_socket_defaults(self, sock): tcp_opts = {} @@ -370,7 +377,7 @@ class SSLTransport(_AbstractTransport): continue raise if not s: - raise IOError('Socket closed') + raise IOError('Server unexpectedly closed connection') rbuf += s except: # noqa self._read_buffer = rbuf @@ -423,7 +430,7 @@ class TCPTransport(_AbstractTransport): continue raise if not s: - raise IOError('Socket closed') + raise IOError('Server unexpectedly closed connection') rbuf += s except: # noqa self._read_buffer = rbuf diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index a0de6d1..b34ba7b 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -4,7 +4,7 @@ import socket import warnings import pytest -from case import ContextMock, Mock, call +from case import ContextMock, Mock, call, patch from amqp import Connection, spec from amqp.connection import SSLError @@ -127,6 +127,14 @@ class test_Connection: assert self.conn.connect(callback) == callback.return_value callback.assert_called_with() + def test_connect__socket_error(self): + self.conn = Connection() + self.conn.collect = Mock(name='collect') + with patch('socket.socket', side_effect=socket.error): + with pytest.raises(socket.error): + self.conn.connect() + self.conn.collect.assert_called_with() + def test_on_start(self): self.conn._on_start(3, 4, {'foo': 'bar'}, b'x y z AMQPLAIN PLAIN', 'en_US en_GB') @@ -279,6 +287,7 @@ class test_Connection: for i, channel in items(channels): if i: channel.collect.assert_called_with() + assert self.conn._transport is None def test_collect__channel_raises_socket_error(self): self.conn.channels = self.conn.channels = {1: Mock(name='c1')} @@ -376,6 +385,7 @@ class test_Connection: ) def test_close(self): + self.conn.collect = Mock(name='collect') self.conn.close(reply_text='foo', method_sig=spec.Channel.Open) self.conn.send_method.assert_called_with( spec.Connection.Close, 'BsBB', @@ -387,6 +397,15 @@ class test_Connection: self.conn.transport = None self.conn.close() + def test_close__socket_error(self): + self.conn.send_method = Mock(name='send_method', + side_effect=socket.error) + self.conn.collect = Mock(name='collect') + with pytest.raises(socket.error): + self.conn.close() + self.conn.send_method.assert_called() + self.conn.collect.assert_called_with() + def test_on_close(self): self.conn._x_close_ok = Mock(name='_x_close_ok') with pytest.raises(NotFound): diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py index 3e2223a..9e4e27c 100644 --- a/t/unit/test_transport.py +++ b/t/unit/test_transport.py @@ -21,7 +21,10 @@ class MockSocket(object): self.sa = None def setsockopt(self, family, key, value): - if not isinstance(value, int): + if (family == socket.SOL_SOCKET and + key in (socket.SO_RCVTIMEO, socket.SO_SNDTIMEO)): + self.options[key] = value + elif not isinstance(value, int): raise socket.error() self.options[key] = value @@ -202,6 +205,18 @@ class test_socket_options: assert opts + def test_set_sockopt_opts_timeout(self): + # tests socket options SO_RCVTIMEO and SO_SNDTIMEO + # this test is soley for coverage as socket.settimeout + # is pythonic way to have timeouts + self.transp = transport.Transport( + self.host, self.connect_timeout, + ) + self.transp.read_timeout = 0xdead + self.transp.write_timeout = 0xbeef + with patch('socket.socket', return_value=MockSocket()): + self.transp.connect() + class test_AbstractTransport: @@ -242,8 +257,9 @@ class test_AbstractTransport: self.t.close() sock.shutdown.assert_called_with(socket.SHUT_RDWR) sock.close.assert_called_with() - assert self.t.sock is None + assert self.t.sock is None and self.t.connected is False self.t.close() + assert self.t.sock is None and self.t.connected is False def test_read_frame__timeout(self): self.t._read = Mock() @@ -299,6 +315,19 @@ class test_AbstractTransport: with pytest.raises(UnexpectedFrame): self.t.read_frame() + def transport_read_EOF(self): + for host, ssl in (('localhost:5672', False), + ('localhost:5671', True),): + self.t = transport.Transport(host, ssl) + self.t.sock = Mock(name='socket') + self.t.connected = True + self.t._quick_recv = Mock(name='recv', return_value='') + with pytest.raises( + IOError, + match=r'.*Server unexpectedly closed connection.*' + ): + self.t.read_frame() + def test_write__success(self): self.t._write = Mock() self.t.write('foo') @@ -350,6 +379,7 @@ class test_AbstractTransport_connect: with patch('socket.socket', side_effect=socket.error): with pytest.raises(socket.error): self.t.connect() + assert self.t.sock is None and self.t.connected is False def test_connect_socket_initialization_fails(self): with patch('socket.socket', side_effect=socket.error), \ @@ -362,6 +392,7 @@ class test_AbstractTransport_connect: ]): with pytest.raises(socket.error): self.t.connect() + assert self.t.sock is None and self.t.connected is False def test_connect_multiple_addr_entries_fails(self): with patch('socket.socket', return_value=MockSocket()) as sock_mock, \ @@ -452,6 +483,15 @@ class test_AbstractTransport_connect: self.t.connect() assert cloexec_mock.called + def test_connect_already_connected(self): + assert not self.t.connected + with patch('socket.socket', return_value=MockSocket()): + self.t.connect() + assert self.t.connected + sock_obj = self.t.sock + self.t.connect() + assert self.t.connected and self.t.sock is sock_obj + class test_SSLTransport: @@ -506,6 +546,14 @@ class test_SSLTransport: self.t._shutdown_transport() assert self.t.sock is sock.unwrap() + def test_read_EOF(self): + self.t.sock = Mock(name='SSLSocket') + self.t.connected = True + self.t._quick_recv = Mock(name='recv', return_value='') + with pytest.raises(IOError, + match=r'.*Server unexpectedly closed connection.*'): + self.t._read(64) + class test_TCPTransport: @@ -527,3 +575,11 @@ class test_TCPTransport: assert self.t._write is self.t.sock.sendall assert self.t._read_buffer is not None assert self.t._quick_recv is self.t.sock.recv + + def test_read_EOF(self): + self.t.sock = Mock(name='socket') + self.t.connected = True + self.t._quick_recv = Mock(name='recv', return_value='') + with pytest.raises(IOError, + match=r'.*Server unexpectedly closed connection.*'): + self.t._read(64) |