summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJean-Paul Calderone <exarkun@divmod.com>2011-05-26 18:47:00 -0400
committerJean-Paul Calderone <exarkun@divmod.com>2011-05-26 18:47:00 -0400
commitc4cb658c516e20cad3e8707ba66dab78ce0bd1e8 (patch)
tree4495cb872f64d0d667d5f571a6f3efc23e3fe319
parent95613b7e9461a34db446501ce565c41feb7f5c6d (diff)
downloadpyopenssl-c4cb658c516e20cad3e8707ba66dab78ce0bd1e8.tar.gz
And SSL_get_servername, SSL_set_tlsext_host_name, and SSL_CTX_set_tlsext_servername_callback
-rwxr-xr-xOpenSSL/ssl/connection.c49
-rw-r--r--OpenSSL/ssl/context.c71
-rw-r--r--OpenSSL/ssl/context.h1
-rw-r--r--OpenSSL/test/test_ssl.py127
-rw-r--r--doc/pyOpenSSL.tex16
5 files changed, 261 insertions, 3 deletions
diff --git a/OpenSSL/ssl/connection.c b/OpenSSL/ssl/connection.c
index a8dfa58..68dbe7e 100755
--- a/OpenSSL/ssl/connection.c
+++ b/OpenSSL/ssl/connection.c
@@ -301,6 +301,53 @@ ssl_Connection_set_context(ssl_ConnectionObj *self, PyObject *args) {
return Py_None;
}
+static char ssl_Connection_get_servername_doc[] = "\n\
+Retrieve the servername extension value if provided in the client hello\n\
+message, or None if there wasn't one.\n\
+\n\
+@return: A byte string giving the server name or C{None}.\n\
+\n\
+";
+static PyObject *
+ssl_Connection_get_servername(ssl_ConnectionObj *self, PyObject *args) {
+ int type = TLSEXT_NAMETYPE_host_name;
+ const char *name;
+
+ /* XXX Argument parsing */
+
+ name = SSL_get_servername(self->ssl, type);
+
+ if (name == NULL) {
+ Py_INCREF(Py_None);
+ return Py_None;
+ } else {
+ return PyBytes_FromString(name);
+ }
+}
+
+
+static char ssl_Connection_set_tlsext_host_name_doc[] = "\n\
+Set the value of the servername extension to send in the client hello.\n\
+\n\
+@param name: A byte string giving the name.\n\
+\n\
+";
+static PyObject *
+ssl_Connection_set_tlsext_host_name(ssl_ConnectionObj *self, PyObject *args) {
+ char *buf;
+
+ if (!PyArg_ParseTuple(args, BYTESTRING_FMT ":set_tlsext_host_name", &buf)) {
+ return NULL;
+ }
+
+ /* XXX I guess this can fail sometimes? */
+ SSL_set_tlsext_host_name(self->ssl, buf);
+
+ Py_INCREF(Py_None);
+ return Py_None;
+}
+
+
static char ssl_Connection_pending_doc[] = "\n\
Get the number of bytes that can be safely read from the connection\n\
@@ -1221,6 +1268,8 @@ static PyMethodDef ssl_Connection_methods[] =
{
ADD_METHOD(get_context),
ADD_METHOD(set_context),
+ ADD_METHOD(get_servername),
+ ADD_METHOD(set_tlsext_host_name),
ADD_METHOD(pending),
ADD_METHOD(send),
ADD_ALIAS (write, send),
diff --git a/OpenSSL/ssl/context.c b/OpenSSL/ssl/context.c
index f178eec..c2bdcab 100644
--- a/OpenSSL/ssl/context.c
+++ b/OpenSSL/ssl/context.c
@@ -238,6 +238,45 @@ global_info_callback(const SSL *ssl, int where, int _ret)
}
/*
+ * Globally defined TLS extension server name callback. This is called from
+ * OpenSSL internally. The GIL will not be held when this function is invoked.
+ * It must not be held when the function returns.
+ *
+ * ssl represents the connection this callback is for
+ *
+ * alert is a pointer to the alert value which maybe will be emitted to the
+ * client if there is an error handling the client hello (which contains the
+ * server name). This is an out parameter, maybe.
+ *
+ * arg is an arbitrary pointer specified by SSL_CTX_set_tlsext_servername_arg.
+ * It will be NULL for all pyOpenSSL uses.
+ */
+static int
+global_tlsext_servername_callback(const SSL *ssl, int *alert, void *arg) {
+ int result = 0;
+ PyObject *argv, *ret;
+ ssl_ConnectionObj *conn = (ssl_ConnectionObj *)SSL_get_app_data(ssl);
+
+ /*
+ * GIL isn't held yet. First things first - acquire it, or any Python API
+ * we invoke might segfault or blow up the sun. The reverse will be done
+ * before returning.
+ */
+ MY_END_ALLOW_THREADS(conn->tstate);
+
+ argv = Py_BuildValue("(O)", (PyObject *)conn);
+ ret = PyEval_CallObject(conn->context->tlsext_servername_callback, argv);
+ Py_DECREF(argv);
+ Py_DECREF(ret);
+
+ /*
+ * This function is returning into OpenSSL. Release the GIL again.
+ */
+ MY_BEGIN_ALLOW_THREADS(conn->tstate);
+ return result;
+}
+
+/*
* More recent builds of OpenSSL may have SSLv2 completely disabled.
*/
#ifdef OPENSSL_NO_SSL2
@@ -1069,6 +1108,34 @@ ssl_Context_set_options(ssl_ContextObj *self, PyObject *args)
return PyLong_FromLong(SSL_CTX_set_options(self->ctx, options));
}
+static char ssl_Context_set_tlsext_servername_callback_doc[] = "\n\
+Specify a callback function to be called when clients specify a server name.\n\
+\n\
+@param callback: The callback function. It will be invoked with one\n\
+ argument, the Connection instance.\n\
+\n\
+";
+static PyObject *
+ssl_Context_set_tlsext_servername_callback(ssl_ContextObj *self, PyObject *args) {
+ PyObject *callback;
+ PyObject *old;
+
+ if (!PyArg_ParseTuple(args, "O:set_tlsext_servername_callback", &callback)) {
+ return NULL;
+ }
+
+ Py_INCREF(callback);
+ old = self->tlsext_servername_callback;
+ self->tlsext_servername_callback = callback;
+ Py_DECREF(old);
+
+ SSL_CTX_set_tlsext_servername_callback(self->ctx, global_tlsext_servername_callback);
+ SSL_CTX_set_tlsext_servername_arg(self->ctx, NULL);
+
+ Py_INCREF(Py_None);
+ return Py_None;
+}
+
/*
* Member methods in the Context object
@@ -1107,6 +1174,7 @@ static PyMethodDef ssl_Context_methods[] = {
ADD_METHOD(set_app_data),
ADD_METHOD(get_cert_store),
ADD_METHOD(set_options),
+ ADD_METHOD(set_tlsext_servername_callback),
{ NULL, NULL }
};
#undef ADD_METHOD
@@ -1155,6 +1223,9 @@ ssl_Context_init(ssl_ContextObj *self, int i_method) {
self->info_callback = Py_None;
Py_INCREF(Py_None);
+ self->tlsext_servername_callback = Py_None;
+
+ Py_INCREF(Py_None);
self->passphrase_userdata = Py_None;
Py_INCREF(Py_None);
diff --git a/OpenSSL/ssl/context.h b/OpenSSL/ssl/context.h
index 21407f3..19b5e9e 100644
--- a/OpenSSL/ssl/context.h
+++ b/OpenSSL/ssl/context.h
@@ -29,6 +29,7 @@ typedef struct {
*passphrase_userdata,
*verify_callback,
*info_callback,
+ *tlsext_servername_callback,
*app_data;
PyThreadState *tstate;
} ssl_ContextObj;
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 24a08b0..52ff818 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -5,12 +5,14 @@
Unit tests for L{OpenSSL.SSL}.
"""
+from gc import collect
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK
-from sys import platform
+from sys import platform, version_info
from socket import error, socket
from os import makedirs
from os.path import join
from unittest import main
+from weakref import ref
from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM
from OpenSSL.crypto import PKey, X509, X509Extension
@@ -873,6 +875,106 @@ class ContextTests(TestCase, _LoopbackMixin):
+class ServerNameCallbackTests(TestCase, _LoopbackMixin):
+ """
+ Tests for L{Context.set_tlsext_servername_callback} and its interaction with
+ L{Connection}.
+ """
+ def test_wrong_args(self):
+ """
+ L{Context.set_tlsext_servername_callback} raises L{TypeError} if called
+ with other than one argument.
+ """
+ context = Context(TLSv1_METHOD)
+ self.assertRaises(TypeError, context.set_tlsext_servername_callback)
+ self.assertRaises(
+ TypeError, context.set_tlsext_servername_callback, 1, 2)
+
+ def test_old_callback_forgotten(self):
+ """
+ If L{Context.set_tlsext_servername_callback} is used to specify a new
+ callback, the one it replaces is dereferenced.
+ """
+ def callback(connection):
+ pass
+
+ def replacement(connection):
+ pass
+
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(callback)
+
+ tracker = ref(callback)
+ del callback
+
+ context.set_tlsext_servername_callback(replacement)
+ collect()
+ self.assertIdentical(None, tracker())
+
+
+ def test_no_servername(self):
+ """
+ When a client specifies no server name, the callback passed to
+ L{Context.set_tlsext_servername_callback} is invoked and the result of
+ L{Connection.get_servername} is C{None}.
+ """
+ args = []
+ def servername(conn):
+ args.append((conn, conn.get_servername()))
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(servername)
+
+ # Lose our reference to it. The Context is responsible for keeping it
+ # alive now.
+ del servername
+ collect()
+
+ # Necessary to actually accept the connection
+ context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(context, None)
+ server.set_accept_state()
+
+ client = Connection(Context(TLSv1_METHOD), None)
+ client.set_connect_state()
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server, None)], args)
+
+
+ def test_servername(self):
+ """
+ When a client specifies a server name in its hello message, the callback
+ passed to L{Contexts.set_tlsext_servername_callback} is invoked and the
+ result of L{Connection.get_servername} is that server name.
+ """
+ args = []
+ def servername(conn):
+ args.append((conn, conn.get_servername()))
+ context = Context(TLSv1_METHOD)
+ context.set_tlsext_servername_callback(servername)
+
+ # Necessary to actually accept the connection
+ context.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ context.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(context, None)
+ server.set_accept_state()
+
+ client = Connection(Context(TLSv1_METHOD), None)
+ client.set_connect_state()
+ client.set_tlsext_host_name(b("foo1.example.com"))
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server, b("foo1.example.com"))], args)
+
+
+
class ConnectionTests(TestCase, _LoopbackMixin):
"""
Unit tests for L{OpenSSL.SSL.Connection}.
@@ -955,8 +1057,27 @@ class ConnectionTests(TestCase, _LoopbackMixin):
# Lose our references to the contexts, just in case the Connection isn't
# properly managing its own contributions to their reference counts.
del original, replacement
- import gc
- gc.collect()
+ collect()
+
+
+ def test_set_tlsext_host_name_wrong_args(self):
+ """
+ If L{Connection.set_tlsext_host_name} is called with a non-byte string
+ argument or a byte string with an embedded NUL or other than one
+ argument, L{TypeError} is raised.
+ """
+ conn = Connection(Context(TLSv1_METHOD), None)
+ self.assertRaises(TypeError, conn.set_tlsext_host_name)
+ self.assertRaises(TypeError, conn.set_tlsext_host_name, object())
+ self.assertRaises(TypeError, conn.set_tlsext_host_name, 123, 456)
+ self.assertRaises(
+ TypeError, conn.set_tlsext_host_name, b("with\0null"))
+
+ if version_info >= (3,):
+ # On Python 3.x, don't accidentally implicitly convert from text.
+ self.assertRaises(
+ TypeError,
+ conn.set_tlsext_host_name, b("example.com").decode("ascii"))
def test_pending(self):
diff --git a/doc/pyOpenSSL.tex b/doc/pyOpenSSL.tex
index 8ea37c4..4e00c14 100644
--- a/doc/pyOpenSSL.tex
+++ b/doc/pyOpenSSL.tex
@@ -1122,6 +1122,12 @@ format specified by \var{format}, which is either \constant{FILETYPE_PEM} or
\constant{FILETYPE_ASN1}. The default is \constant{FILETYPE_PEM}.
\end{methoddesc}
+\begin{methoddesc}[Context]{set_tlsext_servername_callback}{callback}
+Specify a one-argument callable to use as the TLS extension server name
+callback. When a connection using the server name extension is made using this
+context, the callback will be invoked with the \code{Connection} instance.
+\versionadded{0.13}
+\end{methoddesc}
\subsubsection{Connection objects \label{openssl-connection}}
@@ -1338,6 +1344,16 @@ Checks if there is data to write to the transport layer to complete an
operation.
\end{methoddesc}
+\begin{methoddesc}[Connection]{set_tlsext_host_name}{name}
+Specify the byte string to send as the server name in the client hello message.
+\versionadded{0.13}
+\end{methoddesc}
+
+\begin{methoddesc}[Connection]{get_servername}{}
+Get the value of the server name received in the client hello message.
+\versionadded{0.13}
+\end{methoddesc}
+
\section{Internals \label{internals}}