diff options
author | Paul Kehrer <paul.l.kehrer@gmail.com> | 2016-03-18 18:07:25 -0400 |
---|---|---|
committer | Paul Kehrer <paul.l.kehrer@gmail.com> | 2016-03-18 18:07:25 -0400 |
commit | 9dff5c4ce371a585f922f1ed9398cec06a804a9e (patch) | |
tree | 674c4177681fab26201885901c501f974a772d93 /tests | |
parent | 901636c512e1553b0a5070a8f0ab6940ae8f1d71 (diff) | |
parent | 5be951cb834b4ff00e6bffcd6e8268535f259b41 (diff) | |
download | pyopenssl-git-9dff5c4ce371a585f922f1ed9398cec06a804a9e.tar.gz |
Merge pull request #422 from hynek/set_session_id
Implement missing methods
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_ssl.py | 154 |
1 files changed, 133 insertions, 21 deletions
diff --git a/tests/test_ssl.py b/tests/test_ssl.py index 8040c97..ab316fc 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -5,6 +5,9 @@ Unit tests for :mod:`OpenSSL.SSL`. """ +import datetime +import uuid + from gc import collect, get_referrers from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN from sys import platform, getfilesystemencoding, version_info @@ -19,6 +22,14 @@ import pytest from six import PY3, text_type +from cryptography import x509 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import serialization +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography.x509.oid import NameOID + + from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM from OpenSSL.crypto import PKey, X509, X509Extension, X509Store from OpenSSL.crypto import dump_privatekey, load_privatekey @@ -348,6 +359,49 @@ class VersionTests(TestCase): @pytest.fixture +def ca_file(tmpdir): + """ + Create a valid PEM file with CA certificates and return the path. + """ + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + backend=default_backend() + ) + public_key = key.public_key() + + builder = x509.CertificateBuilder() + builder = builder.subject_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"), + ])) + builder = builder.issuer_name(x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, u"pyopenssl.org"), + ])) + one_day = datetime.timedelta(1, 0, 0) + builder = builder.not_valid_before(datetime.datetime.today() - one_day) + builder = builder.not_valid_after(datetime.datetime.today() + one_day) + builder = builder.serial_number(int(uuid.uuid4())) + builder = builder.public_key(public_key) + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True, + ) + + certificate = builder.sign( + private_key=key, algorithm=hashes.SHA256(), + backend=default_backend() + ) + + ca_file = tmpdir.join("test.pem") + ca_file.write_binary( + certificate.public_bytes( + encoding=serialization.Encoding.PEM, + ) + ) + + return str(ca_file).encode("ascii") + + +@pytest.fixture def context(): """ A simple TLS 1.0 context. @@ -389,6 +443,59 @@ class TestContext(object): with pytest.raises(error): context.set_cipher_list(cipher_list) + def test_load_client_ca(self, context, ca_file): + """ + :meth:`Context.load_client_ca` works as far as we can tell. + """ + context.load_client_ca(ca_file) + + def test_load_client_ca_invalid(self, context, tmpdir): + """ + :meth:`Context.load_client_ca` raises an Error if the ca file is + invalid. + """ + ca_file = tmpdir.join("test.pem") + ca_file.write("") + + with pytest.raises(Error) as e: + context.load_client_ca(str(ca_file).encode("ascii")) + + assert "PEM routines" == e.value.args[0][0][0] + + def test_load_client_ca_unicode(self, context, ca_file): + """ + Passing the path as unicode raises a warning but works. + """ + pytest.deprecated_call( + context.load_client_ca, ca_file.decode("ascii") + ) + + def test_set_session_id(self, context): + """ + :meth:`Context.set_session_id` works as far as we can tell. + """ + context.set_session_id(b"abc") + + def test_set_session_id_fail(self, context): + """ + :meth:`Context.set_session_id` errors are propagated. + """ + with pytest.raises(Error) as e: + context.set_session_id(b"abc" * 1000) + + assert [ + ("SSL routines", + "SSL_CTX_set_session_id_context", + "ssl session id context too long") + ] == e.value.args[0] + + def test_set_session_id_unicode(self, context): + """ + :meth:`Context.set_session_id` raises a warning if a unicode string is + passed. + """ + pytest.deprecated_call(context.set_session_id, u"abc") + class ContextTests(TestCase, _LoopbackMixin): """ @@ -1210,9 +1317,10 @@ class ContextTests(TestCase, _LoopbackMixin): raise Exception("silly verify failure") clientContext.set_verify(VERIFY_PEER, verify_callback) - exc = self.assertRaises( - Exception, self._handshake_test, serverContext, clientContext) - self.assertEqual("silly verify failure", str(exc)) + with pytest.raises(Exception) as exc: + self._handshake_test(serverContext, clientContext) + + self.assertEqual("silly verify failure", str(exc.value)) def test_add_extra_chain_cert(self): """ @@ -1338,9 +1446,6 @@ class ContextTests(TestCase, _LoopbackMixin): Error, context.use_certificate_chain_file, self.mktemp() ) - # XXX load_client_ca - # XXX set_session_id - def test_get_verify_mode_wrong_args(self): """ :py:obj:`Context.get_verify_mode` raises :py:obj:`TypeError` if called @@ -2033,7 +2138,6 @@ class ConnectionTests(TestCase, _LoopbackMixin): # XXX connect_ex -> TypeError # XXX set_connect_state -> TypeError # XXX set_accept_state -> TypeError - # XXX renegotiate_pending # XXX do_handshake -> TypeError # XXX bio_read -> TypeError # XXX recv -> TypeError @@ -3136,24 +3240,32 @@ class ConnectionRenegotiateTests(TestCase, _LoopbackMixin): connection = Connection(Context(TLSv1_METHOD), None) self.assertEquals(connection.total_renegotiations(), 0) -# def test_renegotiate(self): -# """ -# """ -# server, client = self._loopback() + def test_renegotiate(self): + """ + Go through a complete renegotiation cycle. + """ + server, client = self._loopback() + + server.send(b"hello world") + + assert b"hello world" == client.recv(len(b"hello world")) -# server.send("hello world") -# self.assertEquals(client.recv(len("hello world")), "hello world") + assert 0 == server.total_renegotiations() + assert False is server.renegotiate_pending() -# self.assertEquals(server.total_renegotiations(), 0) -# self.assertTrue(server.renegotiate()) + assert True is server.renegotiate() -# server.setblocking(False) -# client.setblocking(False) -# while server.renegotiate_pending(): -# client.do_handshake() -# server.do_handshake() + assert True is server.renegotiate_pending() -# self.assertEquals(server.total_renegotiations(), 1) + server.setblocking(False) + client.setblocking(False) + + client.do_handshake() + server.do_handshake() + + assert 1 == server.total_renegotiations() + while False is server.renegotiate_pending(): + pass class ErrorTests(TestCase): |