summaryrefslogtreecommitdiff
path: root/lib/ext/psk_ke_modes.c
diff options
context:
space:
mode:
Diffstat (limited to 'lib/ext/psk_ke_modes.c')
-rw-r--r--lib/ext/psk_ke_modes.c226
1 files changed, 131 insertions, 95 deletions
diff --git a/lib/ext/psk_ke_modes.c b/lib/ext/psk_ke_modes.c
index 56b382e1f5..1c22c74c07 100644
--- a/lib/ext/psk_ke_modes.c
+++ b/lib/ext/psk_ke_modes.c
@@ -25,50 +25,71 @@
#include "ext/pre_shared_key.h"
#include <assert.h>
-#define PSK_KE 0
-#define PSK_DHE_KE 1
+static void get_server_kx_algo_order(gnutls_session_t session, int *psk_pos, int *dhpsk_pos)
+{
+ unsigned i;
-/*
- * We only support ECDHE-authenticated PSKs.
- * The client just sends a "psk_key_exchange_modes" extension
- * with the value one.
- */
-static int
-psk_ke_modes_send_params(gnutls_session_t session,
- gnutls_buffer_t extdata)
+ *psk_pos = INT_MAX;
+ *dhpsk_pos = INT_MAX;
+
+ for (i = 0; i < session->internals.priorities->_kx.algorithms; i++) {
+ if (session->internals.priorities->_kx.priority[i] == GNUTLS_KX_PSK &&
+ *psk_pos == INT_MAX)
+ *psk_pos = i;
+ else if ((session->internals.priorities->_kx.priority[i] == GNUTLS_KX_DHE_PSK ||
+ session->internals.priorities->_kx.priority[i] == GNUTLS_KX_ECDHE_PSK) &&
+ *dhpsk_pos == INT_MAX)
+ *dhpsk_pos = i;
+
+ if (*psk_pos != INT_MAX && *dhpsk_pos != INT_MAX)
+ break;
+ }
+}
+
+static void get_client_kx_algo_order(gnutls_session_t session, int *psk_pos, int *dhpsk_pos)
{
- int ret;
- gnutls_psk_client_credentials_t cred;
- const version_entry_st *vers;
- uint8_t data[2];
- unsigned pos, i;
- unsigned have_dhpsk = 0;
- unsigned have_psk = 0;
+ *dhpsk_pos = INT_MAX;
+ *psk_pos = INT_MAX;
+
+ for (unsigned i = 0; i < session->internals.psk_ke_modes_size; i++) {
+ if (session->internals.psk_ke_modes[i] == PSK_DHE_KE)
+ *dhpsk_pos = i;
+ else if (session->internals.psk_ke_modes[i] == PSK_KE)
+ *psk_pos = i;
+ }
+}
- /* Server doesn't send psk_key_exchange_modes */
- if (session->security_parameters.entity == GNUTLS_SERVER)
- return 0;
+static int check_ke_modes(gnutls_session_t session)
+{
+ int mode = -1;
+ int psk_pos, dhpsk_pos;
+ int cli_psk_pos, cli_dhpsk_pos;
- cred = (gnutls_psk_client_credentials_t)
- _gnutls_get_cred(session, GNUTLS_CRD_PSK);
- if (cred == NULL || _gnutls_have_psk_credentials(cred) == 0) {
- /*
- * No out-of-band PSKs - do we have a session ticket?
- * We're not interested in the ticket itself.
- */
- if (session->internals.tls13_ticket.ticket.data == NULL)
- return 0;
+ get_server_kx_algo_order(session, &psk_pos, &dhpsk_pos);
+ get_client_kx_algo_order(session, &cli_psk_pos, &cli_dhpsk_pos);
+
+ if (session->internals.priorities->server_precedence) {
+ if (dhpsk_pos != INT_MAX && cli_dhpsk_pos != INT_MAX && dhpsk_pos < psk_pos)
+ mode = PSK_DHE_KE;
+ else if (psk_pos != INT_MAX && cli_psk_pos != INT_MAX && psk_pos < dhpsk_pos)
+ mode = PSK_KE;
} else {
- if (!session->internals.priorities->have_psk)
- return 0;
+ if (dhpsk_pos != INT_MAX && cli_dhpsk_pos != INT_MAX && cli_dhpsk_pos < cli_psk_pos)
+ mode = PSK_DHE_KE;
+ else if (psk_pos != INT_MAX && cli_psk_pos != INT_MAX && cli_psk_pos < cli_dhpsk_pos)
+ mode = PSK_KE;
}
- vers = _gnutls_version_max(session);
- if (!vers || !vers->tls13_sem)
- return 0;
+ return mode;
+}
- pos = 0;
- for (i=0;i<session->internals.priorities->_kx.algorithms;i++) {
+static int make_data(gnutls_session_t session, uint8_t data[2])
+{
+ int pos = 0;
+ unsigned have_dhpsk = 0;
+ unsigned have_psk = 0;
+
+ for (unsigned i = 0; i < session->internals.priorities->_kx.algorithms; i++) {
if (session->internals.priorities->_kx.priority[i] == GNUTLS_KX_PSK && !have_psk) {
assert(pos <= 1);
data[pos++] = PSK_KE;
@@ -86,22 +107,50 @@ psk_ke_modes_send_params(gnutls_session_t session,
break;
}
- ret = _gnutls_buffer_append_data_prefix(extdata, 8, data, pos);
- if (ret < 0)
- return gnutls_assert_val(ret);
+ return pos;
+}
+
+static int
+psk_ke_modes_send_params(gnutls_session_t session,
+ gnutls_buffer_t extdata)
+{
+ int ret;
+ gnutls_psk_client_credentials_t cred;
+ const version_entry_st *vers;
+ uint8_t data[2];
+ unsigned pos;
+
+ /* Server doesn't send psk_key_exchange_modes */
+ if (session->security_parameters.entity == GNUTLS_SERVER)
+ return 0;
- session->internals.hsk_flags |= HSK_PSK_KE_MODES_SENT;
+ vers = _gnutls_version_max(session);
+ if (!vers || !vers->tls13_sem)
+ return 0;
+
+ cred = (gnutls_psk_client_credentials_t)
+ _gnutls_get_cred(session, GNUTLS_CRD_PSK);
+ if (cred == NULL || _gnutls_have_psk_credentials(cred) == 0) {
+ /* No out-of-band PSKs - do we have a session ticket? */
+ if (!session->internals.session_ticket_enable ||
+ session->internals.tls13_ticket.ticket.data == NULL)
+ return 0;
+ } else {
+ if (!session->internals.priorities->have_psk)
+ return 0;
+ }
+
+ pos = make_data(session, data);
+ if (pos > 0) {
+ ret = _gnutls_buffer_append_data_prefix(extdata, 8, data, pos);
+ if (ret < 0)
+ return gnutls_assert_val(ret);
+ session->internals.hsk_flags |= HSK_PSK_KE_MODES_SENT;
+ }
return 0;
}
-#define MAX_POS INT_MAX
-
-/*
- * Since we only support ECDHE-authenticated PSKs, the server
- * just verifies that a "psk_key_exchange_modes" extension was received,
- * and that it contains the value one.
- */
static int
psk_ke_modes_recv_params(gnutls_session_t session,
const unsigned char *data, size_t _len)
@@ -109,12 +158,10 @@ psk_ke_modes_recv_params(gnutls_session_t session,
uint8_t ke_modes_len;
ssize_t len = _len;
const version_entry_st *vers = get_version(session);
- gnutls_psk_server_credentials_t cred;
- int dhpsk_pos = MAX_POS;
- int psk_pos = MAX_POS;
- int cli_psk_pos = MAX_POS;
- int cli_dhpsk_pos = MAX_POS;
- unsigned i;
+ int selected_ke_modes[2];
+ unsigned have_dhpsk = 0;
+ unsigned have_psk = 0;
+ unsigned i, pos = 0;
/* Server doesn't send psk_key_exchange_modes */
if (session->security_parameters.entity == GNUTLS_CLIENT)
@@ -123,59 +170,48 @@ psk_ke_modes_recv_params(gnutls_session_t session,
if (!vers || !vers->tls13_sem)
return 0;
- cred = (gnutls_psk_server_credentials_t)_gnutls_get_cred(session, GNUTLS_CRD_PSK);
- if (cred == NULL && !session->internals.session_ticket_enable)
- return 0;
-
DECR_LEN(len, 1);
ke_modes_len = *(data++);
- for (i=0;i<session->internals.priorities->_kx.algorithms;i++) {
- if (session->internals.priorities->_kx.priority[i] == GNUTLS_KX_PSK && psk_pos == MAX_POS) {
- psk_pos = i;
- } else if ((session->internals.priorities->_kx.priority[i] == GNUTLS_KX_DHE_PSK ||
- session->internals.priorities->_kx.priority[i] == GNUTLS_KX_ECDHE_PSK) &&
- dhpsk_pos == MAX_POS) {
- dhpsk_pos = i;
- }
+ for (i = 0; i < ke_modes_len; i++) {
+ DECR_LEN(len, 1);
- if (dhpsk_pos != MAX_POS && psk_pos != MAX_POS)
+ switch (data[i]) {
+ case PSK_DHE_KE:
+ selected_ke_modes[pos++] = PSK_DHE_KE;
+ have_dhpsk = 1;
break;
- }
-
- if (session->internals.priorities->groups.size == 0 && psk_pos == MAX_POS)
- return gnutls_assert_val(0);
-
- for (i=0;i<ke_modes_len;i++) {
- DECR_LEN(len, 1);
- if (data[i] == PSK_DHE_KE)
- cli_dhpsk_pos = i;
- else if (data[i] == PSK_KE)
- cli_psk_pos = i;
+ case PSK_KE:
+ selected_ke_modes[pos++] = PSK_KE;
+ have_psk = 1;
+ break;
+ default:
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_INVALID;
+ return 0;
+ }
- if (cli_psk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS)
+ if (have_dhpsk && have_psk)
break;
}
- if (session->internals.priorities->server_precedence) {
- if (dhpsk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS && dhpsk_pos < psk_pos)
- session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
- else if (psk_pos != MAX_POS && cli_psk_pos != MAX_POS && psk_pos < dhpsk_pos)
- session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
- } else {
- if (dhpsk_pos != MAX_POS && cli_dhpsk_pos != MAX_POS && cli_dhpsk_pos < cli_psk_pos)
- session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
- else if (psk_pos != MAX_POS && cli_psk_pos != MAX_POS && cli_psk_pos < cli_dhpsk_pos)
- session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
+ for (i = 0; i < pos; i++)
+ session->internals.psk_ke_modes[i] = selected_ke_modes[i];
+ session->internals.psk_ke_modes_size = pos;
+
+ switch (check_ke_modes(session)) {
+ case PSK_DHE_KE:
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_DHE_PSK;
+ _gnutls_handshake_log("EXT[%p]: Selected DHE-PSK mode\n", session);
+ break;
+ case PSK_KE:
+ session->internals.hsk_flags |= HSK_PSK_KE_MODE_PSK;
+ _gnutls_handshake_log("EXT[%p]: Selected PSK mode\n", session);
+ break;
+ default:
+ return gnutls_assert_val(GNUTLS_E_RECEIVED_ILLEGAL_PARAMETER);
}
- if ((session->internals.hsk_flags & HSK_PSK_KE_MODE_PSK) ||
- (session->internals.hsk_flags & HSK_PSK_KE_MODE_DHE_PSK)) {
- return 0;
- } else {
- session->internals.hsk_flags |= HSK_PSK_KE_MODE_INVALID;
- return 0;
- }
+ return 0;
}
const hello_ext_entry_st ext_psk_ke_modes = {