diff options
author | Moisés Guimarães de Medeiros <moisesguimaraes@users.noreply.github.com> | 2020-12-22 00:05:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-12-22 00:05:09 +0100 |
commit | 343a00e828d9d2d33998ccaf96dca0b9417f04af (patch) | |
tree | c56549004993bc36c2dd92b52124cde52a19bbfc | |
parent | c903aaeb581cc60bd1258df823a5a2a2a4ecdbd7 (diff) | |
download | py-amqp-343a00e828d9d2d33998ccaf96dca0b9417f04af.tar.gz |
Fix _wrap_socket_sni (#347)
* Change the default value of ssl_version to None. When not set, the
proper value between ssl.PROTOCOL_TLS_CLIENT and ssl.PROTOCOL_TLS_SERVER
will be selected based on the param server_side in order to create
a TLS Context object with better defaults that fit the desired
connection side.
* Change the default value of cert_reqs to None. The default value
of ctx.verify_mode is ssl.CERT_NONE, but when ssl.PROTOCOL_TLS_CLIENT
is used, ctx.verify_mode defaults to ssl.CERT_REQUIRED.
* Fix context.check_hostname logic. Checking the hostname depends on
having support of the SNI TLS extension and being provided with a
server_hostname value. Another important thing to mention is that
enabling hostname checking automatically sets verify_mode from
ssl.CERT_NONE to ssl.CERT_REQUIRED in the stdlib ssl and it cannot
be set back to ssl.CERT_NONE as long as hostname checking is enabled.
* Refactor the SNI tests to test one thing at a time and removing some
tests that were being repeated over and over.
Signed-off-by: Moisés Guimarães de Medeiros <guimaraes@pm.me>
-rw-r--r-- | amqp/transport.py | 29 | ||||
-rw-r--r-- | t/certs/ca_certificate.pem | 20 | ||||
-rw-r--r-- | t/integration/test_rmq.py | 10 | ||||
-rw-r--r-- | t/unit/test_transport.py | 141 |
4 files changed, 105 insertions, 95 deletions
diff --git a/amqp/transport.py b/amqp/transport.py index 2a7c190..ec100e5 100644 --- a/amqp/transport.py +++ b/amqp/transport.py @@ -436,10 +436,10 @@ class SSLTransport(_AbstractTransport): return ctx.wrap_socket(sock, **sslopts) def _wrap_socket_sni(self, sock, keyfile=None, certfile=None, - server_side=False, cert_reqs=ssl.CERT_NONE, + server_side=False, cert_reqs=None, ca_certs=None, do_handshake_on_connect=False, suppress_ragged_eofs=True, server_hostname=None, - ciphers=None, ssl_version=ssl.PROTOCOL_TLS): + ciphers=None, ssl_version=None): """Socket wrap with SNI headers. stdlib :attr:`ssl.SSLContext.wrap_socket` method augmented with support @@ -510,20 +510,31 @@ class SSLTransport(_AbstractTransport): 'server_hostname': server_hostname, } + if ssl_version is None: + ssl_version = ( + ssl.PROTOCOL_TLS_SERVER + if server_side + else ssl.PROTOCOL_TLS_CLIENT + ) + context = ssl.SSLContext(ssl_version) + if certfile is not None: context.load_cert_chain(certfile, keyfile) if ca_certs is not None: context.load_verify_locations(ca_certs) - if ciphers: + if ciphers is not None: context.set_ciphers(ciphers) - if cert_reqs != ssl.CERT_NONE: - context.check_hostname = True - # Set SNI headers if supported - if (server_hostname is not None) and ( - hasattr(ssl, 'HAS_SNI') and ssl.HAS_SNI) and ( - hasattr(ssl, 'SSLContext')): + if cert_reqs is not None: context.verify_mode = cert_reqs + # Set SNI headers if supported + try: + context.check_hostname = ( + ssl.HAS_SNI and server_hostname is not None + ) + except AttributeError: + pass # ask forgiveness not permission + sock = context.wrap_socket(**opts) return sock diff --git a/t/certs/ca_certificate.pem b/t/certs/ca_certificate.pem new file mode 100644 index 0000000..009936d --- /dev/null +++ b/t/certs/ca_certificate.pem @@ -0,0 +1,20 @@ +-----BEGIN CERTIFICATE----- +MIIDRzCCAi+gAwIBAgIJAMa1mrcNQtapMA0GCSqGSIb3DQEBCwUAMDExIDAeBgNV +BAMMF1RMU0dlblNlbGZTaWduZWR0Um9vdENBMQ0wCwYDVQQHDAQkJCQkMCAXDTIw +MDEwMzEyMDE0MFoYDzIxMTkxMjEwMTIwMTQwWjAxMSAwHgYDVQQDDBdUTFNHZW5T +ZWxmU2lnbmVkdFJvb3RDQTENMAsGA1UEBwwEJCQkJDCCASIwDQYJKoZIhvcNAQEB +BQADggEPADCCAQoCggEBAKdmOg5vtuZ5vNZmceToiVBlcFg9Y/xKNyCPBij6Wm5p +mXbnsjO1PhjGr97r2cMLq5QMvGt+FBEIjeeULtWVCBY7vMc4ATEZ1S2PmmKnOSXJ +MLMDIutznopZkyqt3gqWgXZDxxHIlIzJl0HirQmfeLm6eTOYyFoyFZV3CE2IeW4Y +n1zYhgZgIrU7Yo3I7wY9Js5yLk4p3etByN5tlLL2sdCOjRRXWGbOh/kb8uiyotEd +cxNThk0RQDugoEzaGYBU3bzDhKkm4v/v/xp/JxGLDl/e3heRMUbcw9d/0ujflouy +OQ66SNYGLWFQpmhtyHjalKzL5UbTcof4BQltoo/W7xECAwEAAaNgMF4wCwYDVR0P +BAQDAgEGMB0GA1UdDgQWBBTKOnbaptqaUCAiwtnwLcRTcbuRejAfBgNVHSMEGDAW +gBTKOnbaptqaUCAiwtnwLcRTcbuRejAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3 +DQEBCwUAA4IBAQB1tJUR9zoQ98bOz1es91PxgIt8VYR8/r6uIRtYWTBi7fgDRaaR +Glm6ZqOSXNlkacB6kjUzIyKJwGWnD9zU/06CH+ME1U497SVVhvtUEbdJb1COU+/5 +KavEHVINfc3tHD5Z5LJR3okEILAzBYkEcjYUECzBNYVi4l6PBSMSC2+RBKGqHkY7 +ApmD5batRghH5YtadiyF4h6bba/XSUqxzFcLKjKSyyds4ndvA1/yfl/7CrRtiZf0 +jw1pFl33/PTOhgi66MHa4uaKlL/hIjIlh4kJgJajqCN+TVU4Q6JNmSuIsq6rksSw +Rd5baBZrik2NHALr/ZN2Wy0nXiQJ3p+F20+X +-----END CERTIFICATE----- diff --git a/t/integration/test_rmq.py b/t/integration/test_rmq.py index f6b26d1..746aa65 100644 --- a/t/integration/test_rmq.py +++ b/t/integration/test_rmq.py @@ -8,12 +8,15 @@ import amqp def get_connection( - hostname, port, vhost, use_tls=False, keyfile=None, certfile=None): + hostname, port, vhost, use_tls=False, + keyfile=None, certfile=None, ca_certs=None +): host = f'{hostname}:{port}' if use_tls: return amqp.Connection(host=host, vhost=vhost, ssl={ 'keyfile': keyfile, - 'certfile': certfile + 'certfile': certfile, + 'ca_certs': ca_certs, } ) else: @@ -40,7 +43,8 @@ def connection(request): ).get("slaveid", None), use_tls=True, keyfile='t/certs/client_key.pem', - certfile='t/certs/client_certificate.pem' + certfile='t/certs/client_certificate.pem', + ca_certs='t/certs/ca_certificate.pem', ) diff --git a/t/unit/test_transport.py b/t/unit/test_transport.py index ad2750e..7f8d78c 100644 --- a/t/unit/test_transport.py +++ b/t/unit/test_transport.py @@ -640,108 +640,87 @@ class test_SSLTransport: def test_wrap_socket_sni(self): # testing default values of _wrap_socket_sni() sock = Mock() - with patch( - 'ssl.SSLContext.wrap_socket', - return_value=sentinel.WRAPPED_SOCKET) as mock_ssl_wrap: + with patch('ssl.SSLContext') as mock_ssl_context_class: + wrap_socket_method_mock = mock_ssl_context_class().wrap_socket + wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET ret = self.t._wrap_socket_sni(sock) - mock_ssl_wrap.assert_called_with(sock=sock, - server_side=False, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - server_hostname=None) - + mock_ssl_context_class.load_cert_chain.assert_not_called() + mock_ssl_context_class.load_verify_locations.assert_not_called() + mock_ssl_context_class.set_ciphers.assert_not_called() + mock_ssl_context_class.verify_mode.assert_not_called() + wrap_socket_method_mock.assert_called_with( + sock=sock, + server_side=False, + do_handshake_on_connect=False, + suppress_ragged_eofs=True, + server_hostname=None + ) assert ret == sentinel.WRAPPED_SOCKET def test_wrap_socket_sni_certfile(self): # testing _wrap_socket_sni() with parameters certfile and keyfile - sock = Mock() - with patch( - 'ssl.SSLContext.wrap_socket', - return_value=sentinel.WRAPPED_SOCKET - ) as mock_ssl_wrap, patch( - 'ssl.SSLContext.load_cert_chain' - ) as mock_load_cert_chain: - ret = self.t._wrap_socket_sni( - sock, keyfile=sentinel.KEYFILE, certfile=sentinel.CERTFILE) - - mock_load_cert_chain.assert_called_with( - sentinel.CERTFILE, sentinel.KEYFILE) - mock_ssl_wrap.assert_called_with(sock=sock, - server_side=False, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - server_hostname=None) + with patch('ssl.SSLContext') as mock_ssl_context_class: + load_cert_chain_method_mock = \ + mock_ssl_context_class().load_cert_chain + self.t._wrap_socket_sni( + Mock(), keyfile=sentinel.KEYFILE, certfile=sentinel.CERTFILE + ) - assert ret == sentinel.WRAPPED_SOCKET + load_cert_chain_method_mock.assert_called_with( + sentinel.CERTFILE, sentinel.KEYFILE + ) def test_wrap_socket_ca_certs(self): # testing _wrap_socket_sni() with parameter ca_certs - sock = Mock() - with patch( - 'ssl.SSLContext.wrap_socket', - return_value=sentinel.WRAPPED_SOCKET - ) as mock_ssl_wrap, patch( - 'ssl.SSLContext.load_verify_locations' - ) as mock_load_verify_locations: - ret = self.t._wrap_socket_sni(sock, ca_certs=sentinel.CA_CERTS) - - mock_load_verify_locations.assert_called_with(sentinel.CA_CERTS) - mock_ssl_wrap.assert_called_with(sock=sock, - server_side=False, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - server_hostname=None) + with patch('ssl.SSLContext') as mock_ssl_context_class: + load_verify_locations_method_mock = \ + mock_ssl_context_class().load_verify_locations + self.t._wrap_socket_sni(Mock(), ca_certs=sentinel.CA_CERTS) - assert ret == sentinel.WRAPPED_SOCKET + load_verify_locations_method_mock.assert_called_with(sentinel.CA_CERTS) def test_wrap_socket_ciphers(self): # testing _wrap_socket_sni() with parameter ciphers - sock = Mock() - with patch( - 'ssl.SSLContext.wrap_socket', - return_value=sentinel.WRAPPED_SOCKET) as mock_ssl_wrap, \ - patch('ssl.SSLContext.set_ciphers') as mock_set_ciphers: - ret = self.t._wrap_socket_sni(sock, ciphers=sentinel.CIPHERS) - - mock_set_ciphers.assert_called_with(sentinel.CIPHERS) - mock_ssl_wrap.assert_called_with(sock=sock, - server_side=False, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - server_hostname=None) - assert ret == sentinel.WRAPPED_SOCKET + with patch('ssl.SSLContext') as mock_ssl_context_class: + set_ciphers_method_mock = mock_ssl_context_class().set_ciphers + self.t._wrap_socket_sni(Mock(), ciphers=sentinel.CIPHERS) + + set_ciphers_method_mock.assert_called_with(sentinel.CIPHERS) def test_wrap_socket_sni_cert_reqs(self): # testing _wrap_socket_sni() with parameter cert_reqs - sock = Mock() with patch('ssl.SSLContext') as mock_ssl_context_class: - wrap_socket_method_mock = mock_ssl_context_class().wrap_socket - wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET - ret = self.t._wrap_socket_sni(sock, cert_reqs=sentinel.CERT_REQS) + self.t._wrap_socket_sni(Mock(), cert_reqs=sentinel.CERT_REQS) - wrap_socket_method_mock.assert_called_with( - sock=sock, - server_side=False, - do_handshake_on_connect=False, - suppress_ragged_eofs=True, - server_hostname=None - ) - assert mock_ssl_context_class().check_hostname is True - assert ret == sentinel.WRAPPED_SOCKET + assert mock_ssl_context_class().verify_mode == sentinel.CERT_REQS def test_wrap_socket_sni_setting_sni_header(self): - # testing _wrap_socket_sni() with setting SNI header + # testing _wrap_socket_sni() without parameter server_hostname + # SSL module supports SNI + with patch('ssl.SSLContext') as mock_ssl_context_class, \ + patch('ssl.HAS_SNI', new=True): + self.t._wrap_socket_sni(Mock()) + + assert mock_ssl_context_class().check_hostname is False + + # SSL module does not support SNI + with patch('ssl.SSLContext') as mock_ssl_context_class, \ + patch('ssl.HAS_SNI', new=False): + self.t._wrap_socket_sni(Mock()) + + assert mock_ssl_context_class().check_hostname is False + + # testing _wrap_socket_sni() with parameter server_hostname sock = Mock() with patch('ssl.SSLContext') as mock_ssl_context_class, \ patch('ssl.HAS_SNI', new=True): # SSL module supports SNI wrap_socket_method_mock = mock_ssl_context_class().wrap_socket - wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET - ret = self.t._wrap_socket_sni( - sock, cert_reqs=sentinel.CERT_REQS, - server_hostname=sentinel.SERVER_HOSTNAME + self.t._wrap_socket_sni( + sock, server_hostname=sentinel.SERVER_HOSTNAME ) + wrap_socket_method_mock.assert_called_with( sock=sock, server_side=False, @@ -749,17 +728,14 @@ class test_SSLTransport: suppress_ragged_eofs=True, server_hostname=sentinel.SERVER_HOSTNAME ) - assert mock_ssl_context_class().verify_mode == sentinel.CERT_REQS - assert ret == sentinel.WRAPPED_SOCKET + assert mock_ssl_context_class().check_hostname is True with patch('ssl.SSLContext') as mock_ssl_context_class, \ patch('ssl.HAS_SNI', new=False): # SSL module does not support SNI wrap_socket_method_mock = mock_ssl_context_class().wrap_socket - wrap_socket_method_mock.return_value = sentinel.WRAPPED_SOCKET - ret = self.t._wrap_socket_sni( - sock, cert_reqs=sentinel.CERT_REQS, - server_hostname=sentinel.SERVER_HOSTNAME + self.t._wrap_socket_sni( + sock, server_hostname=sentinel.SERVER_HOSTNAME ) wrap_socket_method_mock.assert_called_with( sock=sock, @@ -768,8 +744,7 @@ class test_SSLTransport: suppress_ragged_eofs=True, server_hostname=sentinel.SERVER_HOSTNAME ) - assert mock_ssl_context_class().verify_mode != sentinel.CERT_REQS - assert ret == sentinel.WRAPPED_SOCKET + assert mock_ssl_context_class().check_hostname is False def test_shutdown_transport(self): self.t.sock = None |