summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/libssl_internals.c
diff options
context:
space:
mode:
Diffstat (limited to 'nss/gtests/ssl_gtest/libssl_internals.c')
-rw-r--r--nss/gtests/ssl_gtest/libssl_internals.c340
1 files changed, 340 insertions, 0 deletions
diff --git a/nss/gtests/ssl_gtest/libssl_internals.c b/nss/gtests/ssl_gtest/libssl_internals.c
new file mode 100644
index 0000000..5136ee8
--- /dev/null
+++ b/nss/gtests/ssl_gtest/libssl_internals.c
@@ -0,0 +1,340 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/* This file contains functions for frobbing the internals of libssl */
+#include "libssl_internals.h"
+
+#include "nss.h"
+#include "pk11pub.h"
+#include "seccomon.h"
+#include "ssl.h"
+#include "sslimpl.h"
+
+SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ++ss->clientHelloVersion;
+
+ return SECSuccess;
+}
+
+/* Use this function to update the ClientRandom of a client's handshake state
+ * after replacing its ClientHello message. We for example need to do this
+ * when replacing an SSLv3 ClientHello with its SSLv2 equivalent. */
+SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
+ size_t rnd_len, uint8_t *msg,
+ size_t msg_len) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ SECStatus rv = ssl3_InitState(ss);
+ if (rv != SECSuccess) {
+ return rv;
+ }
+
+ rv = ssl3_RestartHandshakeHashes(ss);
+ if (rv != SECSuccess) {
+ return rv;
+ }
+
+ // Ensure we don't overrun hs.client_random.
+ rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len);
+
+ // Zero the client_random struct.
+ PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
+
+ // Copy over the challenge bytes.
+ size_t offset = SSL3_RANDOM_LENGTH - rnd_len;
+ PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len);
+
+ // Rehash the SSLv2 client hello message.
+ return ssl3_UpdateHandshakeHashes(ss, msg, msg_len);
+}
+
+PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext));
+}
+
+void SSLInt_ClearSessionTicketKey() {
+ ssl3_SessionTicketShutdown(NULL, NULL);
+ NSS_UnregisterShutdown(ssl3_SessionTicketShutdown, NULL);
+}
+
+SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (ss) {
+ ss->ssl3.mtu = mtu;
+ return SECSuccess;
+ }
+ return SECFailure;
+}
+
+PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) {
+ PRCList *cur_p;
+ PRInt32 ct = 0;
+
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return -1;
+ }
+
+ for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
+ cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
+ ++ct;
+ }
+ return ct;
+}
+
+void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) {
+ PRCList *cur_p;
+
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return;
+ }
+
+ fprintf(stderr, "Cipher specs\n");
+ for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
+ cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
+ ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p;
+ fprintf(stderr, " %s\n", spec->phase);
+ }
+}
+
+/* Force a timer expiry by backdating when the timer was started.
+ * We could set the remaining time to 0 but then backoff would not
+ * work properly if we decide to test it. */
+void SSLInt_ForceTimerExpiry(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return;
+ }
+
+ if (!ss->ssl3.hs.rtTimerCb) return;
+
+ ss->ssl3.hs.rtTimerStarted =
+ PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1);
+}
+
+#define CHECK_SECRET(secret) \
+ if (ss->ssl3.hs.secret) { \
+ fprintf(stderr, "%s != NULL\n", #secret); \
+ return PR_FALSE; \
+ }
+
+PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ CHECK_SECRET(currentSecret);
+ CHECK_SECRET(resumptionMasterSecret);
+ CHECK_SECRET(dheSecret);
+ CHECK_SECRET(clientEarlyTrafficSecret);
+ CHECK_SECRET(clientHsTrafficSecret);
+ CHECK_SECRET(serverHsTrafficSecret);
+
+ return PR_TRUE;
+}
+
+PRBool sslint_DamageTrafficSecret(PRFileDesc *fd, size_t offset) {
+ unsigned char data[32] = {0};
+ PK11SymKey **keyPtr;
+ PK11SlotInfo *slot = PK11_GetInternalSlot();
+ SECItem key_item = {siBuffer, data, sizeof(data)};
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+ if (!slot) {
+ return PR_FALSE;
+ }
+ keyPtr = (PK11SymKey **)((char *)&ss->ssl3.hs + offset);
+ if (!*keyPtr) {
+ return PR_FALSE;
+ }
+ PK11_FreeSymKey(*keyPtr);
+ *keyPtr = PK11_ImportSymKey(slot, CKM_NSS_HKDF_SHA256, PK11_OriginUnwrap,
+ CKA_DERIVE, &key_item, NULL);
+ PK11_FreeSlot(slot);
+ if (!*keyPtr) {
+ return PR_FALSE;
+ }
+
+ return PR_TRUE;
+}
+
+PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, clientHsTrafficSecret));
+}
+
+PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, serverHsTrafficSecret));
+}
+
+PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, clientEarlyTrafficSecret));
+}
+
+SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ss->xtnData.nextProtoState = SSL_NEXT_PROTO_EARLY_VALUE;
+ if (ss->xtnData.nextProto.data) {
+ SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE);
+ }
+ if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) return SECFailure;
+ PORT_Memcpy(ss->xtnData.nextProto.data, data, len);
+
+ return SECSuccess;
+}
+
+PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ return (PRBool)(!!ssl_FindServerCertByAuthType(ss, authType));
+}
+
+PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ SECStatus rv = SSL3_SendAlert(ss, level, type);
+ if (rv != SECSuccess) return PR_FALSE;
+
+ return PR_TRUE;
+}
+
+PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ ssl_GetSSL3HandshakeLock(ss);
+ ssl_GetXmitBufLock(ss);
+
+ SECStatus rv = tls13_SendNewSessionTicket(ss);
+ if (rv == SECSuccess) {
+ rv = ssl3_FlushHandshake(ss, 0);
+ }
+
+ ssl_ReleaseXmitBufLock(ss);
+ ssl_ReleaseSSL3HandshakeLock(ss);
+
+ return rv == SECSuccess;
+}
+
+SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
+ PRUint64 epoch;
+ sslSocket *ss;
+ ssl3CipherSpec *spec;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ if (to >= (1ULL << 48)) {
+ return SECFailure;
+ }
+ ssl_GetSpecWriteLock(ss);
+ spec = ss->ssl3.crSpec;
+ epoch = spec->read_seq_num >> 48;
+ spec->read_seq_num = (epoch << 48) | to;
+
+ /* For DTLS, we need to fix the record sequence number. For this, we can just
+ * scrub the entire structure on the assumption that the new sequence number
+ * is far enough past the last received sequence number. */
+ if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
+ return SECFailure;
+ }
+ dtls_RecordSetRecvd(&spec->recvdRecords, to);
+
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECSuccess;
+}
+
+SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) {
+ PRUint64 epoch;
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ if (to >= (1ULL << 48)) {
+ return SECFailure;
+ }
+ ssl_GetSpecWriteLock(ss);
+ epoch = ss->ssl3.cwSpec->write_seq_num >> 48;
+ ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to;
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECSuccess;
+}
+
+SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) {
+ sslSocket *ss;
+ sslSequenceNumber to;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ ssl_GetSpecReadLock(ss);
+ to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra;
+ ssl_ReleaseSpecReadLock(ss);
+ return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX);
+}
+
+SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) {
+ const sslNamedGroupDef *groupDef = ssl_LookupNamedGroup(group);
+ if (!groupDef) return ssl_kea_null;
+
+ return groupDef->keaType;
+}
+
+SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) {
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ss->opt.enableShortHeaders = PR_TRUE;
+ return SECSuccess;
+}
+
+SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) {
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ *result = ss->ssl3.hs.shortHeaders;
+
+ return SECSuccess;
+}