diff options
author | Alex Chan <alex@alexwlchan.net> | 2017-01-30 07:13:30 +0000 |
---|---|---|
committer | Hynek Schlawack <hs@ox.cx> | 2017-01-30 08:13:30 +0100 |
commit | 1c0cb66f81d747b6349f0e132e369f52a0024efe (patch) | |
tree | 362766c8947dd4c210ba387002ecc4347d251288 /tests/test_ssl.py | |
parent | 7f3914b478e8b4fcd6ed0e68a272649bbb1c627d (diff) | |
download | pyopenssl-git-1c0cb66f81d747b6349f0e132e369f52a0024efe.tar.gz |
Convert the rest of TestConnection to be pytest-style (#594)
Diffstat (limited to 'tests/test_ssl.py')
-rw-r--r-- | tests/test_ssl.py | 663 |
1 files changed, 286 insertions, 377 deletions
diff --git a/tests/test_ssl.py b/tests/test_ssl.py index e0a720b..14b2310 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -229,33 +229,13 @@ class _LoopbackMixin(object): BIOs. """ def _loopbackClientFactory(self, socket): - client = Connection(Context(TLSv1_METHOD), socket) - client.set_connect_state() - return client + return loopback_client_factory(socket) def _loopbackServerFactory(self, socket): - ctx = Context(TLSv1_METHOD) - ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) - ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) - server = Connection(ctx, socket) - server.set_accept_state() - return server + return loopback_server_factory(socket) def _loopback(self, serverFactory=None, clientFactory=None): - if serverFactory is None: - serverFactory = self._loopbackServerFactory - if clientFactory is None: - clientFactory = self._loopbackClientFactory - - (server, client) = socket_pair() - server = serverFactory(server) - client = clientFactory(client) - - handshake(client, server) - - server.setblocking(True) - client.setblocking(True) - return server, client + return loopback(serverFactory, clientFactory) def _interactInMemory(self, client_conn, server_conn): return interact_in_memory(client_conn, server_conn) @@ -264,6 +244,42 @@ class _LoopbackMixin(object): return handshake_in_memory(client_conn, server_conn) +def loopback_client_factory(socket): + client = Connection(Context(TLSv1_METHOD), socket) + client.set_connect_state() + return client + + +def loopback_server_factory(socket): + ctx = Context(TLSv1_METHOD) + ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + server = Connection(ctx, socket) + server.set_accept_state() + return server + + +def loopback(server_factory=None, client_factory=None): + """ + Create a connected socket pair and force two connected SSL sockets + to talk to each other via memory BIOs. + """ + if server_factory is None: + server_factory = loopback_server_factory + if client_factory is None: + client_factory = loopback_client_factory + + (server, client) = socket_pair() + server = server_factory(server) + client = client_factory(client) + + handshake(client, server) + + server.setblocking(True) + client.setblocking(True) + return server, client + + def interact_in_memory(client_conn, server_conn): """ Try to read application bytes from each of the two `Connection` objects. @@ -1956,9 +1972,9 @@ class TestSession(object): assert isinstance(new_session, Session) -class ConnectionTests(TestCase, _LoopbackMixin): +class TestConnection(object): """ - Unit tests for :class:`OpenSSL.SSL.Connection`. + Unit tests for `OpenSSL.SSL.Connection`. """ # XXX get_peer_certificate -> None # XXX sock_shutdown @@ -1976,57 +1992,47 @@ class ConnectionTests(TestCase, _LoopbackMixin): def test_type(self): """ - :py:obj:`Connection` and :py:obj:`ConnectionType` refer to the same - type object and can be used to create instances of that type. + `Connection` and `ConnectionType` refer to the same type object and + can be used to create instances of that type. """ - self.assertIdentical(Connection, ConnectionType) + assert Connection is ConnectionType ctx = Context(TLSv1_METHOD) - self.assertConsistentType(Connection, 'Connection', ctx, None) + assert is_consistent_type(Connection, 'Connection', ctx, None) def test_get_context(self): """ - :py:obj:`Connection.get_context` returns the :py:obj:`Context` instance - used to construct the :py:obj:`Connection` instance. + `Connection.get_context` returns the `Context` instance used to + construct the `Connection` instance. """ context = Context(TLSv1_METHOD) connection = Connection(context, None) - self.assertIdentical(connection.get_context(), context) - - def test_get_context_wrong_args(self): - """ - :py:obj:`Connection.get_context` raises :py:obj:`TypeError` if called - with any arguments. - """ - connection = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(TypeError, connection.get_context, None) + assert connection.get_context() is context def test_set_context_wrong_args(self): """ - :py:obj:`Connection.set_context` raises :py:obj:`TypeError` if called - with a non-:py:obj:`Context` instance argument or with any number of - arguments other than 1. + `Connection.set_context` raises `TypeError` if called with a + non-`Context` instance argument, """ ctx = Context(TLSv1_METHOD) connection = Connection(ctx, None) - self.assertRaises(TypeError, connection.set_context) - self.assertRaises(TypeError, connection.set_context, object()) - self.assertRaises(TypeError, connection.set_context, "hello") - self.assertRaises(TypeError, connection.set_context, 1) - self.assertRaises(TypeError, connection.set_context, 1, 2) - self.assertRaises( - TypeError, connection.set_context, Context(TLSv1_METHOD), 2) - self.assertIdentical(ctx, connection.get_context()) + with pytest.raises(TypeError): + connection.set_context(object()) + with pytest.raises(TypeError): + connection.set_context("hello") + with pytest.raises(TypeError): + connection.set_context(1) + assert ctx is connection.get_context() def test_set_context(self): """ - :py:obj:`Connection.set_context` specifies a new :py:obj:`Context` - instance to be used for the connection. + `Connection.set_context` specifies a new `Context` instance to be + used for the connection. """ original = Context(SSLv23_METHOD) replacement = Context(TLSv1_METHOD) connection = Connection(original, None) connection.set_context(replacement) - self.assertIdentical(replacement, connection.get_context()) + assert replacement is connection.get_context() # Lose our references to the contexts, just in case the Connection # isn't properly managing its own contributions to their reference # counts. @@ -2035,88 +2041,52 @@ class ConnectionTests(TestCase, _LoopbackMixin): def test_set_tlsext_host_name_wrong_args(self): """ - If :py:obj:`Connection.set_tlsext_host_name` is called with a non-byte - string argument or a byte string with an embedded NUL or other than one - argument, :py:obj:`TypeError` is raised. + If `Connection.set_tlsext_host_name` is called with a non-byte string + argument or a byte string with an embedded NUL, `TypeError` is raised. """ conn = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(TypeError, conn.set_tlsext_host_name) - self.assertRaises(TypeError, conn.set_tlsext_host_name, object()) - self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456) - self.assertRaises( - TypeError, conn.set_tlsext_host_name, b"with\0null") + with pytest.raises(TypeError): + conn.set_tlsext_host_name(object()) + with pytest.raises(TypeError): + conn.set_tlsext_host_name(b"with\0null") if PY3: # On Python 3.x, don't accidentally implicitly convert from text. - self.assertRaises( - TypeError, - conn.set_tlsext_host_name, b"example.com".decode("ascii")) - - def test_get_servername_wrong_args(self): - """ - :py:obj:`Connection.get_servername` raises :py:obj:`TypeError` if - called with any arguments. - """ - connection = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(TypeError, connection.get_servername, object()) - self.assertRaises(TypeError, connection.get_servername, 1) - self.assertRaises(TypeError, connection.get_servername, "hello") + with pytest.raises(TypeError): + conn.set_tlsext_host_name(b"example.com".decode("ascii")) def test_pending(self): """ - :py:obj:`Connection.pending` returns the number of bytes available for + `Connection.pending` returns the number of bytes available for immediate read. """ connection = Connection(Context(TLSv1_METHOD), None) - self.assertEquals(connection.pending(), 0) - - def test_pending_wrong_args(self): - """ - :py:obj:`Connection.pending` raises :py:obj:`TypeError` if called with - any arguments. - """ - connection = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(TypeError, connection.pending, None) + assert connection.pending() == 0 def test_peek(self): """ - :py:obj:`Connection.recv` peeks into the connection if - :py:obj:`socket.MSG_PEEK` is passed. + `Connection.recv` peeks into the connection if `socket.MSG_PEEK` + is passed. """ - server, client = self._loopback() + server, client = loopback() server.send(b'xy') - self.assertEqual(client.recv(2, MSG_PEEK), b'xy') - self.assertEqual(client.recv(2, MSG_PEEK), b'xy') - self.assertEqual(client.recv(2), b'xy') + assert client.recv(2, MSG_PEEK) == b'xy' + assert client.recv(2, MSG_PEEK) == b'xy' + assert client.recv(2) == b'xy' def test_connect_wrong_args(self): """ - :py:obj:`Connection.connect` raises :py:obj:`TypeError` if called with - a non-address argument or with the wrong number of arguments. + `Connection.connect` raises `TypeError` if called with + a non-address argument. """ connection = Connection(Context(TLSv1_METHOD), socket()) - self.assertRaises(TypeError, connection.connect, None) - self.assertRaises(TypeError, connection.connect) - self.assertRaises( - TypeError, connection.connect, ("127.0.0.1", 1), None - ) - - def test_connection_undefined_attr(self): - """ - :py:obj:`Connection.connect` raises :py:obj:`TypeError` if called with - a non-address argument or with the wrong number of arguments. - """ - - def attr_access_test(connection): - return connection.an_attribute_which_is_not_defined - - connection = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(AttributeError, attr_access_test, connection) + with pytest.raises(TypeError): + connection.connect(None) def test_connect_refused(self): """ - :py:obj:`Connection.connect` raises :py:obj:`socket.error` if the - underlying socket connect method raises it. + `Connection.connect` raises `socket.error` if the underlying socket + connect method raises it. """ client = socket() context = Context(TLSv1_METHOD) @@ -2131,8 +2101,7 @@ class ConnectionTests(TestCase, _LoopbackMixin): def test_connect(self): """ - :py:obj:`Connection.connect` establishes a connection to the specified - address. + `Connection.connect` establishes a connection to the specified address. """ port = socket() port.bind(('', 0)) @@ -2148,8 +2117,8 @@ class ConnectionTests(TestCase, _LoopbackMixin): ) def test_connect_ex(self): """ - If there is a connection error, :py:obj:`Connection.connect_ex` - returns the errno instead of raising an exception. + If there is a connection error, `Connection.connect_ex` returns the + errno instead of raising an exception. """ port = socket() port.bind(('', 0)) @@ -2159,22 +2128,13 @@ class ConnectionTests(TestCase, _LoopbackMixin): clientSSL.setblocking(False) result = clientSSL.connect_ex(port.getsockname()) expected = (EINPROGRESS, EWOULDBLOCK) - self.assertTrue( - result in expected, "%r not in %r" % (result, expected)) - - def test_accept_wrong_args(self): - """ - :py:obj:`Connection.accept` raises :py:obj:`TypeError` if called with - any arguments. - """ - connection = Connection(Context(TLSv1_METHOD), socket()) - self.assertRaises(TypeError, connection.accept, None) + assert result in expected def test_accept(self): """ - :py:obj:`Connection.accept` accepts a pending connection attempt and - returns a tuple of a new :py:obj:`Connection` (the accepted client) and - the address the connection originated from. + `Connection.accept` accepts a pending connection attempt and returns a + tuple of a new `Connection` (the accepted client) and the address the + connection originated from. """ ctx = Context(TLSv1_METHOD) ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) @@ -2192,58 +2152,53 @@ class ConnectionTests(TestCase, _LoopbackMixin): serverSSL, address = portSSL.accept() - self.assertTrue(isinstance(serverSSL, Connection)) - self.assertIdentical(serverSSL.get_context(), ctx) - self.assertEquals(address, clientSSL.getsockname()) + assert isinstance(serverSSL, Connection) + assert serverSSL.get_context() is ctx + assert address == clientSSL.getsockname() def test_shutdown_wrong_args(self): """ - :py:obj:`Connection.shutdown` raises :py:obj:`TypeError` if called with - the wrong number of arguments or with arguments other than integers. + `Connection.set_shutdown` raises `TypeError` if called with arguments + other than integers. """ connection = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(TypeError, connection.shutdown, None) - self.assertRaises(TypeError, connection.get_shutdown, None) - self.assertRaises(TypeError, connection.set_shutdown) - self.assertRaises(TypeError, connection.set_shutdown, None) - self.assertRaises(TypeError, connection.set_shutdown, 0, 1) + with pytest.raises(TypeError): + connection.set_shutdown(None) def test_shutdown(self): """ - :py:obj:`Connection.shutdown` performs an SSL-level connection - shutdown. + `Connection.shutdown` performs an SSL-level connection shutdown. """ - server, client = self._loopback() - self.assertFalse(server.shutdown()) - self.assertEquals(server.get_shutdown(), SENT_SHUTDOWN) - self.assertRaises(ZeroReturnError, client.recv, 1024) - self.assertEquals(client.get_shutdown(), RECEIVED_SHUTDOWN) + server, client = loopback() + assert not server.shutdown() + assert server.get_shutdown() == SENT_SHUTDOWN + with pytest.raises(ZeroReturnError): + client.recv(1024) + assert client.get_shutdown() == RECEIVED_SHUTDOWN client.shutdown() - self.assertEquals( - client.get_shutdown(), SENT_SHUTDOWN | RECEIVED_SHUTDOWN - ) - self.assertRaises(ZeroReturnError, server.recv, 1024) - self.assertEquals( - server.get_shutdown(), SENT_SHUTDOWN | RECEIVED_SHUTDOWN - ) + assert client.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN) + with pytest.raises(ZeroReturnError): + server.recv(1024) + assert server.get_shutdown() == (SENT_SHUTDOWN | RECEIVED_SHUTDOWN) def test_shutdown_closed(self): """ - If the underlying socket is closed, :py:obj:`Connection.shutdown` - propagates the write error from the low level write call. + If the underlying socket is closed, `Connection.shutdown` propagates + the write error from the low level write call. """ - server, client = self._loopback() + server, client = loopback() server.sock_shutdown(2) - exc = self.assertRaises(SysCallError, server.shutdown) - if platform == "win32": - self.assertEqual(exc.args[0], ESHUTDOWN) - else: - self.assertEqual(exc.args[0], EPIPE) + with pytest.raises(SysCallError) as exc: + server.shutdown() + if platform == "win32": + assert exc.value.args[0] == ESHUTDOWN + else: + assert exc.value.args[0] == EPIPE def test_shutdown_truncated(self): """ - If the underlying connection is truncated, :obj:`Connection.shutdown` - raises an :obj:`Error`. + If the underlying connection is truncated, `Connection.shutdown` + raises an `Error`. """ server_ctx = Context(TLSv1_METHOD) client_ctx = Context(TLSv1_METHOD) @@ -2253,39 +2208,41 @@ class ConnectionTests(TestCase, _LoopbackMixin): load_certificate(FILETYPE_PEM, server_cert_pem)) server = Connection(server_ctx, None) client = Connection(client_ctx, None) - self._handshakeInMemory(client, server) - self.assertEqual(server.shutdown(), False) - self.assertRaises(WantReadError, server.shutdown) + handshake_in_memory(client, server) + assert not server.shutdown() + with pytest.raises(WantReadError): + server.shutdown() server.bio_shutdown() - self.assertRaises(Error, server.shutdown) + with pytest.raises(Error): + server.shutdown() def test_set_shutdown(self): """ - :py:obj:`Connection.set_shutdown` sets the state of the SSL connection + `Connection.set_shutdown` sets the state of the SSL connection shutdown process. """ connection = Connection(Context(TLSv1_METHOD), socket()) connection.set_shutdown(RECEIVED_SHUTDOWN) - self.assertEquals(connection.get_shutdown(), RECEIVED_SHUTDOWN) + assert connection.get_shutdown() == RECEIVED_SHUTDOWN @skip_if_py3 def test_set_shutdown_long(self): """ - On Python 2 :py:obj:`Connection.set_shutdown` accepts an argument - of type :py:obj:`long` as well as :py:obj:`int`. + On Python 2 `Connection.set_shutdown` accepts an argument + of type `long` as well as `int`. """ connection = Connection(Context(TLSv1_METHOD), socket()) connection.set_shutdown(long(RECEIVED_SHUTDOWN)) - self.assertEquals(connection.get_shutdown(), RECEIVED_SHUTDOWN) + assert connection.get_shutdown() == RECEIVED_SHUTDOWN def test_state_string(self): """ - :meth:`Connection.state_string` verbosely describes the current - state of the :class:`Connection`. + `Connection.state_string` verbosely describes the current state of + the `Connection`. """ server, client = socket_pair() - server = self._loopbackServerFactory(server) - client = self._loopbackClientFactory(client) + server = loopback_server_factory(server) + client = loopback_client_factory(client) assert server.get_state_string() in [ b"before/accept initialization", b"before SSL initialization" @@ -2294,22 +2251,11 @@ class ConnectionTests(TestCase, _LoopbackMixin): b"before/connect initialization", b"before SSL initialization" ] - def test_app_data_wrong_args(self): - """ - :py:obj:`Connection.set_app_data` raises :py:obj:`TypeError` if called - with other than one argument. :py:obj:`Connection.get_app_data` raises - :py:obj:`TypeError` if called with any arguments. - """ - conn = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(TypeError, conn.get_app_data, None) - self.assertRaises(TypeError, conn.set_app_data) - self.assertRaises(TypeError, conn.set_app_data, None, None) - def test_app_data(self): """ Any object can be set as app data by passing it to - :py:obj:`Connection.set_app_data` and later retrieved with - :py:obj:`Connection.get_app_data`. + `Connection.set_app_data` and later retrieved with + `Connection.get_app_data`. """ conn = Connection(Context(TLSv1_METHOD), None) assert None is conn.get_app_data() @@ -2319,26 +2265,16 @@ class ConnectionTests(TestCase, _LoopbackMixin): def test_makefile(self): """ - :py:obj:`Connection.makefile` is not implemented and calling that - method raises :py:obj:`NotImplementedError`. + `Connection.makefile` is not implemented and calling that + method raises `NotImplementedError`. """ conn = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(NotImplementedError, conn.makefile) - - def test_get_peer_cert_chain_wrong_args(self): - """ - :py:obj:`Connection.get_peer_cert_chain` raises :py:obj:`TypeError` if - called with any arguments. - """ - conn = Connection(Context(TLSv1_METHOD), None) - self.assertRaises(TypeError, conn.get_peer_cert_chain, 1) - self.assertRaises(TypeError, conn.get_peer_cert_chain, "foo") - self.assertRaises(TypeError, conn.get_peer_cert_chain, object()) - self.assertRaises(TypeError, conn.get_peer_cert_chain, []) + with pytest.raises(NotImplementedError): + conn.makefile() def test_get_peer_cert_chain(self): """ - :py:obj:`Connection.get_peer_cert_chain` returns a list of certificates + `Connection.get_peer_cert_chain` returns a list of certificates which the connected server returned for the certification verification. """ chain = _create_certificate_chain() @@ -2358,21 +2294,18 @@ class ConnectionTests(TestCase, _LoopbackMixin): client = Connection(clientContext, None) client.set_connect_state() - self._interactInMemory(client, server) + interact_in_memory(client, server) chain = client.get_peer_cert_chain() - self.assertEqual(len(chain), 3) - self.assertEqual( - "Server Certificate", chain[0].get_subject().CN) - self.assertEqual( - "Intermediate Certificate", chain[1].get_subject().CN) - self.assertEqual( - "Authority Certificate", chain[2].get_subject().CN) + assert len(chain) == 3 + assert "Server Certificate" == chain[0].get_subject().CN + assert "Intermediate Certificate" == chain[1].get_subject().CN + assert "Authority Certificate" == chain[2].get_subject().CN def test_get_peer_cert_chain_none(self): """ - :py:obj:`Connection.get_peer_cert_chain` returns :py:obj:`None` if the - peer sends no certificate chain. + `Connection.get_peer_cert_chain` returns `None` if the peer sends + no certificate chain. """ ctx = Context(TLSv1_METHOD) ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) @@ -2381,71 +2314,57 @@ class ConnectionTests(TestCase, _LoopbackMixin): server.set_accept_state() client = Connection(Context(TLSv1_METHOD), None) client.set_connect_state() - self._interactInMemory(client, server) - self.assertIdentical(None, server.get_peer_cert_chain()) - - def test_get_session_wrong_args(self): - """ - :py:obj:`Connection.get_session` raises :py:obj:`TypeError` if called - with any arguments. - """ - ctx = Context(TLSv1_METHOD) - server = Connection(ctx, None) - self.assertRaises(TypeError, server.get_session, 123) - self.assertRaises(TypeError, server.get_session, "hello") - self.assertRaises(TypeError, server.get_session, object()) + interact_in_memory(client, server) + assert None is server.get_peer_cert_chain() def test_get_session_unconnected(self): """ - :py:obj:`Connection.get_session` returns :py:obj:`None` when used with - an object which has not been connected. + `Connection.get_session` returns `None` when used with an object + which has not been connected. """ ctx = Context(TLSv1_METHOD) server = Connection(ctx, None) session = server.get_session() - self.assertIdentical(None, session) + assert None is session def test_server_get_session(self): """ - On the server side of a connection, :py:obj:`Connection.get_session` - returns a :py:class:`Session` instance representing the SSL session for - that connection. + On the server side of a connection, `Connection.get_session` returns a + `Session` instance representing the SSL session for that connection. """ - server, client = self._loopback() + server, client = loopback() session = server.get_session() - self.assertIsInstance(session, Session) + assert isinstance(session, Session) def test_client_get_session(self): """ - On the client side of a connection, :py:obj:`Connection.get_session` - returns a :py:class:`Session` instance representing the SSL session for + On the client side of a connection, `Connection.get_session` + returns a `Session` instance representing the SSL session for that connection. """ - server, client = self._loopback() + server, client = loopback() session = client.get_session() - self.assertIsInstance(session, Session) + assert isinstance(session, Session) def test_set_session_wrong_args(self): """ - If called with an object that is not an instance of - :py:class:`Session`, or with other than one argument, - :py:obj:`Connection.set_session` raises :py:obj:`TypeError`. + `Connection.set_session` raises `TypeError` if called with an object + that is not an instance of `Session`. """ ctx = Context(TLSv1_METHOD) connection = Connection(ctx, None) - self.assertRaises(TypeError, connection.set_session) - self.assertRaises(TypeError, connection.set_session, 123) - self.assertRaises(TypeError, connection.set_session, "hello") - self.assertRaises(TypeError, connection.set_session, object()) - self.assertRaises( - TypeError, connection.set_session, Session(), Session()) + with pytest.raises(TypeError): + connection.set_session(123) + with pytest.raises(TypeError): + connection.set_session("hello") + with pytest.raises(TypeError): + connection.set_session(object()) def test_client_set_session(self): """ - :py:obj:`Connection.set_session`, when used prior to a connection being - established, accepts a :py:class:`Session` instance and causes an - attempt to re-use the session it represents when the SSL handshake is - performed. + `Connection.set_session`, when used prior to a connection being + established, accepts a `Session` instance and causes an attempt to + re-use the session it represents when the SSL handshake is performed. """ key = load_privatekey(FILETYPE_PEM, server_key_pem) cert = load_certificate(FILETYPE_PEM, server_cert_pem) @@ -2459,17 +2378,17 @@ class ConnectionTests(TestCase, _LoopbackMixin): server.set_accept_state() return server - originalServer, originalClient = self._loopback( - serverFactory=makeServer) + originalServer, originalClient = loopback( + server_factory=makeServer) originalSession = originalClient.get_session() def makeClient(socket): - client = self._loopbackClientFactory(socket) + client = loopback_client_factory(socket) client.set_session(originalSession) return client - resumedServer, resumedClient = self._loopback( - serverFactory=makeServer, - clientFactory=makeClient) + resumedServer, resumedClient = loopback( + server_factory=makeServer, + client_factory=makeClient) # This is a proxy: in general, we have no access to any unique # identifier for the session (new enough versions of OpenSSL expose @@ -2477,15 +2396,13 @@ class ConnectionTests(TestCase, _LoopbackMixin): # Instead, exploit the fact that the master key is re-used if the # session is re-used. As long as the master key for the two # connections is the same, the session was re-used! - self.assertEqual( - originalServer.master_key(), resumedServer.master_key()) + assert originalServer.master_key() == resumedServer.master_key() def test_set_session_wrong_method(self): """ - If :py:obj:`Connection.set_session` is passed a :py:class:`Session` - instance associated with a context using a different SSL method than - the :py:obj:`Connection` is using, a :py:class:`OpenSSL.SSL.Error` is - raised. + If `Connection.set_session` is passed a `Session` instance associated + with a context using a different SSL method than the `Connection` + is using, a `OpenSSL.SSL.Error` is raised. """ # Make this work on both OpenSSL 1.0.0, which doesn't support TLSv1.2 # and also on OpenSSL 1.1.0 which doesn't support SSLv3. (SSL_ST_INIT @@ -2514,8 +2431,8 @@ class ConnectionTests(TestCase, _LoopbackMixin): client.set_connect_state() return client - originalServer, originalClient = self._loopback( - serverFactory=makeServer, clientFactory=makeOriginalClient) + originalServer, originalClient = loopback( + server_factory=makeServer, client_factory=makeOriginalClient) originalSession = originalClient.get_session() def makeClient(socket): @@ -2525,14 +2442,13 @@ class ConnectionTests(TestCase, _LoopbackMixin): client.set_session(originalSession) return client - self.assertRaises( - Error, - self._loopback, clientFactory=makeClient, serverFactory=makeServer) + with pytest.raises(Error): + loopback(client_factory=makeClient, server_factory=makeServer) def test_wantWriteError(self): """ - :py:obj:`Connection` methods which generate output raise - :py:obj:`OpenSSL.SSL.WantWriteError` if writing to the connection's BIO + `Connection` methods which generate output raise + `OpenSSL.SSL.WantWriteError` if writing to the connection's BIO fail indicating a should-write state. """ client_socket, server_socket = socket_pair() @@ -2551,57 +2467,57 @@ class ConnectionTests(TestCase, _LoopbackMixin): break raise else: - self.fail( + pytest.fail( "Failed to fill socket buffer, cannot test BIO want write") ctx = Context(TLSv1_METHOD) conn = Connection(ctx, client_socket) # Client's speak first, so make it an SSL client conn.set_connect_state() - self.assertRaises(WantWriteError, conn.do_handshake) + with pytest.raises(WantWriteError): + conn.do_handshake() # XXX want_read def test_get_finished_before_connect(self): """ - :py:obj:`Connection.get_finished` returns :py:obj:`None` before TLS - handshake is completed. + `Connection.get_finished` returns `None` before TLS handshake + is completed. """ ctx = Context(TLSv1_METHOD) connection = Connection(ctx, None) - self.assertEqual(connection.get_finished(), None) + assert connection.get_finished() is None def test_get_peer_finished_before_connect(self): """ - :py:obj:`Connection.get_peer_finished` returns :py:obj:`None` before - TLS handshake is completed. + `Connection.get_peer_finished` returns `None` before TLS handshake + is completed. """ ctx = Context(TLSv1_METHOD) connection = Connection(ctx, None) - self.assertEqual(connection.get_peer_finished(), None) + assert connection.get_peer_finished() is None def test_get_finished(self): """ - :py:obj:`Connection.get_finished` method returns the TLS Finished - message send from client, or server. Finished messages are send during + `Connection.get_finished` method returns the TLS Finished message send + from client, or server. Finished messages are send during TLS handshake. """ + server, client = loopback() - server, client = self._loopback() - - self.assertNotEqual(server.get_finished(), None) - self.assertTrue(len(server.get_finished()) > 0) + assert server.get_finished() is not None + assert len(server.get_finished()) > 0 def test_get_peer_finished(self): """ - :py:obj:`Connection.get_peer_finished` method returns the TLS Finished + `Connection.get_peer_finished` method returns the TLS Finished message received from client, or server. Finished messages are send during TLS handshake. """ - server, client = self._loopback() + server, client = loopback() - self.assertNotEqual(server.get_peer_finished(), None) - self.assertTrue(len(server.get_peer_finished()) > 0) + assert server.get_peer_finished() is not None + assert len(server.get_peer_finished()) > 0 def test_tls_finished_message_symmetry(self): """ @@ -2611,109 +2527,148 @@ class ConnectionTests(TestCase, _LoopbackMixin): The TLS Finished message send by client must be the TLS Finished message received by server. """ - server, client = self._loopback() + server, client = loopback() - self.assertEqual(server.get_finished(), client.get_peer_finished()) - self.assertEqual(client.get_finished(), server.get_peer_finished()) + assert server.get_finished() == client.get_peer_finished() + assert client.get_finished() == server.get_peer_finished() def test_get_cipher_name_before_connect(self): """ - :py:obj:`Connection.get_cipher_name` returns :py:obj:`None` if no - connection has been established. + `Connection.get_cipher_name` returns `None` if no connection + has been established. """ ctx = Context(TLSv1_METHOD) conn = Connection(ctx, None) - self.assertIdentical(conn.get_cipher_name(), None) + assert conn.get_cipher_name() is None def test_get_cipher_name(self): """ - :py:obj:`Connection.get_cipher_name` returns a :py:class:`unicode` - string giving the name of the currently used cipher. + `Connection.get_cipher_name` returns a `unicode` string giving the + name of the currently used cipher. """ - server, client = self._loopback() + server, client = loopback() server_cipher_name, client_cipher_name = \ server.get_cipher_name(), client.get_cipher_name() - self.assertIsInstance(server_cipher_name, text_type) - self.assertIsInstance(client_cipher_name, text_type) + assert isinstance(server_cipher_name, text_type) + assert isinstance(client_cipher_name, text_type) - self.assertEqual(server_cipher_name, client_cipher_name) + assert server_cipher_name == client_cipher_name def test_get_cipher_version_before_connect(self): """ - :py:obj:`Connection.get_cipher_version` returns :py:obj:`None` if no - connection has been established. + `Connection.get_cipher_version` returns `None` if no connection + has been established. """ ctx = Context(TLSv1_METHOD) conn = Connection(ctx, None) - self.assertIdentical(conn.get_cipher_version(), None) + assert conn.get_cipher_version() is None def test_get_cipher_version(self): """ - :py:obj:`Connection.get_cipher_version` returns a :py:class:`unicode` - string giving the protocol name of the currently used cipher. + `Connection.get_cipher_version` returns a `unicode` string giving + the protocol name of the currently used cipher. """ - server, client = self._loopback() + server, client = loopback() server_cipher_version, client_cipher_version = \ server.get_cipher_version(), client.get_cipher_version() - self.assertIsInstance(server_cipher_version, text_type) - self.assertIsInstance(client_cipher_version, text_type) + assert isinstance(server_cipher_version, text_type) + assert isinstance(client_cipher_version, text_type) - self.assertEqual(server_cipher_version, client_cipher_version) + assert server_cipher_version == client_cipher_version def test_get_cipher_bits_before_connect(self): """ - :py:obj:`Connection.get_cipher_bits` returns :py:obj:`None` if no - connection has been established. + `Connection.get_cipher_bits` returns `None` if no connection has + been established. """ ctx = Context(TLSv1_METHOD) conn = Connection(ctx, None) - self.assertIdentical(conn.get_cipher_bits(), None) + assert conn.get_cipher_bits() is None def test_get_cipher_bits(self): """ - :py:obj:`Connection.get_cipher_bits` returns the number of secret bits + `Connection.get_cipher_bits` returns the number of secret bits of the currently used cipher. """ - server, client = self._loopback() + server, client = loopback() server_cipher_bits, client_cipher_bits = \ server.get_cipher_bits(), client.get_cipher_bits() - self.assertIsInstance(server_cipher_bits, int) - self.assertIsInstance(client_cipher_bits, int) + assert isinstance(server_cipher_bits, int) + assert isinstance(client_cipher_bits, int) - self.assertEqual(server_cipher_bits, client_cipher_bits) + assert server_cipher_bits == client_cipher_bits def test_get_protocol_version_name(self): """ - :py:obj:`Connection.get_protocol_version_name()` returns a string - giving the protocol version of the current connection. + `Connection.get_protocol_version_name()` returns a string giving the + protocol version of the current connection. """ - server, client = self._loopback() + server, client = loopback() client_protocol_version_name = client.get_protocol_version_name() server_protocol_version_name = server.get_protocol_version_name() - self.assertIsInstance(server_protocol_version_name, text_type) - self.assertIsInstance(client_protocol_version_name, text_type) + assert isinstance(server_protocol_version_name, text_type) + assert isinstance(client_protocol_version_name, text_type) - self.assertEqual( - server_protocol_version_name, client_protocol_version_name - ) + assert server_protocol_version_name == client_protocol_version_name def test_get_protocol_version(self): """ - :py:obj:`Connection.get_protocol_version()` returns an integer + `Connection.get_protocol_version()` returns an integer giving the protocol version of the current connection. """ - server, client = self._loopback() + server, client = loopback() client_protocol_version = client.get_protocol_version() server_protocol_version = server.get_protocol_version() - self.assertIsInstance(server_protocol_version, int) - self.assertIsInstance(client_protocol_version, int) + assert isinstance(server_protocol_version, int) + assert isinstance(client_protocol_version, int) - self.assertEqual(server_protocol_version, client_protocol_version) + assert server_protocol_version == client_protocol_version + + def test_wantReadError(self): + """ + `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are + no bytes available to be read from the BIO. + """ + ctx = Context(TLSv1_METHOD) + conn = Connection(ctx, None) + with pytest.raises(WantReadError): + conn.bio_read(1024) + + def test_buffer_size(self): + """ + `Connection.bio_read` accepts an integer giving the maximum number + of bytes to read and return. + """ + ctx = Context(TLSv1_METHOD) + conn = Connection(ctx, None) + conn.set_connect_state() + try: + conn.do_handshake() + except WantReadError: + pass + data = conn.bio_read(2) + assert 2 == len(data) + + @skip_if_py3 + def test_buffer_size_long(self): + """ + On Python 2 `Connection.bio_read` accepts values of type `long` as + well as `int`. + """ + ctx = Context(TLSv1_METHOD) + conn = Connection(ctx, None) + conn.set_connect_state() + try: + conn.do_handshake() + except WantReadError: + pass + data = conn.bio_read(long(2)) + assert 2 == len(data) class ConnectionGetCipherListTests(TestCase): @@ -3618,52 +3573,6 @@ class MemoryBIOTests(TestCase, _LoopbackMixin): self._check_client_ca_list(set_replaces_add_ca) -class TestConnection(object): - """ - Tests for `Connection.bio_read` and `Connection.bio_write`. - """ - def test_wantReadError(self): - """ - `Connection.bio_read` raises `OpenSSL.SSL.WantReadError` if there are - no bytes available to be read from the BIO. - """ - ctx = Context(TLSv1_METHOD) - conn = Connection(ctx, None) - with pytest.raises(WantReadError): - conn.bio_read(1024) - - def test_buffer_size(self): - """ - `Connection.bio_read` accepts an integer giving the maximum number - of bytes to read and return. - """ - ctx = Context(TLSv1_METHOD) - conn = Connection(ctx, None) - conn.set_connect_state() - try: - conn.do_handshake() - except WantReadError: - pass - data = conn.bio_read(2) - assert 2 == len(data) - - @skip_if_py3 - def test_buffer_size_long(self): - """ - On Python 2 `Connection.bio_read` accepts values of type `long` as - well as `int`. - """ - ctx = Context(TLSv1_METHOD) - conn = Connection(ctx, None) - conn.set_connect_state() - try: - conn.do_handshake() - except WantReadError: - pass - data = conn.bio_read(long(2)) - assert 2 == len(data) - - class InfoConstantTests(TestCase): """ Tests for assorted constants exposed for use in info callbacks. |