summaryrefslogtreecommitdiff
path: root/Lib/test
diff options
context:
space:
mode:
authorChristian Heimes <christian@python.org>2019-05-31 11:44:05 +0200
committerGitHub <noreply@github.com>2019-05-31 11:44:05 +0200
commitc7f7069e77c58e83b847c0bfe4d5aadf6add2e68 (patch)
tree306bee26619ebc132be4b98fd60d0daf79964cf0 /Lib/test
parente9b51c0ad81da1da11ae65840ac8b50a8521373c (diff)
downloadcpython-git-c7f7069e77c58e83b847c0bfe4d5aadf6add2e68.tar.gz
bpo-34271: Add ssl debugging helpers (GH-10031)
The ssl module now can dump key material to a keylog file and trace TLS protocol messages with a tracing callback. The default and stdlib contexts also support SSLKEYLOGFILE env var. The msg_callback and related enums are private members. The feature is designed for internal debugging and not for end users. Signed-off-by: Christian Heimes <christian@python.org>
Diffstat (limited to 'Lib/test')
-rw-r--r--Lib/test/test_ssl.py168
1 files changed, 167 insertions, 1 deletions
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index d48d6e5569..f368906c8a 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -2,6 +2,7 @@
import sys
import unittest
+import unittest.mock
from test import support
import socket
import select
@@ -25,6 +26,7 @@ except ImportError:
ssl = support.import_module("ssl")
+from ssl import TLSVersion, _TLSContentType, _TLSMessageType, _TLSAlertType
PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
HOST = support.HOST
@@ -4405,6 +4407,170 @@ class TestPostHandshakeAuth(unittest.TestCase):
self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
+HAS_KEYLOG = hasattr(ssl.SSLContext, 'keylog_filename')
+requires_keylog = unittest.skipUnless(
+ HAS_KEYLOG, 'test requires OpenSSL 1.1.1 with keylog callback')
+
+class TestSSLDebug(unittest.TestCase):
+
+ def keylog_lines(self, fname=support.TESTFN):
+ with open(fname) as f:
+ return len(list(f))
+
+ @requires_keylog
+ def test_keylog_defaults(self):
+ self.addCleanup(support.unlink, support.TESTFN)
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ self.assertEqual(ctx.keylog_filename, None)
+
+ self.assertFalse(os.path.isfile(support.TESTFN))
+ ctx.keylog_filename = support.TESTFN
+ self.assertEqual(ctx.keylog_filename, support.TESTFN)
+ self.assertTrue(os.path.isfile(support.TESTFN))
+ self.assertEqual(self.keylog_lines(), 1)
+
+ ctx.keylog_filename = None
+ self.assertEqual(ctx.keylog_filename, None)
+
+ with self.assertRaises((IsADirectoryError, PermissionError)):
+ # Windows raises PermissionError
+ ctx.keylog_filename = os.path.dirname(
+ os.path.abspath(support.TESTFN))
+
+ with self.assertRaises(TypeError):
+ ctx.keylog_filename = 1
+
+ @requires_keylog
+ def test_keylog_filename(self):
+ self.addCleanup(support.unlink, support.TESTFN)
+ client_context, server_context, hostname = testing_context()
+
+ client_context.keylog_filename = support.TESTFN
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ # header, 5 lines for TLS 1.3
+ self.assertEqual(self.keylog_lines(), 6)
+
+ client_context.keylog_filename = None
+ server_context.keylog_filename = support.TESTFN
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ self.assertGreaterEqual(self.keylog_lines(), 11)
+
+ client_context.keylog_filename = support.TESTFN
+ server_context.keylog_filename = support.TESTFN
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+ self.assertGreaterEqual(self.keylog_lines(), 21)
+
+ client_context.keylog_filename = None
+ server_context.keylog_filename = None
+
+ @requires_keylog
+ @unittest.skipIf(sys.flags.ignore_environment,
+ "test is not compatible with ignore_environment")
+ def test_keylog_env(self):
+ self.addCleanup(support.unlink, support.TESTFN)
+ with unittest.mock.patch.dict(os.environ):
+ os.environ['SSLKEYLOGFILE'] = support.TESTFN
+ self.assertEqual(os.environ['SSLKEYLOGFILE'], support.TESTFN)
+
+ ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
+ self.assertEqual(ctx.keylog_filename, None)
+
+ ctx = ssl.create_default_context()
+ self.assertEqual(ctx.keylog_filename, support.TESTFN)
+
+ ctx = ssl._create_stdlib_context()
+ self.assertEqual(ctx.keylog_filename, support.TESTFN)
+
+ def test_msg_callback(self):
+ client_context, server_context, hostname = testing_context()
+
+ def msg_cb(conn, direction, version, content_type, msg_type, data):
+ pass
+
+ self.assertIs(client_context._msg_callback, None)
+ client_context._msg_callback = msg_cb
+ self.assertIs(client_context._msg_callback, msg_cb)
+ with self.assertRaises(TypeError):
+ client_context._msg_callback = object()
+
+ def test_msg_callback_tls12(self):
+ client_context, server_context, hostname = testing_context()
+ client_context.options |= ssl.OP_NO_TLSv1_3
+
+ msg = []
+
+ def msg_cb(conn, direction, version, content_type, msg_type, data):
+ self.assertIsInstance(conn, ssl.SSLSocket)
+ self.assertIsInstance(data, bytes)
+ self.assertIn(direction, {'read', 'write'})
+ msg.append((direction, version, content_type, msg_type))
+
+ client_context._msg_callback = msg_cb
+
+ server = ThreadedEchoServer(context=server_context, chatty=False)
+ with server:
+ with client_context.wrap_socket(socket.socket(),
+ server_hostname=hostname) as s:
+ s.connect((HOST, server.port))
+
+ self.assertEqual(msg, [
+ ("write", TLSVersion.TLSv1, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.CLIENT_HELLO),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.SERVER_HELLO),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.CERTIFICATE),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.SERVER_KEY_EXCHANGE),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.SERVER_DONE),
+ ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.CLIENT_KEY_EXCHANGE),
+ ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.FINISHED),
+ ("write", TLSVersion.TLSv1_2, _TLSContentType.CHANGE_CIPHER_SPEC,
+ _TLSMessageType.CHANGE_CIPHER_SPEC),
+ ("write", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("write", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.FINISHED),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.NEWSESSION_TICKET),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.FINISHED),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HEADER,
+ _TLSMessageType.CERTIFICATE_STATUS),
+ ("read", TLSVersion.TLSv1_2, _TLSContentType.HANDSHAKE,
+ _TLSMessageType.FINISHED),
+ ])
+
+
def test_main(verbose=False):
if support.verbose:
import warnings
@@ -4440,7 +4606,7 @@ def test_main(verbose=False):
tests = [
ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
- TestPostHandshakeAuth
+ TestPostHandshakeAuth, TestSSLDebug
]
if support.is_resource_enabled('network'):