From d4c879f03e84e00bad9bd54cb8bf3c8b18bcd0f6 Mon Sep 17 00:00:00 2001 From: VinayGValsaraj Date: Fri, 10 Dec 2021 17:15:41 -0600 Subject: Adding two tests, for Connection.collect and Transport.close. --- amqp/connection.py | 32 ++++++++++++++------------------ amqp/transport.py | 11 +++++++---- t/unit/test_connection.py | 13 ++++++++++--- t/unit/test_transport.py | 11 +++++++++-- 4 files changed, 40 insertions(+), 27 deletions(-) diff --git a/amqp/connection.py b/amqp/connection.py index 5b3a4d1..9917ec7 100644 --- a/amqp/connection.py +++ b/amqp/connection.py @@ -466,24 +466,20 @@ class Connection(AbstractChannel): return self._transport and self._transport.connected def collect(self): - try: - if self._transport: - self._transport.close() - - if self.channels: - # Copy all the channels except self since the channels - # dictionary changes during the collection process. - channels = [ - ch for ch in self.channels.values() - if ch is not self - ] - - for ch in channels: - ch.collect() - except OSError: - pass # connection already closed on the other end - finally: - self._transport = self.connection = self.channels = None + if self._transport: + self._transport.close() + + if self.channels: + # Copy all the channels except self since the channels + # dictionary changes during the collection process. + channels = [ + ch for ch in self.channels.values() + if ch is not self + ] + + for ch in channels: + ch.collect() + self._transport = self.connection = self.channels = None def _get_free_channel_id(self): try: diff --git a/amqp/transport.py b/amqp/transport.py index 701c34c..177fb22 100644 --- a/amqp/transport.py +++ b/amqp/transport.py @@ -276,7 +276,10 @@ class _AbstractTransport: # Call shutdown first to make sure that pending messages # reach the AMQP broker if the program exits after # calling this method. - self.sock.shutdown(socket.SHUT_RDWR) + try: + self.sock.shutdown(socket.SHUT_RDWR) + except OSError: + pass self.sock.close() self.sock = None self.connected = False @@ -525,8 +528,8 @@ class SSLTransport(_AbstractTransport): context.load_verify_locations(ca_certs) if ciphers is not None: context.set_ciphers(ciphers) - # Set SNI headers if supported. - # Must set context.check_hostname before setting context.verify_mode + # Set SNI headers if supported. + # Must set context.check_hostname before setting context.verify_mode # to avoid setting context.verify_mode=ssl.CERT_NONE while # context.check_hostname is still True (the default value in context # if client-side) which results in the following exception: @@ -539,7 +542,7 @@ class SSLTransport(_AbstractTransport): except AttributeError: pass # ask forgiveness not permission - # See note above re: ordering for context.check_hostname and + # See note above re: ordering for context.check_hostname and # context.verify_mode assignments. if cert_reqs is not None: context.verify_mode = cert_reqs diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py index 21faebd..a2997e6 100644 --- a/t/unit/test_connection.py +++ b/t/unit/test_connection.py @@ -323,10 +323,17 @@ class test_Connection: channel.collect.assert_called_with() assert self.conn._transport is None - def test_collect__channel_raises_socket_error(self): - self.conn.channels = self.conn.channels = {1: Mock(name='c1')} - self.conn.channels[1].collect.side_effect = socket.error() + def test_collect__transport_socket_raises_os_error(self): + self.conn.transport = TCPTransport('localhost:5672') + sock = self.conn.transport.sock = Mock(name='sock') + channel = Mock(name='c1') + self.conn.channels = {1: channel} + sock.shutdown.side_effect = OSError self.conn.collect() + channel.collect.assert_called_with() + sock.close.assert_called_with() + assert self.conn._transport is None + assert self.conn.channels is None def test_collect_no_transport(self): self.conn = Connection() diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py index d93116a..b111497 100644 --- a/t/unit/test_transport.py +++ b/t/unit/test_transport.py @@ -282,6 +282,13 @@ class test_AbstractTransport: self.t.close() assert self.t.sock is None and self.t.connected is False + def test_close_os_error(self): + sock = self.t.sock = Mock() + sock.shutdown.side_effect = OSError + self.t.close() + sock.close.assert_called_with() + assert self.t.sock is None and self.t.connected is False + def test_read_frame__timeout(self): self.t._read = Mock() self.t._read.side_effect = socket.timeout() @@ -719,7 +726,7 @@ class test_SSLTransport: ) assert context.verify_mode == sentinel.CERT_REQS - # testing context creation inside _wrap_socket_sni() with parameter + # testing context creation inside _wrap_socket_sni() with parameter # cert_reqs == ssl.CERT_NONE. Previously raised ValueError because # code path attempted to set context.verify_mode=ssl.CERT_NONE before # setting context.check_hostname = False which raised a ValueError @@ -740,7 +747,7 @@ class test_SSLTransport: ) mock_load_default_certs.assert_not_called() mock_wrap_socket.assert_called_once() - + with patch('ssl.SSLContext.wrap_socket') as mock_wrap_socket: with patch('ssl.SSLContext.load_default_certs') as mock_load_default_certs: sock = Mock() -- cgit v1.2.1