summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMoisés Guimarães de Medeiros <moisesguimaraes@users.noreply.github.com>2020-12-22 00:05:09 +0100
committerGitHub <noreply@github.com>2020-12-22 00:05:09 +0100
commit343a00e828d9d2d33998ccaf96dca0b9417f04af (patch)
treec56549004993bc36c2dd92b52124cde52a19bbfc
parentc903aaeb581cc60bd1258df823a5a2a2a4ecdbd7 (diff)
downloadpy-amqp-343a00e828d9d2d33998ccaf96dca0b9417f04af.tar.gz
Fix _wrap_socket_sni (#347)
* Change the default value of ssl_version to None. When not set, the proper value between ssl.PROTOCOL_TLS_CLIENT and ssl.PROTOCOL_TLS_SERVER will be selected based on the param server_side in order to create a TLS Context object with better defaults that fit the desired connection side. * Change the default value of cert_reqs to None. The default value of ctx.verify_mode is ssl.CERT_NONE, but when ssl.PROTOCOL_TLS_CLIENT is used, ctx.verify_mode defaults to ssl.CERT_REQUIRED. * Fix context.check_hostname logic. Checking the hostname depends on having support of the SNI TLS extension and being provided with a server_hostname value. Another important thing to mention is that enabling hostname checking automatically sets verify_mode from ssl.CERT_NONE to ssl.CERT_REQUIRED in the stdlib ssl and it cannot be set back to ssl.CERT_NONE as long as hostname checking is enabled. * Refactor the SNI tests to test one thing at a time and removing some tests that were being repeated over and over. Signed-off-by: Moisés Guimarães de Medeiros <guimaraes@pm.me>
-rw-r--r--amqp/transport.py29
-rw-r--r--t/certs/ca_certificate.pem20
-rw-r--r--t/integration/test_rmq.py10
-rw-r--r--t/unit/test_transport.py141
4 files changed, 105 insertions, 95 deletions
diff --git a/amqp/transport.py b/amqp/transport.py
index 2a7c190..ec100e5 100644
--- a/amqp/transport.py
+++ b/amqp/transport.py
@@ -436,10 +436,10 @@ class SSLTransport(_AbstractTransport):
return ctx.wrap_socket(sock, **sslopts)
def _wrap_socket_sni(self, sock, keyfile=None, certfile=None,
- server_side=False, cert_reqs=ssl.CERT_NONE,
+ server_side=False, cert_reqs=None,
ca_certs=None, do_handshake_on_connect=False,
suppress_ragged_eofs=True, server_hostname=None,
- ciphers=None, ssl_version=ssl.PROTOCOL_TLS):
+ ciphers=None, ssl_version=None):
"""Socket wrap with SNI headers.
stdlib :attr:`ssl.SSLContext.wrap_socket` method augmented with support
@@ -510,20 +510,31 @@ class SSLTransport(_AbstractTransport):
'server_hostname': server_hostname,
}
+ if ssl_version is None:
+ ssl_version = (
+ ssl.PROTOCOL_TLS_SERVER
+ if server_side
+ else ssl.PROTOCOL_TLS_CLIENT
+ )
+
context = ssl.SSLContext(ssl_version)
+
if certfile is not None:
context.load_cert_chain(certfile, keyfile)
if ca_certs is not None:
context.load_verify_locations(ca_certs)
- if ciphers:
+ if ciphers is not None:
context.set_ciphers(ciphers)
- if cert_reqs != ssl.CERT_NONE:
- context.check_hostname = True
- # Set SNI headers if supported
- if (server_hostname is not None) and (
- hasattr(ssl, 'HAS_SNI') and ssl.HAS_SNI) and (
- hasattr(ssl, 'SSLContext')):
+ if cert_reqs is not None:
context.verify_mode = cert_reqs
+ # Set SNI headers if supported
+ try:
+ context.check_hostname = (
+ ssl.HAS_SNI and server_hostname is not None
+ )
+ except AttributeError:
+ pass # ask forgiveness not permission
+
sock = context.wrap_socket(**opts)
return sock
diff --git a/t/certs/ca_certificate.pem b/t/certs/ca_certificate.pem
new file mode 100644
index 0000000..009936d
--- /dev/null
+++ b/t/certs/ca_certificate.pem
@@ -0,0 +1,20 @@
+-----BEGIN CERTIFICATE-----
+MIIDRzCCAi+gAwIBAgIJAMa1mrcNQtapMA0GCSqGSIb3DQEBCwUAMDExIDAeBgNV
+BAMMF1RMU0dlblNlbGZTaWduZWR0Um9vdENBMQ0wCwYDVQQHDAQkJCQkMCAXDTIw
+MDEwMzEyMDE0MFoYDzIxMTkxMjEwMTIwMTQwWjAxMSAwHgYDVQQDDBdUTFNHZW5T
+ZWxmU2lnbmVkdFJvb3RDQTENMAsGA1UEBwwEJCQkJDCCASIwDQYJKoZIhvcNAQEB
+BQADggEPADCCAQoCggEBAKdmOg5vtuZ5vNZmceToiVBlcFg9Y/xKNyCPBij6Wm5p
+mXbnsjO1PhjGr97r2cMLq5QMvGt+FBEIjeeULtWVCBY7vMc4ATEZ1S2PmmKnOSXJ
+MLMDIutznopZkyqt3gqWgXZDxxHIlIzJl0HirQmfeLm6eTOYyFoyFZV3CE2IeW4Y
+n1zYhgZgIrU7Yo3I7wY9Js5yLk4p3etByN5tlLL2sdCOjRRXWGbOh/kb8uiyotEd
+cxNThk0RQDugoEzaGYBU3bzDhKkm4v/v/xp/JxGLDl/e3heRMUbcw9d/0ujflouy
+OQ66SNYGLWFQpmhtyHjalKzL5UbTcof4BQltoo/W7xECAwEAAaNgMF4wCwYDVR0P
+BAQDAgEGMB0GA1UdDgQWBBTKOnbaptqaUCAiwtnwLcRTcbuRejAfBgNVHSMEGDAW
+gBTKOnbaptqaUCAiwtnwLcRTcbuRejAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3
+DQEBCwUAA4IBAQB1tJUR9zoQ98bOz1es91PxgIt8VYR8/r6uIRtYWTBi7fgDRaaR
+Glm6ZqOSXNlkacB6kjUzIyKJwGWnD9zU/06CH+ME1U497SVVhvtUEbdJb1COU+/5
+KavEHVINfc3tHD5Z5LJR3okEILAzBYkEcjYUECzBNYVi4l6PBSMSC2+RBKGqHkY7
+ApmD5batRghH5YtadiyF4h6bba/XSUqxzFcLKjKSyyds4ndvA1/yfl/7CrRtiZf0
+jw1pFl33/PTOhgi66MHa4uaKlL/hIjIlh4kJgJajqCN+TVU4Q6JNmSuIsq6rksSw
+Rd5baBZrik2NHALr/ZN2Wy0nXiQJ3p+F20+X
+-----END CERTIFICATE-----
diff --git a/t/integration/test_rmq.py b/t/integration/test_rmq.py
index f6b26d1..746aa65 100644
--- a/t/integration/test_rmq.py
+++ b/t/integration/test_rmq.py
@@ -8,12 +8,15 @@ import amqp
def get_connection(
- hostname, port, vhost, use_tls=False, keyfile=None, certfile=None):
+ hostname, port, vhost, use_tls=False,
+ keyfile=None, certfile=None, ca_certs=None
+):
host = f'{hostname}:{port}'
if use_tls:
return amqp.Connection(host=host, vhost=vhost, ssl={
'keyfile': keyfile,
- 'certfile': certfile
+ 'certfile': certfile,
+ 'ca_certs': ca_certs,
}
)
else:
@@ -40,7 +43,8 @@ def connection(request):
).get("slaveid", None),
use_tls=True,
keyfile='t/certs/client_key.pem',
- certfile='t/certs/client_certificate.pem'
+ certfile='t/certs/client_certificate.pem',
+ ca_certs='t/certs/ca_certificate.pem',
)
diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py
index ad2750e..7f8d78c 100644
--- a/t/unit/test_transport.py
+++ b/t/unit/test_transport.py
@@ -640,108 +640,87 @@ class test_SSLTransport:
def test_wrap_socket_sni(self):
# testing default values of _wrap_socket_sni()
sock = Mock()
- with patch(
- 'ssl.SSLContext.wrap_socket',
- return_value=sentinel.WRAPPED_SOCKET) as mock_ssl_wrap:
+ with patch('ssl.SSLContext') as mock_ssl_context_class:
+ wrap_socket_method_mock = mock_ssl_context_class().wrap_socket
+ wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET
ret = self.t._wrap_socket_sni(sock)
- mock_ssl_wrap.assert_called_with(sock=sock,
- server_side=False,
- do_handshake_on_connect=False,
- suppress_ragged_eofs=True,
- server_hostname=None)
-
+ mock_ssl_context_class.load_cert_chain.assert_not_called()
+ mock_ssl_context_class.load_verify_locations.assert_not_called()
+ mock_ssl_context_class.set_ciphers.assert_not_called()
+ mock_ssl_context_class.verify_mode.assert_not_called()
+ wrap_socket_method_mock.assert_called_with(
+ sock=sock,
+ server_side=False,
+ do_handshake_on_connect=False,
+ suppress_ragged_eofs=True,
+ server_hostname=None
+ )
assert ret == sentinel.WRAPPED_SOCKET
def test_wrap_socket_sni_certfile(self):
# testing _wrap_socket_sni() with parameters certfile and keyfile
- sock = Mock()
- with patch(
- 'ssl.SSLContext.wrap_socket',
- return_value=sentinel.WRAPPED_SOCKET
- ) as mock_ssl_wrap, patch(
- 'ssl.SSLContext.load_cert_chain'
- ) as mock_load_cert_chain:
- ret = self.t._wrap_socket_sni(
- sock, keyfile=sentinel.KEYFILE, certfile=sentinel.CERTFILE)
-
- mock_load_cert_chain.assert_called_with(
- sentinel.CERTFILE, sentinel.KEYFILE)
- mock_ssl_wrap.assert_called_with(sock=sock,
- server_side=False,
- do_handshake_on_connect=False,
- suppress_ragged_eofs=True,
- server_hostname=None)
+ with patch('ssl.SSLContext') as mock_ssl_context_class:
+ load_cert_chain_method_mock = \
+ mock_ssl_context_class().load_cert_chain
+ self.t._wrap_socket_sni(
+ Mock(), keyfile=sentinel.KEYFILE, certfile=sentinel.CERTFILE
+ )
- assert ret == sentinel.WRAPPED_SOCKET
+ load_cert_chain_method_mock.assert_called_with(
+ sentinel.CERTFILE, sentinel.KEYFILE
+ )
def test_wrap_socket_ca_certs(self):
# testing _wrap_socket_sni() with parameter ca_certs
- sock = Mock()
- with patch(
- 'ssl.SSLContext.wrap_socket',
- return_value=sentinel.WRAPPED_SOCKET
- ) as mock_ssl_wrap, patch(
- 'ssl.SSLContext.load_verify_locations'
- ) as mock_load_verify_locations:
- ret = self.t._wrap_socket_sni(sock, ca_certs=sentinel.CA_CERTS)
-
- mock_load_verify_locations.assert_called_with(sentinel.CA_CERTS)
- mock_ssl_wrap.assert_called_with(sock=sock,
- server_side=False,
- do_handshake_on_connect=False,
- suppress_ragged_eofs=True,
- server_hostname=None)
+ with patch('ssl.SSLContext') as mock_ssl_context_class:
+ load_verify_locations_method_mock = \
+ mock_ssl_context_class().load_verify_locations
+ self.t._wrap_socket_sni(Mock(), ca_certs=sentinel.CA_CERTS)
- assert ret == sentinel.WRAPPED_SOCKET
+ load_verify_locations_method_mock.assert_called_with(sentinel.CA_CERTS)
def test_wrap_socket_ciphers(self):
# testing _wrap_socket_sni() with parameter ciphers
- sock = Mock()
- with patch(
- 'ssl.SSLContext.wrap_socket',
- return_value=sentinel.WRAPPED_SOCKET) as mock_ssl_wrap, \
- patch('ssl.SSLContext.set_ciphers') as mock_set_ciphers:
- ret = self.t._wrap_socket_sni(sock, ciphers=sentinel.CIPHERS)
-
- mock_set_ciphers.assert_called_with(sentinel.CIPHERS)
- mock_ssl_wrap.assert_called_with(sock=sock,
- server_side=False,
- do_handshake_on_connect=False,
- suppress_ragged_eofs=True,
- server_hostname=None)
- assert ret == sentinel.WRAPPED_SOCKET
+ with patch('ssl.SSLContext') as mock_ssl_context_class:
+ set_ciphers_method_mock = mock_ssl_context_class().set_ciphers
+ self.t._wrap_socket_sni(Mock(), ciphers=sentinel.CIPHERS)
+
+ set_ciphers_method_mock.assert_called_with(sentinel.CIPHERS)
def test_wrap_socket_sni_cert_reqs(self):
# testing _wrap_socket_sni() with parameter cert_reqs
- sock = Mock()
with patch('ssl.SSLContext') as mock_ssl_context_class:
- wrap_socket_method_mock = mock_ssl_context_class().wrap_socket
- wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET
- ret = self.t._wrap_socket_sni(sock, cert_reqs=sentinel.CERT_REQS)
+ self.t._wrap_socket_sni(Mock(), cert_reqs=sentinel.CERT_REQS)
- wrap_socket_method_mock.assert_called_with(
- sock=sock,
- server_side=False,
- do_handshake_on_connect=False,
- suppress_ragged_eofs=True,
- server_hostname=None
- )
- assert mock_ssl_context_class().check_hostname is True
- assert ret == sentinel.WRAPPED_SOCKET
+ assert mock_ssl_context_class().verify_mode == sentinel.CERT_REQS
def test_wrap_socket_sni_setting_sni_header(self):
- # testing _wrap_socket_sni() with setting SNI header
+ # testing _wrap_socket_sni() without parameter server_hostname
+ # SSL module supports SNI
+ with patch('ssl.SSLContext') as mock_ssl_context_class, \
+ patch('ssl.HAS_SNI', new=True):
+ self.t._wrap_socket_sni(Mock())
+
+ assert mock_ssl_context_class().check_hostname is False
+
+ # SSL module does not support SNI
+ with patch('ssl.SSLContext') as mock_ssl_context_class, \
+ patch('ssl.HAS_SNI', new=False):
+ self.t._wrap_socket_sni(Mock())
+
+ assert mock_ssl_context_class().check_hostname is False
+
+ # testing _wrap_socket_sni() with parameter server_hostname
sock = Mock()
with patch('ssl.SSLContext') as mock_ssl_context_class, \
patch('ssl.HAS_SNI', new=True):
# SSL module supports SNI
wrap_socket_method_mock = mock_ssl_context_class().wrap_socket
- wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET
- ret = self.t._wrap_socket_sni(
- sock, cert_reqs=sentinel.CERT_REQS,
- server_hostname=sentinel.SERVER_HOSTNAME
+ self.t._wrap_socket_sni(
+ sock, server_hostname=sentinel.SERVER_HOSTNAME
)
+
wrap_socket_method_mock.assert_called_with(
sock=sock,
server_side=False,
@@ -749,17 +728,14 @@ class test_SSLTransport:
suppress_ragged_eofs=True,
server_hostname=sentinel.SERVER_HOSTNAME
)
- assert mock_ssl_context_class().verify_mode == sentinel.CERT_REQS
- assert ret == sentinel.WRAPPED_SOCKET
+ assert mock_ssl_context_class().check_hostname is True
with patch('ssl.SSLContext') as mock_ssl_context_class, \
patch('ssl.HAS_SNI', new=False):
# SSL module does not support SNI
wrap_socket_method_mock = mock_ssl_context_class().wrap_socket
- wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET
- ret = self.t._wrap_socket_sni(
- sock, cert_reqs=sentinel.CERT_REQS,
- server_hostname=sentinel.SERVER_HOSTNAME
+ self.t._wrap_socket_sni(
+ sock, server_hostname=sentinel.SERVER_HOSTNAME
)
wrap_socket_method_mock.assert_called_with(
sock=sock,
@@ -768,8 +744,7 @@ class test_SSLTransport:
suppress_ragged_eofs=True,
server_hostname=sentinel.SERVER_HOSTNAME
)
- assert mock_ssl_context_class().verify_mode != sentinel.CERT_REQS
- assert ret == sentinel.WRAPPED_SOCKET
+ assert mock_ssl_context_class().check_hostname is False
def test_shutdown_transport(self):
self.t.sock = None