summaryrefslogtreecommitdiff
path: root/lib/ext
diff options
context:
space:
mode:
authorAnder Juaristi <a@juaristi.eus>2018-03-22 08:59:56 +0100
committerNikos Mavrogiannopoulos <nmav@redhat.com>2018-04-06 13:28:55 +0200
commit921cee23b4c7ee5d4e4537431e7fb1e9411be2d6 (patch)
tree3b1b423ea33220f41c49d7d5322fd505c4dfb55d /lib/ext
parenta046665a384a728253ad94122dfcbd25a52478c2 (diff)
downloadgnutls-921cee23b4c7ee5d4e4537431e7fb1e9411be2d6.tar.gz
Added support for out-of-band Pre-shared keys under TLS1.3
That adds support for pre-shared keys with and without Diffie-Hellman key exchange. That's a modified version of initial Ander's patch. Resolves #414 Resolves #125 Signed-off-by: Ander Juaristi <a@juaristi.eus> Signed-off-by: Nikos Mavrogiannopoulos <nmav@redhat.org>
Diffstat (limited to 'lib/ext')
-rw-r--r--lib/ext/Makefile.am3
-rw-r--r--lib/ext/key_share.c19
-rw-r--r--lib/ext/pre_shared_key.c470
-rw-r--r--lib/ext/pre_shared_key.h18
-rw-r--r--lib/ext/psk_ke_modes.c180
-rw-r--r--lib/ext/psk_ke_modes.h8
6 files changed, 695 insertions, 3 deletions
diff --git a/lib/ext/Makefile.am b/lib/ext/Makefile.am
index 63d94760bb..89d2389be9 100644
--- a/lib/ext/Makefile.am
+++ b/lib/ext/Makefile.am
@@ -43,7 +43,8 @@ libgnutls_ext_la_SOURCES = max_record.c \
ext_master_secret.c ext_master_secret.h etm.h etm.c \
supported_versions.c supported_versions.h \
post_handshake.c post_handshake.h key_share.c key_share.h \
- cookie.c cookie.h
+ cookie.c cookie.h \
+ psk_ke_modes.c psk_ke_modes.h pre_shared_key.c pre_shared_key.h
if ENABLE_ALPN
libgnutls_ext_la_SOURCES += alpn.c alpn.h
diff --git a/lib/ext/key_share.c b/lib/ext/key_share.c
index f9403df838..871ff08ceb 100644
--- a/lib/ext/key_share.c
+++ b/lib/ext/key_share.c
@@ -506,6 +506,13 @@ key_share_recv_params(gnutls_session_t session,
if (data_size != size)
return gnutls_assert_val(GNUTLS_E_UNEXPECTED_PACKET_LENGTH);
+ /* if we do PSK without DH ignore that share */
+ if ((session->internals.hsk_flags & HSK_PSK_SELECTED) &&
+ (session->internals.hsk_flags & HSK_PSK_KE_MODE_PSK)) {
+ reset_cand_groups(session);
+ return 0;
+ }
+
while(data_size > 0) {
DECR_LEN(data_size, 2);
gid = _gnutls_read_uint16(data);
@@ -554,8 +561,9 @@ key_share_recv_params(gnutls_session_t session,
* In cases (2,3) the error is translated to illegal
* parameter alert.
*/
- if (used_share == 0)
+ if (used_share == 0) {
return gnutls_assert_val(GNUTLS_E_NO_COMMON_KEY_SHARE);
+ }
} else { /* Client */
ver = get_version(session);
@@ -611,6 +619,7 @@ key_share_recv_params(gnutls_session_t session,
}
_gnutls_session_group_set(session, group);
+ session->internals.hsk_flags |= HSK_KEY_SHARE_RECEIVED;
ret = client_use_key_share(session, group, data, size);
if (ret < 0)
@@ -718,6 +727,11 @@ key_share_send_params(gnutls_session_t session,
if (ret < 0)
return gnutls_assert_val(ret);
} else {
+ /* if we are negotiating PSK without DH, do not send a key share */
+ if ((session->internals.hsk_flags & HSK_PSK_SELECTED) &&
+ (session->internals.hsk_flags & HSK_PSK_KE_MODE_PSK))
+ return gnutls_assert_val(0);
+
group = get_group(session);
if (unlikely(group == NULL))
return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_PARAMETER);
@@ -726,8 +740,9 @@ key_share_send_params(gnutls_session_t session,
if (ret < 0)
return gnutls_assert_val(ret);
}
+
+ session->internals.hsk_flags |= HSK_KEY_SHARE_SENT;
}
return 0;
}
-
diff --git a/lib/ext/pre_shared_key.c b/lib/ext/pre_shared_key.c
new file mode 100644
index 0000000000..02c2288528
--- /dev/null
+++ b/lib/ext/pre_shared_key.c
@@ -0,0 +1,470 @@
+/*
+ * Copyright (C) 2017-2018 Free Software Foundation, Inc.
+ * Copyright (C) 2018 Red Hat, Inc.
+ *
+ * Author: Ander Juaristi, Nikos Mavrogiannopoulos
+ *
+ * This file is part of GnuTLS.
+ *
+ * The GnuTLS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1 of
+ * the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>
+ *
+ */
+
+#include "gnutls_int.h"
+#include "auth/psk.h"
+#include "secrets.h"
+#include "tls13/psk_ext_parser.h"
+#include "tls13/finished.h"
+#include "auth/psk_passwd.h"
+#include <ext/pre_shared_key.h>
+#include <assert.h>
+
+typedef struct {
+ uint16_t selected_identity;
+} psk_ext_st;
+
+static int
+compute_binder_key(const mac_entry_st *prf,
+ const uint8_t *key, size_t keylen,
+ void *out)
+{
+ int ret;
+ char label[] = "ext_binder";
+ size_t label_len = sizeof(label) - 1;
+ uint8_t tmp_key[MAX_HASH_SIZE];
+
+ /* Compute HKDF-Extract(0, psk) */
+ ret = _tls13_init_secret2(prf, key, keylen, tmp_key);
+ if (ret < 0)
+ return ret;
+
+ /* Compute Derive-Secret(secret, label, transcript_hash) */
+ ret = _tls13_derive_secret2(prf,
+ label, label_len,
+ NULL, 0,
+ tmp_key,
+ out);
+ if (ret < 0)
+ return ret;
+
+ return 0;
+}
+
+static int
+compute_psk_binder(unsigned entity,
+ const mac_entry_st *prf, unsigned binders_length, unsigned hash_size,
+ int exts_length, int ext_offset,
+ const gnutls_datum_t *psk, const gnutls_datum_t *client_hello,
+ void *out)
+{
+ int ret;
+ unsigned extensions_len_pos;
+ gnutls_buffer_st handshake_buf;
+ uint8_t binder_key[MAX_HASH_SIZE];
+
+ _gnutls_buffer_init(&handshake_buf);
+
+ if (entity == GNUTLS_CLIENT) {
+ ret = gnutls_buffer_append_data(&handshake_buf,
+ (const void *) client_hello->data,
+ client_hello->size);
+ if (ret < 0) {
+ gnutls_assert();
+ goto error;
+ }
+
+ /* This is a ClientHello message */
+ handshake_buf.data[0] = GNUTLS_HANDSHAKE_CLIENT_HELLO;
+
+ /*
+ * At this point we have not yet added the binders to the ClientHello,
+ * but we have to overwrite the size field, pretending as if binders
+ * of the correct length were present.
+ */
+ _gnutls_write_uint24(handshake_buf.length + binders_length - 2, &handshake_buf.data[1]);
+ _gnutls_write_uint16(handshake_buf.length + binders_length - ext_offset,
+ &handshake_buf.data[ext_offset]);
+
+ extensions_len_pos = handshake_buf.length - exts_length - 2;
+ _gnutls_write_uint16(exts_length + binders_length + 2,
+ &handshake_buf.data[extensions_len_pos]);
+ } else {
+ gnutls_buffer_append_data(&handshake_buf,
+ (const void *) client_hello->data,
+ client_hello->size - binders_length - 3);
+ }
+
+ ret = compute_binder_key(prf,
+ psk->data, psk->size,
+ binder_key);
+ if (ret < 0) {
+ gnutls_assert();
+ goto error;
+ }
+
+ ret = _gnutls13_compute_finished(prf,
+ binder_key, hash_size,
+ &handshake_buf,
+ out);
+ if (ret < 0) {
+ gnutls_assert();
+ goto error;
+ }
+
+ ret = 0;
+error:
+ _gnutls_buffer_clear(&handshake_buf);
+ return ret;
+}
+
+static int
+client_send_params(gnutls_session_t session,
+ gnutls_buffer_t extdata,
+ const gnutls_psk_client_credentials_t cred)
+{
+ int ret, ext_offset = 0;
+ uint8_t binder_value[MAX_HASH_SIZE];
+ size_t length, pos;
+ gnutls_datum_t username = {NULL, 0}, key = {NULL, 0}, client_hello;
+ const mac_entry_st *prf = cred->binder_algo;
+ unsigned hash_size = _gnutls_mac_get_algo_len(prf);
+ int free_data;
+
+ if (prf == NULL || hash_size == 0 || hash_size > 255)
+ return gnutls_assert_val(GNUTLS_E_INSUFFICIENT_CREDENTIALS);
+
+ /* Credentials but no username set - this extension is not applicable */
+ if (!_gnutls_have_psk_credentials(cred))
+ return 0;
+
+ ret = _gnutls_find_psk_key(session, cred, &username, &key, &free_data);
+ if (ret < 0)
+ return gnutls_assert_val(ret);
+
+ if (username.size == 0 || username.size > UINT16_MAX) {
+ ret = gnutls_assert_val(GNUTLS_E_INVALID_PASSWORD);
+ goto cleanup;
+ }
+
+ /* placeholder to be filled later */
+ pos = extdata->length;
+ ret = _gnutls_buffer_append_prefix(extdata, 16, 0);
+ if (ret < 0) {
+ gnutls_assert_val(ret);
+ goto cleanup;
+ }
+
+ if ((ret = _gnutls_buffer_append_data_prefix(extdata, 16,
+ username.data, username.size)) < 0) {
+ gnutls_assert();
+ goto cleanup;
+ }
+
+ /* Now append the ticket age, which is always zero for out-of-band PSKs */
+ if ((ret = _gnutls_buffer_append_prefix(extdata, 32, 0)) < 0) {
+ gnutls_assert();
+ goto cleanup;
+ }
+ /* Total length appended is the length of the data, plus six octets */
+ length = (username.size + 6);
+
+ _gnutls_write_uint16(length, &extdata->data[pos]);
+
+ ext_offset = _gnutls_ext_get_extensions_offset(session);
+
+ /* Compute the binders. extdata->data points to the start
+ * of this client hello. */
+ assert(extdata->length >= sizeof(mbuffer_st));
+ assert(ext_offset >= (ssize_t)sizeof(mbuffer_st));
+ ext_offset -= sizeof(mbuffer_st);
+ client_hello.data = extdata->data+sizeof(mbuffer_st);
+ client_hello.size = extdata->length-sizeof(mbuffer_st);
+
+ ret = compute_psk_binder(GNUTLS_CLIENT, prf,
+ hash_size+1, hash_size, extdata->length-pos,
+ ext_offset, &key, &client_hello,
+ binder_value);
+ if (ret < 0) {
+ gnutls_assert();
+ goto cleanup;
+ }
+
+ /* Associate the selected pre-shared key with the session */
+ session->key.psk.data = key.data;
+ session->key.psk.size = key.size;
+ session->key.psk_needs_free = free_data;
+ key.data = NULL;
+ session->key.proto.tls13.binder_prf = prf;
+
+ /* Now append the binders */
+ ret = _gnutls_buffer_append_prefix(extdata, 16, hash_size+1);
+ if (ret < 0) {
+ gnutls_assert();
+ goto cleanup;
+ }
+
+ /* Add the size of the binder (we only have one) */
+ ret = _gnutls_buffer_append_data_prefix(extdata, 8, binder_value, hash_size);
+ if (ret < 0) {
+ gnutls_assert();
+ goto cleanup;
+ }
+
+ ret = 0;
+
+cleanup:
+ if (free_data) {
+ _gnutls_free_datum(&username);
+ _gnutls_free_temp_key_datum(&key);
+ }
+ return ret;
+}
+
+static int
+server_send_params(gnutls_session_t session, gnutls_buffer_t extdata)
+{
+ int ret;
+
+ if (!(session->internals.hsk_flags & HSK_PSK_SELECTED))
+ return 0;
+
+ ret = _gnutls_buffer_append_prefix(extdata, 16,
+ session->key.proto.tls13.psk_index);
+ if (ret < 0)
+ return gnutls_assert_val(ret);
+
+ return 2;
+}
+
+static int server_recv_params(gnutls_session_t session,
+ const unsigned char *data, size_t len,
+ const gnutls_psk_server_credentials_t pskcred)
+{
+ int ret;
+ const mac_entry_st *prf;
+ gnutls_datum_t full_client_hello;
+ uint8_t binder_value[MAX_HASH_SIZE];
+ int psk_index = -1;
+ gnutls_datum_t binder_recvd = { NULL, 0 };
+ gnutls_datum_t key;
+ unsigned hash_size;
+ psk_ext_parser_st psk_parser;
+ struct psk_st psk;
+
+ ret = _gnutls13_psk_ext_parser_init(&psk_parser, data, len);
+ if (ret == 0) {
+ /* No PSKs advertised by client */
+ return 0;
+ } else if (ret < 0) {
+ return gnutls_assert_val(ret);
+ }
+
+ while ((ret = _gnutls13_psk_ext_parser_next_psk(&psk_parser, &psk)) >= 0) {
+ if (psk.ob_ticket_age == 0) {
+ /* _gnutls_psk_pwd_find_entry() expects 0-terminated identities */
+ if (psk.identity.size > 0 && psk.identity.size <= MAX_USERNAME_SIZE) {
+ char identity_str[psk.identity.size + 1];
+
+ memcpy(identity_str, psk.identity.data, psk.identity.size);
+ identity_str[psk.identity.size] = 0;
+
+ ret = _gnutls_psk_pwd_find_entry(session, identity_str, &key);
+ if (ret == 0)
+ psk_index = ret;
+ }
+ }
+ }
+
+ if (psk_index < 0)
+ return 0;
+
+ ret = _gnutls13_psk_ext_parser_find_binder(&psk_parser, psk_index,
+ &binder_recvd);
+ if (ret < 0)
+ return gnutls_assert_val(ret);
+
+ /* Get full ClientHello */
+ if (!_gnutls_ext_get_full_client_hello(session, &full_client_hello)) {
+ ret = 0;
+ goto cleanup;
+ }
+
+ /* Compute the binder value for this PSK */
+ prf = pskcred->binder_algo;
+ hash_size = prf->output_size;
+ ret = compute_psk_binder(GNUTLS_SERVER, prf, hash_size, hash_size, 0, 0,
+ &key, &full_client_hello,
+ binder_value);
+ if (ret < 0) {
+ gnutls_assert();
+ goto cleanup;
+ }
+
+ if (_gnutls_mac_get_algo_len(prf) != binder_recvd.size ||
+ safe_memcmp(binder_value, binder_recvd.data, binder_recvd.size)) {
+ gnutls_free(key.data);
+ ret = gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_PARAMETER);
+ goto cleanup;
+ }
+
+ if (session->internals.hsk_flags & HSK_PSK_KE_MODE_DHE_PSK)
+ _gnutls_handshake_log("EXT[%p]: Selected DHE-PSK mode\n", session);
+ else {
+ reset_cand_groups(session);
+ _gnutls_handshake_log("EXT[%p]: Selected PSK mode\n", session);
+ }
+
+ session->internals.hsk_flags |= HSK_PSK_SELECTED;
+
+ /* Reference the selected pre-shared key */
+ session->key.psk.data = key.data;
+ session->key.psk.size = key.size;
+ session->key.psk_needs_free = 1;
+
+ session->key.proto.tls13.psk_index = psk_index;
+ session->key.proto.tls13.binder_prf = prf;
+
+ ret = 0;
+ cleanup:
+ _gnutls_free_datum(&binder_recvd);
+
+ return ret;
+}
+
+/*
+ * Return values for this function:
+ * - 0 : Not applicable.
+ * - >0 : Ok. Return size of extension data.
+ * - GNUTLS_E_INT_RET_0 : Size of extension data is zero.
+ * - <0 : There's been an error.
+ *
+ * In the client, generates the PskIdentity and PskBinderEntry messages.
+ *
+ * PskIdentity identities<7..2^16-1>;
+ * PskBinderEntry binders<33..2^16-1>;
+ *
+ * struct {
+ * opaque identity<1..2^16-1>;
+ * uint32 obfuscated_ticket_age;
+ * } PskIdentity;
+ *
+ * opaque PskBinderEntry<32..255>;
+ *
+ * The server sends the selected identity, which is a zero-based index
+ * of the PSKs offered by the client:
+ *
+ * struct {
+ * uint16 selected_identity;
+ * } PreSharedKeyExtension;
+ */
+static int _gnutls_psk_send_params(gnutls_session_t session,
+ gnutls_buffer_t extdata)
+{
+ gnutls_psk_client_credentials_t cred = NULL;
+ const version_entry_st *vers;
+
+ if (session->security_parameters.entity == GNUTLS_CLIENT) {
+ vers = _gnutls_version_max(session);
+
+ if (!vers || !vers->tls13_sem)
+ return 0;
+
+ if (session->internals.hsk_flags & HSK_PSK_KE_MODES_SENT) {
+ cred = (gnutls_psk_client_credentials_t)
+ _gnutls_get_cred(session, GNUTLS_CRD_PSK);
+ /* If there are no PSK credentials, this extension is not applicable,
+ * so we return zero. */
+ if (cred == NULL || !session->internals.priorities->have_psk)
+ return 0;
+
+ return client_send_params(session, extdata, cred);
+ } else {
+ return 0;
+ }
+ } else {
+ vers = get_version(session);
+
+ if (!vers || !vers->tls13_sem)
+ return 0;
+
+ cred = (gnutls_psk_client_credentials_t)
+ _gnutls_get_cred(session, GNUTLS_CRD_PSK);
+ if (cred == NULL || !session->internals.priorities->have_psk)
+ return 0;
+
+ if (session->internals.hsk_flags & HSK_PSK_KE_MODES_RECEIVED)
+ return server_send_params(session, extdata);
+ else
+ return 0;
+ }
+}
+
+/*
+ * Return values for this function:
+ * - 0 : Not applicable.
+ * - >0 : Ok. Return size of extension data.
+ * - <0 : There's been an error.
+ */
+static int _gnutls_psk_recv_params(gnutls_session_t session,
+ const unsigned char *data, size_t len)
+{
+ gnutls_psk_server_credentials_t pskcred;
+ const version_entry_st *vers = get_version(session);
+
+ if (!vers || !vers->tls13_sem)
+ return 0;
+
+ if (session->security_parameters.entity == GNUTLS_CLIENT) {
+ if (session->internals.hsk_flags & HSK_PSK_KE_MODES_SENT) {
+ uint16_t selected_identity = _gnutls_read_uint16(data);
+
+ if (selected_identity == 0) {
+ _gnutls_handshake_log("EXT[%p]: Selected PSK mode\n", session);
+ session->internals.hsk_flags |= HSK_PSK_SELECTED;
+ }
+ return 0;
+ } else {
+ return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_EXTENSION);
+ }
+ } else {
+ if (session->internals.hsk_flags & HSK_PSK_KE_MODES_RECEIVED) {
+ if (session->internals.hsk_flags & HSK_PSK_KE_MODE_INVALID) {
+ /* We received a "psk_ke_modes" extension, but with a value we don't support */
+ return 0;
+ }
+
+ pskcred = (gnutls_psk_server_credentials_t)
+ _gnutls_get_cred(session, GNUTLS_CRD_PSK);
+
+ /* If there are no PSK credentials, this extension is not applicable,
+ * so we return zero. */
+ if (pskcred == NULL)
+ return 0;
+
+ return server_recv_params(session, data, len, pskcred);
+ } else {
+ return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_EXTENSION);
+ }
+ }
+}
+
+const hello_ext_entry_st ext_pre_shared_key = {
+ .name = "Pre Shared Key",
+ .tls_id = 41,
+ .gid = GNUTLS_EXTENSION_PRE_SHARED_KEY,
+ .parse_type = GNUTLS_EXT_TLS,
+ .validity = GNUTLS_EXT_FLAG_CLIENT_HELLO | GNUTLS_EXT_FLAG_TLS13_SERVER_HELLO,
+ .send_func = _gnutls_psk_send_params,
+ .recv_func = _gnutls_psk_recv_params
+};
diff --git a/lib/ext/pre_shared_key.h b/lib/ext/pre_shared_key.h
new file mode 100644
index 0000000000..25dd159f6e
--- /dev/null
+++ b/lib/ext/pre_shared_key.h
@@ -0,0 +1,18 @@
+#ifndef EXT_PRE_SHARED_KEY_H
+#define EXT_PRE_SHARED_KEY_H
+
+#include "auth/psk.h"
+#include <hello_ext.h>
+
+extern const hello_ext_entry_st ext_pre_shared_key;
+
+inline static
+unsigned _gnutls_have_psk_credentials(const gnutls_psk_client_credentials_t cred)
+{
+ if (cred->get_function || cred->username.data)
+ return 1;
+ else
+ return 0;
+}
+
+#endif
diff --git a/lib/ext/psk_ke_modes.c b/lib/ext/psk_ke_modes.c
new file mode 100644
index 0000000000..c6aef3bda8
--- /dev/null
+++ b/lib/ext/psk_ke_modes.c
@@ -0,0 +1,180 @@
+/*
+ * Copyright (C) 2017 Free Software Foundation, Inc.
+ *
+ * Author: Ander Juaristi
+ *
+ * This file is part of GnuTLS.
+ *
+ * The GnuTLS is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public License
+ * as published by the Free Software Foundation; either version 2.1 of
+ * the License, or (at your option) any later version.
+ *
+ * This library is distributed in the hope that it will be useful, but
+ * WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>
+ *
+ */
+
+#include "gnutls_int.h"
+#include "ext/psk_ke_modes.h"
+#include "ext/pre_shared_key.h"
+#include <assert.h>
+
+#define PSK_KE 0
+#define PSK_DHE_KE 1
+
+/*
+ * We only support ECDHE-authenticated PSKs.
+ * The client just sends a "psk_key_exchange_modes" extension
+ * with the value one.
+ */
+static int
+psk_ke_modes_send_params(gnutls_session_t session,
+ gnutls_buffer_t extdata)
+{
+ int ret;
+ gnutls_psk_client_credentials_t cred;
+ const version_entry_st *vers;
+ uint8_t data[2];
+ unsigned pos, i;
+ unsigned have_dhpsk = 0;
+ unsigned have_psk = 0;
+
+ /* Server doesn't send psk_key_exchange_modes */
+ if (session->security_parameters.entity == GNUTLS_SERVER ||
+ !session->internals.priorities->have_psk)
+ return 0;
+
+ cred = (gnutls_psk_client_credentials_t)
+ _gnutls_get_cred(session, GNUTLS_CRD_PSK);
+ if (cred == NULL || _gnutls_have_psk_credentials(cred) == 0)
+ return 0;
+
+ vers = _gnutls_version_max(session);
+ if (!vers || !vers->tls13_sem)
+ return 0;
+
+ pos = 0;
+ for (i=0;i<session->internals.priorities->_kx.algorithms;i++) {
+ if (session->internals.priorities->_kx.priority[i] == GNUTLS_KX_PSK && !have_psk) {
+ assert(pos <= 1);
+ data[pos++] = PSK_KE;
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
+ have_psk = 1;
+ } else if ((session->internals.priorities->_kx.priority[i] == GNUTLS_KX_DHE_PSK ||
+ session->internals.priorities->_kx.priority[i] == GNUTLS_KX_ECDHE_PSK) && !have_dhpsk) {
+ assert(pos <= 1);
+ data[pos++] = PSK_DHE_KE;
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
+ have_dhpsk = 1;
+ }
+
+ if (have_psk && have_dhpsk)
+ break;
+ }
+
+ ret = _gnutls_buffer_append_data_prefix(extdata, 8, data, pos);
+ if (ret < 0)
+ return gnutls_assert_val(ret);
+
+ session->internals.hsk_flags |= HSK_PSK_KE_MODES_SENT;
+
+ return 0;
+}
+
+#define MAX_POS INT_MAX
+
+/*
+ * Since we only support ECDHE-authenticated PSKs, the server
+ * just verifies that a "psk_key_exchange_modes" extension was received,
+ * and that it contains the value one.
+ */
+static int
+psk_ke_modes_recv_params(gnutls_session_t session,
+ const unsigned char *data, size_t _len)
+{
+ uint8_t ke_modes_len;
+ ssize_t len = _len;
+ const version_entry_st *vers = get_version(session);
+ gnutls_psk_server_credentials_t cred;
+ int dhpsk_pos = MAX_POS;
+ int psk_pos = MAX_POS;
+ int cli_psk_pos = MAX_POS;
+ int cli_dhpsk_pos = MAX_POS;
+ unsigned i;
+
+ /* Server doesn't send psk_key_exchange_modes */
+ if (session->security_parameters.entity == GNUTLS_CLIENT)
+ return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_EXTENSION);
+
+ if (!vers || !vers->tls13_sem)
+ return 0;
+
+ cred = (gnutls_psk_server_credentials_t)_gnutls_get_cred(session, GNUTLS_CRD_PSK);
+ if (cred == NULL)
+ return 0;
+
+ DECR_LEN(len, 1);
+ ke_modes_len = *(data++);
+
+ for (i=0;i<session->internals.priorities->_kx.algorithms;i++) {
+ if (session->internals.priorities->_kx.priority[i] == GNUTLS_KX_PSK && psk_pos == MAX_POS) {
+ psk_pos = i;
+ } else if ((session->internals.priorities->_kx.priority[i] == GNUTLS_KX_DHE_PSK ||
+ session->internals.priorities->_kx.priority[i] == GNUTLS_KX_ECDHE_PSK) &&
+ dhpsk_pos == MAX_POS) {
+ dhpsk_pos = i;
+ }
+
+ if (dhpsk_pos != MAX_POS && psk_pos != MAX_POS)
+ break;
+ }
+
+ if (session->internals.priorities->groups.size == 0 && psk_pos == MAX_POS)
+ return gnutls_assert_val(0);
+
+ for (i=0;i<ke_modes_len;i++) {
+ if (data[i] == PSK_DHE_KE)
+ cli_dhpsk_pos = i;
+ if (data[i] == PSK_KE)
+ cli_psk_pos = i;
+
+ if (cli_psk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS)
+ break;
+ }
+
+ if (session->internals.priorities->server_precedence) {
+ if (dhpsk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS && dhpsk_pos < psk_pos)
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
+ else if (psk_pos != MAX_POS && cli_psk_pos != MAX_POS && psk_pos < dhpsk_pos)
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
+ } else {
+ if (dhpsk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS && cli_dhpsk_pos < cli_psk_pos)
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
+ else if (psk_pos != MAX_POS && cli_psk_pos != MAX_POS && cli_psk_pos < cli_dhpsk_pos)
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
+ }
+
+ if ((session->internals.hsk_flags & HSK_PSK_KE_MODE_PSK) ||
+ (session->internals.hsk_flags & HSK_PSK_KE_MODE_DHE_PSK)) {
+ return 0;
+ } else {
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_INVALID;
+ return 0;
+ }
+}
+
+const hello_ext_entry_st ext_psk_ke_modes = {
+ .name = "PSK Key Exchange Modes",
+ .tls_id = 45,
+ .gid = GNUTLS_EXTENSION_PSK_KE_MODES,
+ .parse_type = GNUTLS_EXT_TLS,
+ .validity = GNUTLS_EXT_FLAG_CLIENT_HELLO | GNUTLS_EXT_FLAG_TLS13_SERVER_HELLO,
+ .send_func = psk_ke_modes_send_params,
+ .recv_func = psk_ke_modes_recv_params
+};
diff --git a/lib/ext/psk_ke_modes.h b/lib/ext/psk_ke_modes.h
new file mode 100644
index 0000000000..bd06139ff5
--- /dev/null
+++ b/lib/ext/psk_ke_modes.h
@@ -0,0 +1,8 @@
+#ifndef EXT_PSK_KE_MODES_H
+#define EXT_PSK_KE_MODES_H
+
+#include <hello_ext.h>
+
+extern const hello_ext_entry_st ext_psk_ke_modes;
+
+#endif