diff options
-rwxr-xr-x | OpenSSL/ssl/connection.c | 49 | ||||
-rw-r--r-- | OpenSSL/ssl/context.c | 71 | ||||
-rw-r--r-- | OpenSSL/ssl/context.h | 1 | ||||
-rw-r--r-- | OpenSSL/test/test_ssl.py | 127 | ||||
-rw-r--r-- | doc/pyOpenSSL.tex | 16 |
5 files changed, 261 insertions, 3 deletions
diff --git a/OpenSSL/ssl/connection.c b/OpenSSL/ssl/connection.c index a8dfa58..68dbe7e 100755 --- a/OpenSSL/ssl/connection.c +++ b/OpenSSL/ssl/connection.c @@ -301,6 +301,53 @@ ssl_Connection_set_context(ssl_ConnectionObj *self, PyObject *args) { return Py_None; } +static char ssl_Connection_get_servername_doc[] = "\n\ +Retrieve the servername extension value if provided in the client hello\n\ +message, or None if there wasn't one.\n\ +\n\ +@return: A byte string giving the server name or C{None}.\n\ +\n\ +"; +static PyObject * +ssl_Connection_get_servername(ssl_ConnectionObj *self, PyObject *args) { + int type = TLSEXT_NAMETYPE_host_name; + const char *name; + + /* XXX Argument parsing */ + + name = SSL_get_servername(self->ssl, type); + + if (name == NULL) { + Py_INCREF(Py_None); + return Py_None; + } else { + return PyBytes_FromString(name); + } +} + + +static char ssl_Connection_set_tlsext_host_name_doc[] = "\n\ +Set the value of the servername extension to send in the client hello.\n\ +\n\ +@param name: A byte string giving the name.\n\ +\n\ +"; +static PyObject * +ssl_Connection_set_tlsext_host_name(ssl_ConnectionObj *self, PyObject *args) { + char *buf; + + if (!PyArg_ParseTuple(args, BYTESTRING_FMT ":set_tlsext_host_name", &buf)) { + return NULL; + } + + /* XXX I guess this can fail sometimes? */ + SSL_set_tlsext_host_name(self->ssl, buf); + + Py_INCREF(Py_None); + return Py_None; +} + + static char ssl_Connection_pending_doc[] = "\n\ Get the number of bytes that can be safely read from the connection\n\ @@ -1221,6 +1268,8 @@ static PyMethodDef ssl_Connection_methods[] = { ADD_METHOD(get_context), ADD_METHOD(set_context), + ADD_METHOD(get_servername), + ADD_METHOD(set_tlsext_host_name), ADD_METHOD(pending), ADD_METHOD(send), ADD_ALIAS (write, send), diff --git a/OpenSSL/ssl/context.c b/OpenSSL/ssl/context.c index f178eec..c2bdcab 100644 --- a/OpenSSL/ssl/context.c +++ b/OpenSSL/ssl/context.c @@ -238,6 +238,45 @@ global_info_callback(const SSL *ssl, int where, int _ret) } /* + * Globally defined TLS extension server name callback. This is called from + * OpenSSL internally. The GIL will not be held when this function is invoked. + * It must not be held when the function returns. + * + * ssl represents the connection this callback is for + * + * alert is a pointer to the alert value which maybe will be emitted to the + * client if there is an error handling the client hello (which contains the + * server name). This is an out parameter, maybe. + * + * arg is an arbitrary pointer specified by SSL_CTX_set_tlsext_servername_arg. + * It will be NULL for all pyOpenSSL uses. + */ +static int +global_tlsext_servername_callback(const SSL *ssl, int *alert, void *arg) { + int result = 0; + PyObject *argv, *ret; + ssl_ConnectionObj *conn = (ssl_ConnectionObj *)SSL_get_app_data(ssl); + + /* + * GIL isn't held yet. First things first - acquire it, or any Python API + * we invoke might segfault or blow up the sun. The reverse will be done + * before returning. + */ + MY_END_ALLOW_THREADS(conn->tstate); + + argv = Py_BuildValue("(O)", (PyObject *)conn); + ret = PyEval_CallObject(conn->context->tlsext_servername_callback, argv); + Py_DECREF(argv); + Py_DECREF(ret); + + /* + * This function is returning into OpenSSL. Release the GIL again. + */ + MY_BEGIN_ALLOW_THREADS(conn->tstate); + return result; +} + +/* * More recent builds of OpenSSL may have SSLv2 completely disabled. */ #ifdef OPENSSL_NO_SSL2 @@ -1069,6 +1108,34 @@ ssl_Context_set_options(ssl_ContextObj *self, PyObject *args) return PyLong_FromLong(SSL_CTX_set_options(self->ctx, options)); } +static char ssl_Context_set_tlsext_servername_callback_doc[] = "\n\ +Specify a callback function to be called when clients specify a server name.\n\ +\n\ +@param callback: The callback function. It will be invoked with one\n\ + argument, the Connection instance.\n\ +\n\ +"; +static PyObject * +ssl_Context_set_tlsext_servername_callback(ssl_ContextObj *self, PyObject *args) { + PyObject *callback; + PyObject *old; + + if (!PyArg_ParseTuple(args, "O:set_tlsext_servername_callback", &callback)) { + return NULL; + } + + Py_INCREF(callback); + old = self->tlsext_servername_callback; + self->tlsext_servername_callback = callback; + Py_DECREF(old); + + SSL_CTX_set_tlsext_servername_callback(self->ctx, global_tlsext_servername_callback); + SSL_CTX_set_tlsext_servername_arg(self->ctx, NULL); + + Py_INCREF(Py_None); + return Py_None; +} + /* * Member methods in the Context object @@ -1107,6 +1174,7 @@ static PyMethodDef ssl_Context_methods[] = { ADD_METHOD(set_app_data), ADD_METHOD(get_cert_store), ADD_METHOD(set_options), + ADD_METHOD(set_tlsext_servername_callback), { NULL, NULL } }; #undef ADD_METHOD @@ -1155,6 +1223,9 @@ ssl_Context_init(ssl_ContextObj *self, int i_method) { self->info_callback = Py_None; Py_INCREF(Py_None); + self->tlsext_servername_callback = Py_None; + + Py_INCREF(Py_None); self->passphrase_userdata = Py_None; Py_INCREF(Py_None); diff --git a/OpenSSL/ssl/context.h b/OpenSSL/ssl/context.h index 21407f3..19b5e9e 100644 --- a/OpenSSL/ssl/context.h +++ b/OpenSSL/ssl/context.h @@ -29,6 +29,7 @@ typedef struct { *passphrase_userdata, *verify_callback, *info_callback, + *tlsext_servername_callback, *app_data; PyThreadState *tstate; } ssl_ContextObj; diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index 24a08b0..52ff818 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -5,12 +5,14 @@ Unit tests for L{OpenSSL.SSL}. """ +from gc import collect from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK -from sys import platform +from sys import platform, version_info from socket import error, socket from os import makedirs from os.path import join from unittest import main +from weakref import ref from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM from OpenSSL.crypto import PKey, X509, X509Extension @@ -873,6 +875,106 @@ class ContextTests(TestCase, _LoopbackMixin): +class ServerNameCallbackTests(TestCase, _LoopbackMixin): + """ + Tests for L{Context.set_tlsext_servername_callback} and its interaction with + L{Connection}. + """ + def test_wrong_args(self): + """ + L{Context.set_tlsext_servername_callback} raises L{TypeError} if called + with other than one argument. + """ + context = Context(TLSv1_METHOD) + self.assertRaises(TypeError, context.set_tlsext_servername_callback) + self.assertRaises( + TypeError, context.set_tlsext_servername_callback, 1, 2) + + def test_old_callback_forgotten(self): + """ + If L{Context.set_tlsext_servername_callback} is used to specify a new + callback, the one it replaces is dereferenced. + """ + def callback(connection): + pass + + def replacement(connection): + pass + + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(callback) + + tracker = ref(callback) + del callback + + context.set_tlsext_servername_callback(replacement) + collect() + self.assertIdentical(None, tracker()) + + + def test_no_servername(self): + """ + When a client specifies no server name, the callback passed to + L{Context.set_tlsext_servername_callback} is invoked and the result of + L{Connection.get_servername} is C{None}. + """ + args = [] + def servername(conn): + args.append((conn, conn.get_servername())) + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(servername) + + # Lose our reference to it. The Context is responsible for keeping it + # alive now. + del servername + collect() + + # Necessary to actually accept the connection + context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(context, None) + server.set_accept_state() + + client = Connection(Context(TLSv1_METHOD), None) + client.set_connect_state() + + self._interactInMemory(server, client) + + self.assertEqual([(server, None)], args) + + + def test_servername(self): + """ + When a client specifies a server name in its hello message, the callback + passed to L{Contexts.set_tlsext_servername_callback} is invoked and the + result of L{Connection.get_servername} is that server name. + """ + args = [] + def servername(conn): + args.append((conn, conn.get_servername())) + context = Context(TLSv1_METHOD) + context.set_tlsext_servername_callback(servername) + + # Necessary to actually accept the connection + context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem)) + context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem)) + + # Do a little connection to trigger the logic + server = Connection(context, None) + server.set_accept_state() + + client = Connection(Context(TLSv1_METHOD), None) + client.set_connect_state() + client.set_tlsext_host_name(b("foo1.example.com")) + + self._interactInMemory(server, client) + + self.assertEqual([(server, b("foo1.example.com"))], args) + + + class ConnectionTests(TestCase, _LoopbackMixin): """ Unit tests for L{OpenSSL.SSL.Connection}. @@ -955,8 +1057,27 @@ class ConnectionTests(TestCase, _LoopbackMixin): # Lose our references to the contexts, just in case the Connection isn't # properly managing its own contributions to their reference counts. del original, replacement - import gc - gc.collect() + collect() + + + def test_set_tlsext_host_name_wrong_args(self): + """ + If L{Connection.set_tlsext_host_name} is called with a non-byte string + argument or a byte string with an embedded NUL or other than one + argument, L{TypeError} is raised. + """ + conn = Connection(Context(TLSv1_METHOD), None) + self.assertRaises(TypeError, conn.set_tlsext_host_name) + self.assertRaises(TypeError, conn.set_tlsext_host_name, object()) + self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456) + self.assertRaises( + TypeError, conn.set_tlsext_host_name, b("with\0null")) + + if version_info >= (3,): + # On Python 3.x, don't accidentally implicitly convert from text. + self.assertRaises( + TypeError, + conn.set_tlsext_host_name, b("example.com").decode("ascii")) def test_pending(self): diff --git a/doc/pyOpenSSL.tex b/doc/pyOpenSSL.tex index 8ea37c4..4e00c14 100644 --- a/doc/pyOpenSSL.tex +++ b/doc/pyOpenSSL.tex @@ -1122,6 +1122,12 @@ format specified by \var{format}, which is either \constant{FILETYPE_PEM} or \constant{FILETYPE_ASN1}. The default is \constant{FILETYPE_PEM}. \end{methoddesc} +\begin{methoddesc}[Context]{set_tlsext_servername_callback}{callback} +Specify a one-argument callable to use as the TLS extension server name +callback. When a connection using the server name extension is made using this +context, the callback will be invoked with the \code{Connection} instance. +\versionadded{0.13} +\end{methoddesc} \subsubsection{Connection objects \label{openssl-connection}} @@ -1338,6 +1344,16 @@ Checks if there is data to write to the transport layer to complete an operation. \end{methoddesc} +\begin{methoddesc}[Connection]{set_tlsext_host_name}{name} +Specify the byte string to send as the server name in the client hello message. +\versionadded{0.13} +\end{methoddesc} + +\begin{methoddesc}[Connection]{get_servername}{} +Get the value of the server name received in the client hello message. +\versionadded{0.13} +\end{methoddesc} + \section{Internals \label{internals}} |