summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorVinayGValsaraj <vinaygvalsaraj@gmail.com>2021-12-10 17:15:41 -0600
committerAsif Saif Uddin <auvipy@gmail.com>2021-12-12 10:42:09 +0600
commitd4c879f03e84e00bad9bd54cb8bf3c8b18bcd0f6 (patch)
tree0f367981e3a2b555b2a97c5e1e124aacc0508ed8
parent1cf468ce3ff0da2dc835daff269110032c16310c (diff)
downloadpy-amqp-d4c879f03e84e00bad9bd54cb8bf3c8b18bcd0f6.tar.gz
Adding two tests, for Connection.collect and Transport.close.
-rw-r--r--amqp/connection.py32
-rw-r--r--amqp/transport.py11
-rw-r--r--t/unit/test_connection.py13
-rw-r--r--t/unit/test_transport.py11
4 files changed, 40 insertions, 27 deletions
diff --git a/amqp/connection.py b/amqp/connection.py
index 5b3a4d1..9917ec7 100644
--- a/amqp/connection.py
+++ b/amqp/connection.py
@@ -466,24 +466,20 @@ class Connection(AbstractChannel):
return self._transport and self._transport.connected
def collect(self):
- try:
- if self._transport:
- self._transport.close()
-
- if self.channels:
- # Copy all the channels except self since the channels
- # dictionary changes during the collection process.
- channels = [
- ch for ch in self.channels.values()
- if ch is not self
- ]
-
- for ch in channels:
- ch.collect()
- except OSError:
- pass # connection already closed on the other end
- finally:
- self._transport = self.connection = self.channels = None
+ if self._transport:
+ self._transport.close()
+
+ if self.channels:
+ # Copy all the channels except self since the channels
+ # dictionary changes during the collection process.
+ channels = [
+ ch for ch in self.channels.values()
+ if ch is not self
+ ]
+
+ for ch in channels:
+ ch.collect()
+ self._transport = self.connection = self.channels = None
def _get_free_channel_id(self):
try:
diff --git a/amqp/transport.py b/amqp/transport.py
index 701c34c..177fb22 100644
--- a/amqp/transport.py
+++ b/amqp/transport.py
@@ -276,7 +276,10 @@ class _AbstractTransport:
# Call shutdown first to make sure that pending messages
# reach the AMQP broker if the program exits after
# calling this method.
- self.sock.shutdown(socket.SHUT_RDWR)
+ try:
+ self.sock.shutdown(socket.SHUT_RDWR)
+ except OSError:
+ pass
self.sock.close()
self.sock = None
self.connected = False
@@ -525,8 +528,8 @@ class SSLTransport(_AbstractTransport):
context.load_verify_locations(ca_certs)
if ciphers is not None:
context.set_ciphers(ciphers)
- # Set SNI headers if supported.
- # Must set context.check_hostname before setting context.verify_mode
+ # Set SNI headers if supported.
+ # Must set context.check_hostname before setting context.verify_mode
# to avoid setting context.verify_mode=ssl.CERT_NONE while
# context.check_hostname is still True (the default value in context
# if client-side) which results in the following exception:
@@ -539,7 +542,7 @@ class SSLTransport(_AbstractTransport):
except AttributeError:
pass # ask forgiveness not permission
- # See note above re: ordering for context.check_hostname and
+ # See note above re: ordering for context.check_hostname and
# context.verify_mode assignments.
if cert_reqs is not None:
context.verify_mode = cert_reqs
diff --git a/t/unit/test_connection.py b/t/unit/test_connection.py
index 21faebd..a2997e6 100644
--- a/t/unit/test_connection.py
+++ b/t/unit/test_connection.py
@@ -323,10 +323,17 @@ class test_Connection:
channel.collect.assert_called_with()
assert self.conn._transport is None
- def test_collect__channel_raises_socket_error(self):
- self.conn.channels = self.conn.channels = {1: Mock(name='c1')}
- self.conn.channels[1].collect.side_effect = socket.error()
+ def test_collect__transport_socket_raises_os_error(self):
+ self.conn.transport = TCPTransport('localhost:5672')
+ sock = self.conn.transport.sock = Mock(name='sock')
+ channel = Mock(name='c1')
+ self.conn.channels = {1: channel}
+ sock.shutdown.side_effect = OSError
self.conn.collect()
+ channel.collect.assert_called_with()
+ sock.close.assert_called_with()
+ assert self.conn._transport is None
+ assert self.conn.channels is None
def test_collect_no_transport(self):
self.conn = Connection()
diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py
index d93116a..b111497 100644
--- a/t/unit/test_transport.py
+++ b/t/unit/test_transport.py
@@ -282,6 +282,13 @@ class test_AbstractTransport:
self.t.close()
assert self.t.sock is None and self.t.connected is False
+ def test_close_os_error(self):
+ sock = self.t.sock = Mock()
+ sock.shutdown.side_effect = OSError
+ self.t.close()
+ sock.close.assert_called_with()
+ assert self.t.sock is None and self.t.connected is False
+
def test_read_frame__timeout(self):
self.t._read = Mock()
self.t._read.side_effect = socket.timeout()
@@ -719,7 +726,7 @@ class test_SSLTransport:
)
assert context.verify_mode == sentinel.CERT_REQS
- # testing context creation inside _wrap_socket_sni() with parameter
+ # testing context creation inside _wrap_socket_sni() with parameter
# cert_reqs == ssl.CERT_NONE. Previously raised ValueError because
# code path attempted to set context.verify_mode=ssl.CERT_NONE before
# setting context.check_hostname = False which raised a ValueError
@@ -740,7 +747,7 @@ class test_SSLTransport:
)
mock_load_default_certs.assert_not_called()
mock_wrap_socket.assert_called_once()
-
+
with patch('ssl.SSLContext.wrap_socket') as mock_wrap_socket:
with patch('ssl.SSLContext.load_default_certs') as mock_load_default_certs:
sock = Mock()