summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest
diff options
context:
space:
mode:
authorLorry Tar Creator <lorry-tar-importer@lorry>2017-01-04 14:24:24 +0000
committerLorry Tar Creator <lorry-tar-importer@lorry>2017-01-04 14:24:24 +0000
commitdc1565216a5d20ae0d75872151523252309a1292 (patch)
treed57454ba9a40386552179eddf60d28bd1e8f3d54 /nss/gtests/ssl_gtest
parent26c046fbc57d53136b4fb3b5e0d18298318125d4 (diff)
downloadnss-dc1565216a5d20ae0d75872151523252309a1292.tar.gz
nss-3.28.1nss-3.28.1
Diffstat (limited to 'nss/gtests/ssl_gtest')
-rw-r--r--nss/gtests/ssl_gtest/Makefile59
-rw-r--r--nss/gtests/ssl_gtest/databuffer.h191
-rw-r--r--nss/gtests/ssl_gtest/gtest_utils.h57
-rw-r--r--nss/gtests/ssl_gtest/libssl_internals.c340
-rw-r--r--nss/gtests/ssl_gtest/libssl_internals.h43
-rw-r--r--nss/gtests/ssl_gtest/manifest.mn54
-rw-r--r--nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc203
-rw-r--r--nss/gtests/ssl_gtest/ssl_agent_unittest.cc210
-rw-r--r--nss/gtests/ssl_gtest/ssl_auth_unittest.cc736
-rw-r--r--nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc214
-rw-r--r--nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc455
-rw-r--r--nss/gtests/ssl_gtest/ssl_damage_unittest.cc61
-rw-r--r--nss/gtests/ssl_gtest/ssl_dhe_unittest.cc609
-rw-r--r--nss/gtests/ssl_gtest/ssl_drop_unittest.cc133
-rw-r--r--nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc532
-rw-r--r--nss/gtests/ssl_gtest/ssl_ems_unittest.cc100
-rw-r--r--nss/gtests/ssl_gtest/ssl_exporter_unittest.cc122
-rw-r--r--nss/gtests/ssl_gtest/ssl_extension_unittest.cc985
-rw-r--r--nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc223
-rw-r--r--nss/gtests/ssl_gtest/ssl_gtest.cc44
-rw-r--r--nss/gtests/ssl_gtest/ssl_gtest.gyp101
-rw-r--r--nss/gtests/ssl_gtest/ssl_hrr_unittest.cc285
-rw-r--r--nss/gtests/ssl_gtest/ssl_loopback_unittest.cc274
-rw-r--r--nss/gtests/ssl_gtest/ssl_record_unittest.cc111
-rw-r--r--nss/gtests/ssl_gtest/ssl_resumption_unittest.cc582
-rw-r--r--nss/gtests/ssl_gtest/ssl_skip_unittest.cc158
-rw-r--r--nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc123
-rw-r--r--nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc351
-rw-r--r--nss/gtests/ssl_gtest/ssl_version_unittest.cc300
-rw-r--r--nss/gtests/ssl_gtest/test_io.cc536
-rw-r--r--nss/gtests/ssl_gtest/test_io.h152
-rw-r--r--nss/gtests/ssl_gtest/tls_agent.cc992
-rw-r--r--nss/gtests/ssl_gtest/tls_agent.h457
-rw-r--r--nss/gtests/ssl_gtest/tls_connect.cc708
-rw-r--r--nss/gtests/ssl_gtest/tls_connect.h274
-rw-r--r--nss/gtests/ssl_gtest/tls_filter.cc503
-rw-r--r--nss/gtests/ssl_gtest/tls_filter.h343
-rw-r--r--nss/gtests/ssl_gtest/tls_hkdf_unittest.cc262
-rw-r--r--nss/gtests/ssl_gtest/tls_parser.cc73
-rw-r--r--nss/gtests/ssl_gtest/tls_parser.h131
40 files changed, 12087 insertions, 0 deletions
diff --git a/nss/gtests/ssl_gtest/Makefile b/nss/gtests/ssl_gtest/Makefile
new file mode 100644
index 0000000..dfb8df9
--- /dev/null
+++ b/nss/gtests/ssl_gtest/Makefile
@@ -0,0 +1,59 @@
+#! gmake
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+
+#######################################################################
+# (1) Include initial platform-independent assignments (MANDATORY). #
+#######################################################################
+
+include manifest.mn
+
+#######################################################################
+# (2) Include "global" configuration information. (OPTIONAL) #
+#######################################################################
+
+include $(CORE_DEPTH)/coreconf/config.mk
+
+#######################################################################
+# (3) Include "component" configuration information. (OPTIONAL) #
+#######################################################################
+
+
+#######################################################################
+# (4) Include "local" platform-dependent assignments (OPTIONAL). #
+#######################################################################
+
+include ../common/gtest.mk
+
+CFLAGS += -I$(CORE_DEPTH)/lib/ssl
+
+ifdef NSS_SSL_ENABLE_ZLIB
+include $(CORE_DEPTH)/coreconf/zlib.mk
+endif
+
+ifndef NSS_ENABLE_TLS_1_3
+NSS_DISABLE_TLS_1_3=1
+endif
+
+ifdef NSS_DISABLE_TLS_1_3
+# Run parameterized tests only, for which we can easily exclude TLS 1.3
+CPPSRCS := $(filter-out $(shell grep -l '^TEST_F' $(CPPSRCS)), $(CPPSRCS))
+CFLAGS += -DNSS_DISABLE_TLS_1_3
+endif
+
+#######################################################################
+# (5) Execute "global" rules. (OPTIONAL) #
+#######################################################################
+
+include $(CORE_DEPTH)/coreconf/rules.mk
+
+#######################################################################
+# (6) Execute "component" rules. (OPTIONAL) #
+#######################################################################
+
+
+#######################################################################
+# (7) Execute "local" rules. (OPTIONAL). #
+#######################################################################
diff --git a/nss/gtests/ssl_gtest/databuffer.h b/nss/gtests/ssl_gtest/databuffer.h
new file mode 100644
index 0000000..e7236d4
--- /dev/null
+++ b/nss/gtests/ssl_gtest/databuffer.h
@@ -0,0 +1,191 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef databuffer_h__
+#define databuffer_h__
+
+#include <algorithm>
+#include <cassert>
+#include <cstring>
+#include <iomanip>
+#include <iostream>
+#if defined(WIN32) || defined(WIN64)
+#include <winsock2.h>
+#else
+#include <arpa/inet.h>
+#endif
+
+extern bool g_ssl_gtest_verbose;
+
+namespace nss_test {
+
+class DataBuffer {
+ public:
+ DataBuffer() : data_(nullptr), len_(0) {}
+ DataBuffer(const uint8_t* data, size_t len) : data_(nullptr), len_(0) {
+ Assign(data, len);
+ }
+ DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) {
+ Assign(other);
+ }
+ ~DataBuffer() { delete[] data_; }
+
+ DataBuffer& operator=(const DataBuffer& other) {
+ if (&other != this) {
+ Assign(other);
+ }
+ return *this;
+ }
+
+ void Allocate(size_t len) {
+ delete[] data_;
+ data_ = new uint8_t[len ? len : 1]; // Don't depend on new [0].
+ len_ = len;
+ }
+
+ void Truncate(size_t len) { len_ = std::min(len_, len); }
+
+ void Assign(const DataBuffer& other) { Assign(other.data(), other.len()); }
+
+ void Assign(const uint8_t* data, size_t len) {
+ if (data) {
+ Allocate(len);
+ memcpy(static_cast<void*>(data_), static_cast<const void*>(data), len);
+ } else {
+ assert(len == 0);
+ data_ = nullptr;
+ len_ = 0;
+ }
+ }
+
+ // Write will do a new allocation and expand the size of the buffer if needed.
+ // Returns the offset of the end of the write.
+ size_t Write(size_t index, const uint8_t* val, size_t count) {
+ assert(val);
+ if (index + count > len_) {
+ size_t newlen = index + count;
+ uint8_t* tmp = new uint8_t[newlen]; // Always > 0.
+ if (data_) {
+ memcpy(static_cast<void*>(tmp), static_cast<const void*>(data_), len_);
+ }
+ if (index > len_) {
+ memset(static_cast<void*>(tmp + len_), 0, index - len_);
+ }
+ delete[] data_;
+ data_ = tmp;
+ len_ = newlen;
+ }
+ if (data_) {
+ memcpy(static_cast<void*>(data_ + index), static_cast<const void*>(val),
+ count);
+ }
+ return index + count;
+ }
+
+ size_t Write(size_t index, const DataBuffer& buf) {
+ return Write(index, buf.data(), buf.len());
+ }
+
+ // Write an integer, also performing host-to-network order conversion.
+ // Returns the offset of the end of the write.
+ size_t Write(size_t index, uint32_t val, size_t count) {
+ assert(count <= sizeof(uint32_t));
+ uint32_t nvalue = htonl(val);
+ auto* addr = reinterpret_cast<const uint8_t*>(&nvalue);
+ return Write(index, addr + sizeof(uint32_t) - count, count);
+ }
+
+ // This can't use the same trick as Write(), since we might be reading from a
+ // smaller data source.
+ bool Read(size_t index, size_t count, uint32_t* val) const {
+ assert(count < sizeof(uint32_t));
+ assert(val);
+ if ((index > len()) || (count > (len() - index))) {
+ return false;
+ }
+ *val = 0;
+ for (size_t i = 0; i < count; ++i) {
+ *val = (*val << 8) | data()[index + i];
+ }
+ return true;
+ }
+
+ // Starting at |index|, remove |remove| bytes and replace them with the
+ // contents of |buf|.
+ void Splice(const DataBuffer& buf, size_t index, size_t remove = 0) {
+ Splice(buf.data(), buf.len(), index, remove);
+ }
+
+ void Splice(const uint8_t* ins, size_t ins_len, size_t index,
+ size_t remove = 0) {
+ assert(ins);
+ uint8_t* old_value = data_;
+ size_t old_len = len_;
+
+ // The amount of stuff remaining from the tail of the old.
+ size_t tail_len = old_len - std::min(old_len, index + remove);
+ // The new length: the head of the old, the new, and the tail of the old.
+ len_ = index + ins_len + tail_len;
+ data_ = new uint8_t[len_ ? len_ : 1];
+
+ // The head of the old.
+ if (old_value) {
+ Write(0, old_value, std::min(old_len, index));
+ }
+ // Maybe a gap.
+ if (old_value && index > old_len) {
+ memset(old_value + index, 0, index - old_len);
+ }
+ // The new.
+ Write(index, ins, ins_len);
+ // The tail of the old.
+ if (tail_len > 0) {
+ Write(index + ins_len, old_value + index + remove, tail_len);
+ }
+
+ delete[] old_value;
+ }
+
+ void Append(const DataBuffer& buf) { Splice(buf, len_); }
+
+ const uint8_t* data() const { return data_; }
+ uint8_t* data() { return data_; }
+ size_t len() const { return len_; }
+ bool empty() const { return len_ == 0; }
+
+ private:
+ uint8_t* data_;
+ size_t len_;
+};
+
+static const size_t kMaxBufferPrint = 32;
+
+inline std::ostream& operator<<(std::ostream& stream, const DataBuffer& buf) {
+ stream << "[" << buf.len() << "] ";
+ for (size_t i = 0; i < buf.len(); ++i) {
+ if (!g_ssl_gtest_verbose && i >= kMaxBufferPrint) {
+ stream << "...";
+ break;
+ }
+ stream << std::hex << std::setfill('0') << std::setw(2)
+ << static_cast<unsigned>(buf.data()[i]);
+ }
+ stream << std::dec;
+ return stream;
+}
+
+inline bool operator==(const DataBuffer& a, const DataBuffer& b) {
+ return (a.empty() && b.empty()) ||
+ (a.len() == b.len() && 0 == memcmp(a.data(), b.data(), a.len()));
+}
+
+inline bool operator!=(const DataBuffer& a, const DataBuffer& b) {
+ return !(a == b);
+}
+
+} // namespace nss_test
+
+#endif
diff --git a/nss/gtests/ssl_gtest/gtest_utils.h b/nss/gtests/ssl_gtest/gtest_utils.h
new file mode 100644
index 0000000..3ecd96c
--- /dev/null
+++ b/nss/gtests/ssl_gtest/gtest_utils.h
@@ -0,0 +1,57 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef gtest_utils_h__
+#define gtest_utils_h__
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+#include "test_io.h"
+
+namespace nss_test {
+
+// Gtest utilities
+class Timeout : public PollTarget {
+ public:
+ Timeout(int32_t timer_ms) : handle_(nullptr) {
+ Poller::Instance()->SetTimer(timer_ms, this, &Timeout::ExpiredCallback,
+ &handle_);
+ }
+ ~Timeout() {
+ if (handle_) {
+ handle_->Cancel();
+ }
+ }
+
+ static void ExpiredCallback(PollTarget* target, Event event) {
+ Timeout* timeout = static_cast<Timeout*>(target);
+ timeout->handle_ = nullptr;
+ }
+
+ bool timed_out() const { return !handle_; }
+
+ private:
+ Poller::Timer* handle_;
+};
+
+} // namespace nss_test
+
+#define WAIT_(expression, timeout) \
+ do { \
+ Timeout tm(timeout); \
+ while (!(expression)) { \
+ Poller::Instance()->Poll(); \
+ if (tm.timed_out()) break; \
+ } \
+ } while (0)
+
+#define ASSERT_TRUE_WAIT(expression, timeout) \
+ do { \
+ WAIT_(expression, timeout); \
+ ASSERT_TRUE(expression); \
+ } while (0)
+
+#endif
diff --git a/nss/gtests/ssl_gtest/libssl_internals.c b/nss/gtests/ssl_gtest/libssl_internals.c
new file mode 100644
index 0000000..5136ee8
--- /dev/null
+++ b/nss/gtests/ssl_gtest/libssl_internals.c
@@ -0,0 +1,340 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+/* This file contains functions for frobbing the internals of libssl */
+#include "libssl_internals.h"
+
+#include "nss.h"
+#include "pk11pub.h"
+#include "seccomon.h"
+#include "ssl.h"
+#include "sslimpl.h"
+
+SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ++ss->clientHelloVersion;
+
+ return SECSuccess;
+}
+
+/* Use this function to update the ClientRandom of a client's handshake state
+ * after replacing its ClientHello message. We for example need to do this
+ * when replacing an SSLv3 ClientHello with its SSLv2 equivalent. */
+SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
+ size_t rnd_len, uint8_t *msg,
+ size_t msg_len) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ SECStatus rv = ssl3_InitState(ss);
+ if (rv != SECSuccess) {
+ return rv;
+ }
+
+ rv = ssl3_RestartHandshakeHashes(ss);
+ if (rv != SECSuccess) {
+ return rv;
+ }
+
+ // Ensure we don't overrun hs.client_random.
+ rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len);
+
+ // Zero the client_random struct.
+ PORT_Memset(&ss->ssl3.hs.client_random, 0, SSL3_RANDOM_LENGTH);
+
+ // Copy over the challenge bytes.
+ size_t offset = SSL3_RANDOM_LENGTH - rnd_len;
+ PORT_Memcpy(&ss->ssl3.hs.client_random.rand[offset], rnd, rnd_len);
+
+ // Rehash the SSLv2 client hello message.
+ return ssl3_UpdateHandshakeHashes(ss, msg, msg_len);
+}
+
+PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext));
+}
+
+void SSLInt_ClearSessionTicketKey() {
+ ssl3_SessionTicketShutdown(NULL, NULL);
+ NSS_UnregisterShutdown(ssl3_SessionTicketShutdown, NULL);
+}
+
+SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (ss) {
+ ss->ssl3.mtu = mtu;
+ return SECSuccess;
+ }
+ return SECFailure;
+}
+
+PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) {
+ PRCList *cur_p;
+ PRInt32 ct = 0;
+
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return -1;
+ }
+
+ for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
+ cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
+ ++ct;
+ }
+ return ct;
+}
+
+void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd) {
+ PRCList *cur_p;
+
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return;
+ }
+
+ fprintf(stderr, "Cipher specs\n");
+ for (cur_p = PR_NEXT_LINK(&ss->ssl3.hs.cipherSpecs);
+ cur_p != &ss->ssl3.hs.cipherSpecs; cur_p = PR_NEXT_LINK(cur_p)) {
+ ssl3CipherSpec *spec = (ssl3CipherSpec *)cur_p;
+ fprintf(stderr, " %s\n", spec->phase);
+ }
+}
+
+/* Force a timer expiry by backdating when the timer was started.
+ * We could set the remaining time to 0 but then backoff would not
+ * work properly if we decide to test it. */
+void SSLInt_ForceTimerExpiry(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return;
+ }
+
+ if (!ss->ssl3.hs.rtTimerCb) return;
+
+ ss->ssl3.hs.rtTimerStarted =
+ PR_IntervalNow() - PR_MillisecondsToInterval(ss->ssl3.hs.rtTimeoutMs + 1);
+}
+
+#define CHECK_SECRET(secret) \
+ if (ss->ssl3.hs.secret) { \
+ fprintf(stderr, "%s != NULL\n", #secret); \
+ return PR_FALSE; \
+ }
+
+PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ CHECK_SECRET(currentSecret);
+ CHECK_SECRET(resumptionMasterSecret);
+ CHECK_SECRET(dheSecret);
+ CHECK_SECRET(clientEarlyTrafficSecret);
+ CHECK_SECRET(clientHsTrafficSecret);
+ CHECK_SECRET(serverHsTrafficSecret);
+
+ return PR_TRUE;
+}
+
+PRBool sslint_DamageTrafficSecret(PRFileDesc *fd, size_t offset) {
+ unsigned char data[32] = {0};
+ PK11SymKey **keyPtr;
+ PK11SlotInfo *slot = PK11_GetInternalSlot();
+ SECItem key_item = {siBuffer, data, sizeof(data)};
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+ if (!slot) {
+ return PR_FALSE;
+ }
+ keyPtr = (PK11SymKey **)((char *)&ss->ssl3.hs + offset);
+ if (!*keyPtr) {
+ return PR_FALSE;
+ }
+ PK11_FreeSymKey(*keyPtr);
+ *keyPtr = PK11_ImportSymKey(slot, CKM_NSS_HKDF_SHA256, PK11_OriginUnwrap,
+ CKA_DERIVE, &key_item, NULL);
+ PK11_FreeSlot(slot);
+ if (!*keyPtr) {
+ return PR_FALSE;
+ }
+
+ return PR_TRUE;
+}
+
+PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, clientHsTrafficSecret));
+}
+
+PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, serverHsTrafficSecret));
+}
+
+PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd) {
+ return sslint_DamageTrafficSecret(
+ fd, offsetof(SSL3HandshakeState, clientEarlyTrafficSecret));
+}
+
+SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ss->xtnData.nextProtoState = SSL_NEXT_PROTO_EARLY_VALUE;
+ if (ss->xtnData.nextProto.data) {
+ SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE);
+ }
+ if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) return SECFailure;
+ PORT_Memcpy(ss->xtnData.nextProto.data, data, len);
+
+ return SECSuccess;
+}
+
+PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ return (PRBool)(!!ssl_FindServerCertByAuthType(ss, authType));
+}
+
+PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ SECStatus rv = SSL3_SendAlert(ss, level, type);
+ if (rv != SECSuccess) return PR_FALSE;
+
+ return PR_TRUE;
+}
+
+PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd) {
+ sslSocket *ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return PR_FALSE;
+ }
+
+ ssl_GetSSL3HandshakeLock(ss);
+ ssl_GetXmitBufLock(ss);
+
+ SECStatus rv = tls13_SendNewSessionTicket(ss);
+ if (rv == SECSuccess) {
+ rv = ssl3_FlushHandshake(ss, 0);
+ }
+
+ ssl_ReleaseXmitBufLock(ss);
+ ssl_ReleaseSSL3HandshakeLock(ss);
+
+ return rv == SECSuccess;
+}
+
+SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
+ PRUint64 epoch;
+ sslSocket *ss;
+ ssl3CipherSpec *spec;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ if (to >= (1ULL << 48)) {
+ return SECFailure;
+ }
+ ssl_GetSpecWriteLock(ss);
+ spec = ss->ssl3.crSpec;
+ epoch = spec->read_seq_num >> 48;
+ spec->read_seq_num = (epoch << 48) | to;
+
+ /* For DTLS, we need to fix the record sequence number. For this, we can just
+ * scrub the entire structure on the assumption that the new sequence number
+ * is far enough past the last received sequence number. */
+ if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
+ return SECFailure;
+ }
+ dtls_RecordSetRecvd(&spec->recvdRecords, to);
+
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECSuccess;
+}
+
+SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) {
+ PRUint64 epoch;
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ if (to >= (1ULL << 48)) {
+ return SECFailure;
+ }
+ ssl_GetSpecWriteLock(ss);
+ epoch = ss->ssl3.cwSpec->write_seq_num >> 48;
+ ss->ssl3.cwSpec->write_seq_num = (epoch << 48) | to;
+ ssl_ReleaseSpecWriteLock(ss);
+ return SECSuccess;
+}
+
+SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra) {
+ sslSocket *ss;
+ sslSequenceNumber to;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+ ssl_GetSpecReadLock(ss);
+ to = ss->ssl3.cwSpec->write_seq_num + DTLS_RECVD_RECORDS_WINDOW + extra;
+ ssl_ReleaseSpecReadLock(ss);
+ return SSLInt_AdvanceWriteSeqNum(fd, to & RECORD_SEQ_MAX);
+}
+
+SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) {
+ const sslNamedGroupDef *groupDef = ssl_LookupNamedGroup(group);
+ if (!groupDef) return ssl_kea_null;
+
+ return groupDef->keaType;
+}
+
+SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) {
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ss->opt.enableShortHeaders = PR_TRUE;
+ return SECSuccess;
+}
+
+SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) {
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ *result = ss->ssl3.hs.shortHeaders;
+
+ return SECSuccess;
+}
diff --git a/nss/gtests/ssl_gtest/libssl_internals.h b/nss/gtests/ssl_gtest/libssl_internals.h
new file mode 100644
index 0000000..6ea66db
--- /dev/null
+++ b/nss/gtests/ssl_gtest/libssl_internals.h
@@ -0,0 +1,43 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef libssl_internals_h_
+#define libssl_internals_h_
+
+#include <stdint.h>
+
+#include "prio.h"
+#include "seccomon.h"
+#include "sslt.h"
+
+SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd);
+
+SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
+ size_t rnd_len, uint8_t *msg,
+ size_t msg_len);
+
+PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext);
+void SSLInt_ClearSessionTicketKey();
+PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd);
+void SSLInt_PrintTls13CipherSpecs(PRFileDesc *fd);
+void SSLInt_ForceTimerExpiry(PRFileDesc *fd);
+SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu);
+PRBool SSLInt_CheckSecretsDestroyed(PRFileDesc *fd);
+PRBool SSLInt_DamageClientHsTrafficSecret(PRFileDesc *fd);
+PRBool SSLInt_DamageServerHsTrafficSecret(PRFileDesc *fd);
+PRBool SSLInt_DamageEarlyTrafficSecret(PRFileDesc *fd);
+SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len);
+PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType);
+PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type);
+PRBool SSLInt_SendNewSessionTicket(PRFileDesc *fd);
+SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to);
+SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to);
+SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra);
+SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group);
+SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd);
+SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result);
+
+#endif // ndef libssl_internals_h_
diff --git a/nss/gtests/ssl_gtest/manifest.mn b/nss/gtests/ssl_gtest/manifest.mn
new file mode 100644
index 0000000..391db81
--- /dev/null
+++ b/nss/gtests/ssl_gtest/manifest.mn
@@ -0,0 +1,54 @@
+#
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+CORE_DEPTH = ../..
+DEPTH = ../..
+MODULE = nss
+
+# These sources have access to libssl internals
+CSRCS = \
+ libssl_internals.c \
+ $(NULL)
+
+CPPSRCS = \
+ ssl_0rtt_unittest.cc \
+ ssl_agent_unittest.cc \
+ ssl_auth_unittest.cc \
+ ssl_cert_ext_unittest.cc \
+ ssl_ciphersuite_unittest.cc \
+ ssl_damage_unittest.cc \
+ ssl_dhe_unittest.cc \
+ ssl_drop_unittest.cc \
+ ssl_ecdh_unittest.cc \
+ ssl_ems_unittest.cc \
+ ssl_exporter_unittest.cc \
+ ssl_extension_unittest.cc \
+ ssl_fuzz_unittest.cc \
+ ssl_gtest.cc \
+ ssl_hrr_unittest.cc \
+ ssl_loopback_unittest.cc \
+ ssl_record_unittest.cc \
+ ssl_resumption_unittest.cc \
+ ssl_skip_unittest.cc \
+ ssl_staticrsa_unittest.cc \
+ ssl_v2_client_hello_unittest.cc \
+ ssl_version_unittest.cc \
+ test_io.cc \
+ tls_agent.cc \
+ tls_connect.cc \
+ tls_hkdf_unittest.cc \
+ tls_filter.cc \
+ tls_parser.cc \
+ $(NULL)
+
+INCLUDES += -I$(CORE_DEPTH)/gtests/google_test/gtest/include \
+ -I$(CORE_DEPTH)/gtests/common
+
+REQUIRES = nspr nss libdbm gtest
+
+PROGRAM = ssl_gtest
+EXTRA_LIBS = $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX) \
+ $(DIST)/lib/$(LIB_PREFIX)softokn.$(LIB_SUFFIX)
+
+USE_STATIC_LIBS = 1
diff --git a/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
new file mode 100644
index 0000000..cf5a27f
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -0,0 +1,203 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectTls13, ZeroRtt) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true);
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ZeroRttServerRejectByOption) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+// Test that we don't try to send 0-RTT data when the server sent
+// us a ticket without the 0-RTT flags.
+TEST_P(TlsConnectTls13, ZeroRttOptionsSetLate) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+ Reset();
+ server_->StartConnect();
+ client_->StartConnect();
+ // Now turn on 0-RTT but too late for the ticket.
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(false, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ZeroRttServerForgetTicket) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ClearServerCache();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ ZeroRttSendReceive(true, false);
+ Handshake();
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ZeroRttServerOnly) {
+ ExpectResumption(RESUME_NONE);
+ server_->Set0RttEnabled(true);
+ client_->StartConnect();
+ server_->StartConnect();
+
+ // Client sends ordinary ClientHello.
+ client_->Handshake();
+
+ // Verify that the server doesn't get data.
+ uint8_t buf[100];
+ PRInt32 rv = PR_Read(server_->ssl_fd(), buf, sizeof(buf));
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Now make sure that things complete.
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpn) {
+ EnableAlpn();
+ SetupForZeroRtt();
+ EnableAlpn();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ExpectEarlyDataAccepted(true);
+ ZeroRttSendReceive(true, true, [this]() {
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
+ return true;
+ });
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckAlpn("a");
+}
+
+// Have the server negotiate a different ALPN value, and therefore
+// reject 0-RTT.
+TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeServer) {
+ EnableAlpn();
+ SetupForZeroRtt();
+ static const uint8_t client_alpn[] = {0x01, 0x61, 0x01, 0x62}; // "a", "b"
+ static const uint8_t server_alpn[] = {0x01, 0x62}; // "b"
+ client_->EnableAlpn(client_alpn, sizeof(client_alpn));
+ server_->EnableAlpn(server_alpn, sizeof(server_alpn));
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false, [this]() {
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
+ return true;
+ });
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckAlpn("b");
+}
+
+// Check that the client validates the ALPN selection of the server.
+// Stomp the ALPN on the client after sending the ClientHello so
+// that the server selection appears to be incorrect. The client
+// should then fail the connection.
+TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) {
+ EnableAlpn();
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ EnableAlpn();
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true, [this]() {
+ PRUint8 b[] = {'b'};
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
+ EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b)));
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
+ return true;
+ });
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Set up with no ALPN and then set the client so it thinks it has ALPN.
+// The server responds without the extension and the client returns an
+// error.
+TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, true, [this]() {
+ PRUint8 b[] = {'b'};
+ EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1));
+ client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
+ return true;
+ });
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NEXT_PROTOCOL_DATA_INVALID);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+// Remove the old ALPN value and so the client will not offer early data.
+TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeBoth) {
+ EnableAlpn();
+ SetupForZeroRtt();
+ static const uint8_t alpn[] = {0x01, 0x62}; // "b"
+ EnableAlpn(alpn, sizeof(alpn));
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+ ZeroRttSendReceive(true, false, [this]() {
+ client_->CheckAlpn(SSL_NEXT_PROTO_NO_SUPPORT);
+ return false;
+ });
+ Handshake();
+ CheckConnected();
+ SendReceive();
+ CheckAlpn("b");
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
new file mode 100644
index 0000000..0e6ddae
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
@@ -0,0 +1,210 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+// This is an internal header, used to get TLS_1_3_DRAFT_VERSION.
+#include "ssl3prot.h"
+
+#include <memory>
+
+#include "databuffer.h"
+#include "tls_agent.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+static const uint8_t kD13 = TLS_1_3_DRAFT_VERSION;
+// This is a 1-RTT ClientHello with ECDHE.
+const static uint8_t kCannedTls13ClientHello[] = {
+ 0x01, 0x00, 0x00, 0xcf, 0x03, 0x03, 0x6c, 0xb3, 0x46, 0x81, 0xc8, 0x1a,
+ 0xf9, 0xd2, 0x05, 0x97, 0x48, 0x7c, 0xa8, 0x31, 0x03, 0x1c, 0x06, 0xa8,
+ 0x62, 0xb1, 0x90, 0xd6, 0x21, 0x44, 0x7f, 0xc1, 0x9b, 0x87, 0x3e, 0xad,
+ 0x91, 0x85, 0x00, 0x00, 0x06, 0x13, 0x01, 0x13, 0x03, 0x13, 0x02, 0x01,
+ 0x00, 0x00, 0xa0, 0x00, 0x00, 0x00, 0x0b, 0x00, 0x09, 0x00, 0x00, 0x06,
+ 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00,
+ 0x0a, 0x00, 0x12, 0x00, 0x10, 0x00, 0x17, 0x00, 0x18, 0x00, 0x19, 0x01,
+ 0x00, 0x01, 0x01, 0x01, 0x02, 0x01, 0x03, 0x01, 0x04, 0x00, 0x28, 0x00,
+ 0x47, 0x00, 0x45, 0x00, 0x17, 0x00, 0x41, 0x04, 0x86, 0x4a, 0xb9, 0xdc,
+ 0x6a, 0x38, 0xa7, 0xce, 0xe7, 0xc2, 0x4f, 0xa6, 0x28, 0xb9, 0xdc, 0x65,
+ 0xbf, 0x73, 0x47, 0x3c, 0x9c, 0x65, 0x8c, 0x47, 0x6d, 0x57, 0x22, 0x8a,
+ 0xc2, 0xb3, 0xc6, 0x80, 0x72, 0x86, 0x08, 0x86, 0x8f, 0x52, 0xc5, 0xcb,
+ 0xbf, 0x2a, 0xb5, 0x59, 0x64, 0xcc, 0x0c, 0x49, 0x95, 0x36, 0xe4, 0xd9,
+ 0x2f, 0xd4, 0x24, 0x66, 0x71, 0x6f, 0x5d, 0x70, 0xe2, 0xa0, 0xea, 0x26,
+ 0x00, 0x2b, 0x00, 0x03, 0x02, 0x7f, kD13, 0x00, 0x0d, 0x00, 0x20, 0x00,
+ 0x1e, 0x04, 0x03, 0x05, 0x03, 0x06, 0x03, 0x02, 0x03, 0x08, 0x04, 0x08,
+ 0x05, 0x08, 0x06, 0x04, 0x01, 0x05, 0x01, 0x06, 0x01, 0x02, 0x01, 0x04,
+ 0x02, 0x05, 0x02, 0x06, 0x02, 0x02, 0x02};
+
+const static uint8_t kCannedTls13ServerHello[] = {
+ 0x7f, kD13, 0x9c, 0xbc, 0x14, 0x9b, 0x0e, 0x2e, 0xfa, 0x0d, 0xf3, 0xf0,
+ 0x5c, 0x70, 0x7a, 0xe0, 0xd1, 0x9b, 0x3e, 0x5a, 0x44, 0x6b, 0xdf, 0xe5,
+ 0xc2, 0x28, 0x64, 0xf7, 0x00, 0xc1, 0x9c, 0x08, 0x76, 0x08, 0x13, 0x01,
+ 0x00, 0x28, 0x00, 0x28, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0xc2, 0xcf,
+ 0x23, 0x17, 0x64, 0x23, 0x03, 0xf0, 0xfb, 0x45, 0x98, 0x26, 0xd1, 0x65,
+ 0x24, 0xa1, 0x6c, 0xa9, 0x80, 0x8f, 0x2c, 0xac, 0x0a, 0xea, 0x53, 0x3a,
+ 0xcb, 0xe3, 0x08, 0x84, 0xae, 0x19};
+static const char *k0RttData = "ABCDEF";
+
+TEST_P(TlsAgentTest, EarlyFinished) {
+ DataBuffer buffer;
+ MakeTrivialHandshakeRecord(kTlsHandshakeFinished, 0, &buffer);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+TEST_P(TlsAgentTest, EarlyCertificateVerify) {
+ DataBuffer buffer;
+ MakeTrivialHandshakeRecord(kTlsHandshakeCertificateVerify, 0, &buffer);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(TlsAgentTestClient, CannedHello) {
+ DataBuffer buffer;
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ DataBuffer server_hello;
+ MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello,
+ sizeof(kCannedTls13ServerHello), &server_hello);
+ MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello.data(), server_hello.len(), &buffer);
+ ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+}
+
+TEST_P(TlsAgentTestClient, EncryptedExtensionsInClear) {
+ DataBuffer server_hello;
+ MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello,
+ sizeof(kCannedTls13ServerHello), &server_hello);
+ DataBuffer encrypted_extensions;
+ MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0,
+ &encrypted_extensions, 1);
+ server_hello.Append(encrypted_extensions);
+ DataBuffer buffer;
+ MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello.data(), server_hello.len(), &buffer);
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
+}
+
+TEST_F(TlsAgentStreamTestClient, EncryptedExtensionsInClearTwoPieces) {
+ DataBuffer server_hello;
+ MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello,
+ sizeof(kCannedTls13ServerHello), &server_hello);
+ DataBuffer encrypted_extensions;
+ MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0,
+ &encrypted_extensions, 1);
+ server_hello.Append(encrypted_extensions);
+ DataBuffer buffer;
+ MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello.data(), 20, &buffer);
+
+ DataBuffer buffer2;
+ MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello.data() + 20, server_hello.len() - 20, &buffer2);
+
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+ ProcessMessage(buffer2, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
+}
+
+TEST_F(TlsAgentDgramTestClient, EncryptedExtensionsInClearTwoPieces) {
+ DataBuffer server_hello_frag1;
+ MakeHandshakeMessageFragment(
+ kTlsHandshakeServerHello, kCannedTls13ServerHello,
+ sizeof(kCannedTls13ServerHello), &server_hello_frag1, 0, 0, 20);
+ DataBuffer server_hello_frag2;
+ MakeHandshakeMessageFragment(
+ kTlsHandshakeServerHello, kCannedTls13ServerHello + 20,
+ sizeof(kCannedTls13ServerHello), &server_hello_frag2, 0, 20,
+ sizeof(kCannedTls13ServerHello) - 20);
+ DataBuffer encrypted_extensions;
+ MakeHandshakeMessage(kTlsHandshakeEncryptedExtensions, nullptr, 0,
+ &encrypted_extensions, 1);
+ server_hello_frag2.Append(encrypted_extensions);
+ DataBuffer buffer;
+ MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello_frag1.data(), server_hello_frag1.len(), &buffer);
+
+ DataBuffer buffer2;
+ MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3,
+ server_hello_frag2.data(), server_hello_frag2.len(), &buffer2, 1);
+
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+ ProcessMessage(buffer2, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
+}
+
+TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) {
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ agent_->StartConnect();
+ agent_->Set0RttEnabled(true);
+ auto filter =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeClientHello);
+ agent_->SetPacketFilter(filter);
+ PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData));
+ EXPECT_EQ(-1, rv);
+ int32_t err = PORT_GetError();
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, err);
+ EXPECT_LT(0UL, filter->buffer().len());
+}
+
+TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenRead) {
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ agent_->StartConnect();
+ agent_->Set0RttEnabled(true);
+ DataBuffer buffer;
+ MakeRecord(kTlsApplicationDataType, SSL_LIBRARY_VERSION_TLS_1_3,
+ reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData),
+ &buffer);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
+}
+
+// The server is allowing 0-RTT but the client doesn't offer it,
+// so trial decryption isn't engaged and 0-RTT messages cause
+// an error.
+TEST_F(TlsAgentStreamTestServer, Set0RttOptionClientHelloThenRead) {
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ agent_->StartConnect();
+ agent_->Set0RttEnabled(true);
+ DataBuffer buffer;
+ MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3,
+ kCannedTls13ClientHello, sizeof(kCannedTls13ClientHello), &buffer);
+ ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+ MakeRecord(kTlsApplicationDataType, SSL_LIBRARY_VERSION_TLS_1_3,
+ reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData),
+ &buffer);
+ ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_BAD_MAC_READ);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ AgentTests, TlsAgentTest,
+ ::testing::Combine(TlsAgentTestBase::kTlsRolesAll,
+ TlsConnectTestBase::kTlsModesStream));
+INSTANTIATE_TEST_CASE_P(ClientTests, TlsAgentTestClient,
+ TlsConnectTestBase::kTlsModesAll);
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
new file mode 100644
index 0000000..e407d55
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -0,0 +1,736 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGeneric, ServerAuthBigRsa) {
+ Reset(TlsAgent::kRsa2048);
+ Connect();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectGeneric, ServerAuthRsaChain) {
+ Reset(TlsAgent::kServerRsaChain);
+ Connect();
+ CheckKeys();
+ size_t chain_length;
+ EXPECT_TRUE(client_->GetPeerChainLength(&chain_length));
+ EXPECT_EQ(2UL, chain_length);
+}
+
+TEST_P(TlsConnectGeneric, ClientAuth) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+}
+
+// In TLS 1.3, the client sends its cert rejection on the
+// second flight, and since it has already received the
+// server's Finished, it transitions to complete and
+// then gets an alert from the server. The test harness
+// doesn't handle this right yet.
+TEST_P(TlsConnectStream, DISABLED_ClientAuthRequiredRejected) {
+ server_->RequestClientAuth(true);
+ ConnectExpectFail();
+}
+
+TEST_P(TlsConnectGeneric, ClientAuthRequestedRejected) {
+ server_->RequestClientAuth(false);
+ Connect();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectGeneric, ClientAuthEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
+}
+
+TEST_P(TlsConnectGeneric, ClientAuthBigRsa) {
+ Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+}
+
+// Offset is the position in the captured buffer where the signature sits.
+static void CheckSigScheme(TlsInspectorRecordHandshakeMessage* capture,
+ size_t offset, TlsAgent* peer,
+ uint16_t expected_scheme, size_t expected_size) {
+ EXPECT_LT(offset + 2U, capture->buffer().len());
+
+ uint32_t scheme = 0;
+ capture->buffer().Read(offset, 2, &scheme);
+ EXPECT_EQ(expected_scheme, static_cast<uint16_t>(scheme));
+
+ ScopedCERTCertificate remote_cert(SSL_PeerCertificate(peer->ssl_fd()));
+ ScopedSECKEYPublicKey remote_key(CERT_ExtractPublicKey(remote_cert.get()));
+ EXPECT_EQ(expected_size, SECKEY_PublicKeyStrengthInBits(remote_key.get()));
+}
+
+// The server should prefer SHA-256 by default, even for the small key size used
+// in the default certificate.
+TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
+ EnsureTlsSetup();
+ auto capture_ske =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ server_->SetPacketFilter(capture_ske);
+ Connect();
+ CheckKeys();
+
+ const DataBuffer& buffer = capture_ske->buffer();
+ EXPECT_LT(3U, buffer.len());
+ EXPECT_EQ(3U, buffer.data()[0]) << "curve_type == named_curve";
+ uint32_t tmp;
+ EXPECT_TRUE(buffer.Read(1, 2, &tmp)) << "read NamedCurve";
+ EXPECT_EQ(ssl_grp_ec_curve25519, tmp);
+ EXPECT_TRUE(buffer.Read(3, 1, &tmp)) << " read ECPoint";
+ CheckSigScheme(capture_ske, 4 + tmp, client_, ssl_sig_rsa_pss_sha256, 1024);
+}
+
+TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
+ EnsureTlsSetup();
+ auto capture_cert_verify =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify);
+ client_->SetPacketFilter(capture_cert_verify);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+
+ CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pkcs1_sha1, 1024);
+}
+
+TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
+ Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
+ auto capture_cert_verify =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify);
+ client_->SetPacketFilter(capture_cert_verify);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ Connect();
+ CheckKeys();
+ CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_sha256, 2048);
+}
+
+static const SSLSignatureScheme SignatureSchemeEcdsaSha384[] = {
+ ssl_sig_ecdsa_secp384r1_sha384};
+static const SSLSignatureScheme SignatureSchemeEcdsaSha256[] = {
+ ssl_sig_ecdsa_secp256r1_sha256};
+static const SSLSignatureScheme SignatureSchemeRsaSha384[] = {
+ ssl_sig_rsa_pkcs1_sha384};
+static const SSLSignatureScheme SignatureSchemeRsaSha256[] = {
+ ssl_sig_rsa_pkcs1_sha256};
+
+static SSLNamedGroup NamedGroupForEcdsa384(uint16_t version) {
+ // NSS tries to match the group size to the symmetric cipher. In TLS 1.1 and
+ // 1.0, TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA is the highest priority suite, so
+ // we use P-384. With TLS 1.2 on we pick AES-128 GCM so use x25519.
+ if (version <= SSL_LIBRARY_VERSION_TLS_1_1) {
+ return ssl_grp_ec_secp384r1;
+ }
+ return ssl_grp_ec_curve25519;
+}
+
+// When signature algorithms match up, this should connect successfully; even
+// for TLS 1.1 and 1.0, where they should be ignored.
+TEST_P(TlsConnectGeneric, SignatureAlgorithmServerAuth) {
+ Reset(TlsAgent::kServerEcdsa384);
+ client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ Connect();
+ CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp384r1_sha384);
+}
+
+// Here the client picks a single option, which should work in all versions.
+// Defaults on the server include the first option.
+TEST_P(TlsConnectGeneric, SignatureAlgorithmClientOnly) {
+ const SSLSignatureAndHashAlg clientAlgorithms[] = {
+ {ssl_hash_sha384, ssl_sign_ecdsa},
+ {ssl_hash_sha384, ssl_sign_rsa}, // supported but unusable
+ {ssl_hash_md5, ssl_sign_ecdsa} // unsupported and ignored
+ };
+ Reset(TlsAgent::kServerEcdsa384);
+ EnsureTlsSetup();
+ // Use the old API for this function.
+ EXPECT_EQ(SECSuccess,
+ SSL_SignaturePrefSet(client_->ssl_fd(), clientAlgorithms,
+ PR_ARRAY_SIZE(clientAlgorithms)));
+ Connect();
+ CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp384r1_sha384);
+}
+
+// Here the server picks a single option, which should work in all versions.
+// Defaults on the client include the provided option.
+TEST_P(TlsConnectGeneric, SignatureAlgorithmServerOnly) {
+ Reset(TlsAgent::kServerEcdsa384);
+ server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ Connect();
+ CheckKeys(ssl_kea_ecdh, NamedGroupForEcdsa384(version_), ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp384r1_sha384);
+}
+
+TEST_P(TlsConnectTls12Plus, SignatureSchemeCurveMismatch) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsConnectTls12Plus, SignatureSchemeBadConfig) {
+ Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used.
+ server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Where there is no overlap on signature schemes, we still connect successfully
+// if we aren't going to use a signature.
+TEST_P(TlsConnectGenericPre13, SignatureAlgorithmNoOverlapStaticRsa) {
+ client_->SetSignatureSchemes(SignatureSchemeRsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeRsaSha384));
+ server_->SetSignatureSchemes(SignatureSchemeRsaSha256,
+ PR_ARRAY_SIZE(SignatureSchemeRsaSha256));
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_auth_rsa_decrypt);
+}
+
+TEST_P(TlsConnectTls12Plus, SignatureAlgorithmNoOverlapEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ server_->SetSignatureSchemes(SignatureSchemeEcdsaSha256,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha256));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
+}
+
+// Pre 1.2, a mismatch on signature algorithms shouldn't affect anything.
+TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ server_->SetSignatureSchemes(SignatureSchemeEcdsaSha256,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha256));
+ Connect();
+}
+
+// The signature_algorithms extension is mandatory in TLS 1.3.
+TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
+ client_->SetPacketFilter(
+ new TlsExtensionDropper(ssl_signature_algorithms_xtn));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
+}
+
+// TLS 1.2 has trouble detecting this sort of modification: it uses SHA1 and
+// only fails when the Finished is checked.
+TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) {
+ client_->SetPacketFilter(
+ new TlsExtensionDropper(ssl_signature_algorithms_xtn));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+TEST_P(TlsConnectTls12Plus, RequestClientAuthWithSha384) {
+ server_->SetSignatureSchemes(SignatureSchemeRsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeRsaSha384));
+ server_->RequestClientAuth(false);
+ Connect();
+}
+
+class BeforeFinished : public TlsRecordFilter {
+ private:
+ enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE };
+
+ public:
+ BeforeFinished(TlsAgent* client, TlsAgent* server, VoidFunction before_ccs,
+ VoidFunction before_finished)
+ : client_(client),
+ server_(server),
+ before_ccs_(before_ccs),
+ before_finished_(before_finished),
+ state_(BEFORE_CCS) {}
+
+ protected:
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& body,
+ DataBuffer* out) {
+ switch (state_) {
+ case BEFORE_CCS:
+ // Awaken when we see the CCS.
+ if (header.content_type() == kTlsChangeCipherSpecType) {
+ before_ccs_();
+
+ // Write the CCS out as a separate write, so that we can make
+ // progress. Ordinarily, libssl sends the CCS and Finished together,
+ // but that means that they both get processed together.
+ DataBuffer ccs;
+ header.Write(&ccs, 0, body);
+ server_->SendDirect(ccs);
+ client_->Handshake();
+ state_ = AFTER_CCS;
+ // Request that the original record be dropped by the filter.
+ return DROP;
+ }
+ break;
+
+ case AFTER_CCS:
+ EXPECT_EQ(kTlsHandshakeType, header.content_type());
+ // This could check that data contains a Finished message, but it's
+ // encrypted, so that's too much extra work.
+
+ before_finished_();
+ state_ = DONE;
+ break;
+
+ case DONE:
+ break;
+ }
+ return KEEP;
+ }
+
+ private:
+ TlsAgent* client_;
+ TlsAgent* server_;
+ VoidFunction before_ccs_;
+ VoidFunction before_finished_;
+ HandshakeState state_;
+};
+
+// Running code after the client has started processing the encrypted part of
+// the server's first flight, but before the Finished is processed is very hard
+// in TLS 1.3. These encrypted messages are sent in a single encrypted blob.
+// The following test uses DTLS to make it possible to force the client to
+// process the handshake in pieces.
+//
+// The first encrypted message from the server is dropped, and the MTU is
+// reduced to just below the original message size so that the server sends two
+// messages. The Finished message is then processed separately.
+class BeforeFinished13 : public PacketFilter {
+ private:
+ enum HandshakeState {
+ INIT,
+ BEFORE_FIRST_FRAGMENT,
+ BEFORE_SECOND_FRAGMENT,
+ DONE
+ };
+
+ public:
+ BeforeFinished13(TlsAgent* client, TlsAgent* server,
+ VoidFunction before_finished)
+ : client_(client),
+ server_(server),
+ before_finished_(before_finished),
+ records_(0) {}
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ switch (++records_) {
+ case 1:
+ // Packet 1 is the server's entire first flight. Drop it.
+ EXPECT_EQ(SECSuccess,
+ SSLInt_SetMTU(server_->ssl_fd(), input.len() - 1));
+ return DROP;
+
+ // Packet 2 is the first part of the server's retransmitted first
+ // flight. Keep that.
+
+ case 3:
+ // Packet 3 is the second part of the server's retransmitted first
+ // flight. Before passing that on, make sure that the client processes
+ // packet 2, then call the before_finished_() callback.
+ client_->Handshake();
+ before_finished_();
+ break;
+
+ default:
+ break;
+ }
+ return KEEP;
+ }
+
+ private:
+ TlsAgent* client_;
+ TlsAgent* server_;
+ VoidFunction before_finished_;
+ size_t records_;
+};
+
+static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
+ return SECWouldBlock;
+}
+
+// This test uses an AuthCertificateCallback that blocks. A filter is used to
+// split the server's first flight into two pieces. Before the second piece is
+// processed by the client, SSL_AuthCertificateComplete() is called.
+TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+ server_->SetPacketFilter(new BeforeFinished13(client_, server_, [this]() {
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ }));
+ Connect();
+}
+
+static void TriggerAuthComplete(PollTarget* target, Event event) {
+ std::cerr << "client: call SSL_AuthCertificateComplete" << std::endl;
+ EXPECT_EQ(TIMER_EVENT, event);
+ TlsAgent* client = static_cast<TlsAgent*>(target);
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client->ssl_fd(), 0));
+}
+
+// This test uses a simple AuthCertificateCallback. Due to the way that the
+// entire server flight is processed, the call to SSL_AuthCertificateComplete
+// will trigger after the Finished message is processed.
+TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) {
+ client_->SetAuthCertificateCallback(
+ [this](TlsAgent*, PRBool, PRBool) -> SECStatus {
+ Poller::Timer* timer_handle;
+ // This is really just to unroll the stack.
+ Poller::Instance()->SetTimer(1U, client_, TriggerAuthComplete,
+ &timer_handle);
+ return SECWouldBlock;
+ });
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
+ client_->EnableFalseStart();
+ server_->SetPacketFilter(new BeforeFinished(
+ client_, server_,
+ [this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); },
+ [this]() {
+ // Write something, which used to fail: bug 1235366.
+ client_->SendData(10);
+ }));
+
+ Connect();
+ server_->SendData(10);
+ Receive(10);
+}
+
+TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) {
+ client_->EnableFalseStart();
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+ server_->SetPacketFilter(new BeforeFinished(
+ client_, server_,
+ []() {
+ // Do nothing before CCS
+ },
+ [this]() {
+ EXPECT_FALSE(client_->can_falsestart_hook_called());
+ // AuthComplete before Finished still enables false start.
+ EXPECT_EQ(SECSuccess,
+ SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ EXPECT_TRUE(client_->can_falsestart_hook_called());
+ client_->SendData(10);
+ }));
+
+ Connect();
+ server_->SendData(10);
+ Receive(10);
+}
+
+class EnforceNoActivity : public PacketFilter {
+ protected:
+ PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ std::cerr << "Unexpected packet: " << input << std::endl;
+ EXPECT_TRUE(false) << "should not send anything";
+ return KEEP;
+ }
+};
+
+// In this test, we want to make sure that the server completes its handshake,
+// but the client does not. Because the AuthCertificate callback blocks and we
+// never call SSL_AuthCertificateComplete(), the client should never report that
+// it has completed the handshake. Manually call Handshake(), alternating sides
+// between client and server, until the desired state is reached.
+TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ server_->StartConnect();
+ client_->StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+ client_->Handshake(); // Send ClientKeyExchange and Finished
+ server_->Handshake(); // Send Finished
+ // The server should now report that it is connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ // The client should send nothing from here on.
+ client_->SetPacketFilter(new EnforceNoActivity());
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // This should allow the handshake to complete now.
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ client_->Handshake(); // Transition to connected
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ // Remove this before closing or the close_notify alert will trigger it.
+ client_->SetPacketFilter(nullptr);
+}
+
+// TLS 1.3 handles a delayed AuthComplete callback differently since the
+// shape of the handshake is different.
+TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
+ client_->SetAuthCertificateCallback(AuthCompleteBlock);
+
+ server_->StartConnect();
+ client_->StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send ServerHello
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
+
+ // The client will send nothing until AuthCertificateComplete is called.
+ client_->SetPacketFilter(new EnforceNoActivity());
+ client_->Handshake();
+ EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
+
+ // This should allow the handshake to complete now.
+ client_->SetPacketFilter(nullptr);
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ client_->Handshake(); // Send Finished
+ server_->Handshake(); // Transition to connected and send NewSessionTicket
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+}
+
+static const SSLExtraServerCertData ServerCertDataRsaPkcs1Decrypt = {
+ ssl_auth_rsa_decrypt, nullptr, nullptr, nullptr};
+static const SSLExtraServerCertData ServerCertDataRsaPkcs1Sign = {
+ ssl_auth_rsa_sign, nullptr, nullptr, nullptr};
+static const SSLExtraServerCertData ServerCertDataRsaPss = {
+ ssl_auth_rsa_pss, nullptr, nullptr, nullptr};
+
+// Test RSA cert with usage=[signature, encipherment].
+TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1SignAndKEX) {
+ Reset(TlsAgent::kServerRsa);
+
+ PRFileDesc* ssl_fd = agent_->ssl_fd();
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt));
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign));
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss));
+
+ // Configuring for only rsa_sign, rsa_pss, or rsa_decrypt should work.
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false,
+ &ServerCertDataRsaPkcs1Decrypt));
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false,
+ &ServerCertDataRsaPkcs1Sign));
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsa, false,
+ &ServerCertDataRsaPss));
+}
+
+// Test RSA cert with usage=[signature].
+TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1Sign) {
+ Reset(TlsAgent::kServerRsaSign);
+
+ PRFileDesc* ssl_fd = agent_->ssl_fd();
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt));
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign));
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss));
+
+ // Configuring for only rsa_decrypt should fail.
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false,
+ &ServerCertDataRsaPkcs1Decrypt));
+
+ // Configuring for only rsa_sign or rsa_pss should work.
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false,
+ &ServerCertDataRsaPkcs1Sign));
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaSign, false,
+ &ServerCertDataRsaPss));
+}
+
+// Test RSA cert with usage=[encipherment].
+TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPkcs1KEX) {
+ Reset(TlsAgent::kServerRsaDecrypt);
+
+ PRFileDesc* ssl_fd = agent_->ssl_fd();
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt));
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign));
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss));
+
+ // Configuring for only rsa_sign or rsa_pss should fail.
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false,
+ &ServerCertDataRsaPkcs1Sign));
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false,
+ &ServerCertDataRsaPss));
+
+ // Configuring for only rsa_decrypt should work.
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaDecrypt, false,
+ &ServerCertDataRsaPkcs1Decrypt));
+}
+
+// Test configuring an RSA-PSS cert.
+TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPss) {
+ Reset(TlsAgent::kServerRsaPss);
+
+ PRFileDesc* ssl_fd = agent_->ssl_fd();
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_decrypt));
+ EXPECT_FALSE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_sign));
+ EXPECT_TRUE(SSLInt_HasCertWithAuthType(ssl_fd, ssl_auth_rsa_pss));
+
+ // Configuring for only rsa_sign or rsa_decrypt should fail.
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false,
+ &ServerCertDataRsaPkcs1Sign));
+ EXPECT_FALSE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false,
+ &ServerCertDataRsaPkcs1Decrypt));
+
+ // Configuring for only rsa_pss should work.
+ EXPECT_TRUE(agent_->ConfigServerCert(TlsAgent::kServerRsaPss, false,
+ &ServerCertDataRsaPss));
+}
+
+// mode, version, certificate, auth type, signature scheme
+typedef std::tuple<std::string, uint16_t, std::string, SSLAuthType,
+ SSLSignatureScheme>
+ SignatureSchemeProfile;
+
+class TlsSignatureSchemeConfiguration
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SignatureSchemeProfile> {
+ public:
+ TlsSignatureSchemeConfiguration()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())),
+ certificate_(std::get<2>(GetParam())),
+ auth_type_(std::get<3>(GetParam())),
+ signature_scheme_(std::get<4>(GetParam())) {}
+
+ protected:
+ void TestSignatureSchemeConfig(TlsAgent* configPeer) {
+ EnsureTlsSetup();
+ configPeer->SetSignatureSchemes(&signature_scheme_, 1);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_,
+ signature_scheme_);
+ }
+
+ std::string certificate_;
+ SSLAuthType auth_type_;
+ SSLSignatureScheme signature_scheme_;
+};
+
+TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) {
+ Reset(certificate_);
+ TestSignatureSchemeConfig(server_);
+}
+
+TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) {
+ Reset(certificate_);
+ TlsExtensionCapture* capture =
+ new TlsExtensionCapture(ssl_signature_algorithms_xtn);
+ client_->SetPacketFilter(capture);
+ TestSignatureSchemeConfig(client_);
+
+ const DataBuffer& ext = capture->extension();
+ ASSERT_EQ(2U + 2U, ext.len());
+ uint32_t v = 0;
+ ASSERT_TRUE(ext.Read(0, 2, &v));
+ EXPECT_EQ(2U, v);
+ ASSERT_TRUE(ext.Read(2, 2, &v));
+ EXPECT_EQ(signature_scheme_, static_cast<SSLSignatureScheme>(v));
+}
+
+TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigBoth) {
+ Reset(certificate_);
+ EnsureTlsSetup();
+ client_->SetSignatureSchemes(&signature_scheme_, 1);
+ server_->SetSignatureSchemes(&signature_scheme_, 1);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, auth_type_, signature_scheme_);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ SignatureSchemeRsa, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(
+ TlsConnectTestBase::kTlsModesAll, TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kServerRsaSign),
+ ::testing::Values(ssl_auth_rsa_sign),
+ ::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384,
+ ssl_sig_rsa_pkcs1_sha512, ssl_sig_rsa_pss_sha256,
+ ssl_sig_rsa_pss_sha384)));
+// PSS with SHA-512 needs a bigger key to work.
+INSTANTIATE_TEST_CASE_P(
+ SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kRsa2048),
+ ::testing::Values(ssl_auth_rsa_sign),
+ ::testing::Values(ssl_sig_rsa_pss_sha512)));
+INSTANTIATE_TEST_CASE_P(
+ SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12,
+ ::testing::Values(TlsAgent::kServerRsa),
+ ::testing::Values(ssl_auth_rsa_sign),
+ ::testing::Values(ssl_sig_rsa_pkcs1_sha1)));
+INSTANTIATE_TEST_CASE_P(
+ SignatureSchemeEcdsaP256, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kServerEcdsa256),
+ ::testing::Values(ssl_auth_ecdsa),
+ ::testing::Values(ssl_sig_ecdsa_secp256r1_sha256)));
+INSTANTIATE_TEST_CASE_P(
+ SignatureSchemeEcdsaP384, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kServerEcdsa384),
+ ::testing::Values(ssl_auth_ecdsa),
+ ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384)));
+INSTANTIATE_TEST_CASE_P(
+ SignatureSchemeEcdsaP521, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12Plus,
+ ::testing::Values(TlsAgent::kServerEcdsa521),
+ ::testing::Values(ssl_auth_ecdsa),
+ ::testing::Values(ssl_sig_ecdsa_secp521r1_sha512)));
+INSTANTIATE_TEST_CASE_P(
+ SignatureSchemeEcdsaSha1, TlsSignatureSchemeConfiguration,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12,
+ ::testing::Values(TlsAgent::kServerEcdsa256,
+ TlsAgent::kServerEcdsa384),
+ ::testing::Values(ssl_auth_ecdsa),
+ ::testing::Values(ssl_sig_ecdsa_sha1)));
+}
diff --git a/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
new file mode 100644
index 0000000..876c368
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
@@ -0,0 +1,214 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// Tests for Certificate Transparency (RFC 6962)
+// These don't work with TLS 1.3: see bug 1252745.
+
+// Helper class - stores signed certificate timestamps as provided
+// by the relevant callbacks on the client.
+class SignedCertificateTimestampsExtractor {
+ public:
+ SignedCertificateTimestampsExtractor(TlsAgent* client) : client_(client) {
+ client_->SetAuthCertificateCallback(
+ [&](TlsAgent* agent, bool checksig, bool isServer) -> SECStatus {
+ const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd());
+ EXPECT_TRUE(scts);
+ if (!scts) {
+ return SECFailure;
+ }
+ auth_timestamps_.reset(new DataBuffer(scts->data, scts->len));
+ return SECSuccess;
+ });
+ client_->SetHandshakeCallback([&](TlsAgent* agent) {
+ const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd());
+ ASSERT_TRUE(scts);
+ handshake_timestamps_.reset(new DataBuffer(scts->data, scts->len));
+ });
+ }
+
+ void assertTimestamps(const DataBuffer& timestamps) {
+ EXPECT_TRUE(auth_timestamps_);
+ EXPECT_EQ(timestamps, *auth_timestamps_);
+
+ EXPECT_TRUE(handshake_timestamps_);
+ EXPECT_EQ(timestamps, *handshake_timestamps_);
+
+ const SECItem* current = SSL_PeerSignedCertTimestamps(client_->ssl_fd());
+ EXPECT_EQ(timestamps, DataBuffer(current->data, current->len));
+ }
+
+ private:
+ TlsAgent* client_;
+ std::unique_ptr<DataBuffer> auth_timestamps_;
+ std::unique_ptr<DataBuffer> handshake_timestamps_;
+};
+
+static const uint8_t kSctValue[] = {0x01, 0x23, 0x45, 0x67, 0x89};
+static const SECItem kSctItem = {siBuffer, const_cast<uint8_t*>(kSctValue),
+ sizeof(kSctValue)};
+static const DataBuffer kSctBuffer(kSctValue, sizeof(kSctValue));
+
+// Test timestamps extraction during a successful handshake.
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsHandshake) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(),
+ &kSctItem, ssl_kea_rsa));
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
+ PR_TRUE));
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+
+ timestamps_extractor.assertTimestamps(kSctBuffer);
+}
+
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsConfig) {
+ static const SSLExtraServerCertData kExtraData = {ssl_auth_rsa_sign, nullptr,
+ nullptr, &kSctItem};
+
+ EnsureTlsSetup();
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraData));
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
+ PR_TRUE));
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+
+ timestamps_extractor.assertTimestamps(kSctBuffer);
+}
+
+// Test SSL_PeerSignedCertTimestamps returning zero-length SECItem
+// when the client / the server / both have not enabled the feature.
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(),
+ &kSctItem, ssl_kea_rsa));
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+ timestamps_extractor.assertTimestamps(DataBuffer());
+}
+
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveServer) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
+ PR_TRUE));
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+ timestamps_extractor.assertTimestamps(DataBuffer());
+}
+
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveBoth) {
+ EnsureTlsSetup();
+ SignedCertificateTimestampsExtractor timestamps_extractor(client_);
+
+ Connect();
+ timestamps_extractor.assertTimestamps(DataBuffer());
+}
+
+// Check that the given agent doesn't have an OCSP response for its peer.
+static SECStatus CheckNoOCSP(TlsAgent* agent, bool checksig, bool isServer) {
+ const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd());
+ EXPECT_TRUE(ocsp);
+ EXPECT_EQ(0U, ocsp->len);
+ return SECSuccess;
+}
+
+static const uint8_t kOcspValue1[] = {1, 2, 3, 4, 5, 6};
+static const uint8_t kOcspValue2[] = {7, 8, 9};
+static const SECItem kOcspItems[] = {
+ {siBuffer, const_cast<uint8_t*>(kOcspValue1), sizeof(kOcspValue1)},
+ {siBuffer, const_cast<uint8_t*>(kOcspValue2), sizeof(kOcspValue2)}};
+static const SECItemArray kOcspResponses = {const_cast<SECItem*>(kOcspItems),
+ PR_ARRAY_SIZE(kOcspItems)};
+const static SSLExtraServerCertData kOcspExtraData = {
+ ssl_auth_rsa_sign, nullptr, &kOcspResponses, nullptr};
+
+TEST_P(TlsConnectGeneric, NoOcsp) {
+ EnsureTlsSetup();
+ client_->SetAuthCertificateCallback(CheckNoOCSP);
+ Connect();
+}
+
+// The client doesn't get OCSP stapling unless it asks.
+TEST_P(TlsConnectGeneric, OcspNotRequested) {
+ EnsureTlsSetup();
+ client_->SetAuthCertificateCallback(CheckNoOCSP);
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
+ Connect();
+}
+
+// Even if the client asks, the server has nothing unless it is configured.
+TEST_P(TlsConnectGeneric, OcspNotProvided) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ client_->SetAuthCertificateCallback(CheckNoOCSP);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, OcspMangled) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
+
+ static const uint8_t val[] = {1};
+ auto replacer = new TlsExtensionReplacer(ssl_cert_status_xtn,
+ DataBuffer(val, sizeof(val)));
+ server_->SetPacketFilter(replacer);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+TEST_P(TlsConnectGeneric, OcspSuccess) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+ auto capture_ocsp = new TlsExtensionCapture(ssl_cert_status_xtn);
+ server_->SetPacketFilter(capture_ocsp);
+
+ // The value should be available during the AuthCertificateCallback
+ client_->SetAuthCertificateCallback([](TlsAgent* agent, bool checksig,
+ bool isServer) -> SECStatus {
+ const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd());
+ if (!ocsp) {
+ return SECFailure;
+ }
+ EXPECT_EQ(1U, ocsp->len) << "We only provide the first item";
+ EXPECT_EQ(0, SECITEM_CompareItem(&kOcspItems[0], &ocsp->items[0]));
+ return SECSuccess;
+ });
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
+
+ Connect();
+ // In TLS 1.3, the server doesn't provide a visible ServerHello extension.
+ // For earlier versions, the extension is just empty.
+ EXPECT_EQ(0U, capture_ocsp->extension().len());
+}
+
+} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
new file mode 100644
index 0000000..ab10a84
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -0,0 +1,455 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// mode, version, cipher suite
+typedef std::tuple<std::string, uint16_t, uint16_t, SSLNamedGroup,
+ SSLSignatureScheme>
+ CipherSuiteProfile;
+
+class TlsCipherSuiteTestBase : public TlsConnectTestBase {
+ public:
+ TlsCipherSuiteTestBase(const std::string &mode, uint16_t version,
+ uint16_t cipher_suite, SSLNamedGroup group,
+ SSLSignatureScheme signature_scheme)
+ : TlsConnectTestBase(mode, version),
+ cipher_suite_(cipher_suite),
+ group_(group),
+ signature_scheme_(signature_scheme),
+ csinfo_({0}) {
+ SECStatus rv =
+ SSL_GetCipherSuiteInfo(cipher_suite_, &csinfo_, sizeof(csinfo_));
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv == SECSuccess) {
+ std::cerr << "Cipher suite: " << csinfo_.cipherSuiteName << std::endl;
+ }
+ auth_type_ = csinfo_.authType;
+ kea_type_ = csinfo_.keaType;
+ }
+
+ protected:
+ void EnableSingleCipher() {
+ EnsureTlsSetup();
+ // It doesn't matter which does this, but the test is better if both do it.
+ client_->EnableSingleCipher(cipher_suite_);
+ server_->EnableSingleCipher(cipher_suite_);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ std::vector<SSLNamedGroup> groups = {group_};
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ kea_type_ = SSLInt_GetKEAType(group_);
+
+ client_->SetSignatureSchemes(&signature_scheme_, 1);
+ server_->SetSignatureSchemes(&signature_scheme_, 1);
+ }
+ }
+
+ virtual void SetupCertificate() {
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ switch (signature_scheme_) {
+ case ssl_sig_rsa_pkcs1_sha256:
+ case ssl_sig_rsa_pkcs1_sha384:
+ case ssl_sig_rsa_pkcs1_sha512:
+ Reset(TlsAgent::kServerRsaSign);
+ auth_type_ = ssl_auth_rsa_sign;
+ break;
+ case ssl_sig_rsa_pss_sha256:
+ case ssl_sig_rsa_pss_sha384:
+ Reset(TlsAgent::kServerRsaSign);
+ auth_type_ = ssl_auth_rsa_sign;
+ break;
+ case ssl_sig_rsa_pss_sha512:
+ // You can't fit SHA-512 PSS in a 1024-bit key.
+ Reset(TlsAgent::kRsa2048);
+ auth_type_ = ssl_auth_rsa_sign;
+ break;
+ case ssl_sig_ecdsa_secp256r1_sha256:
+ Reset(TlsAgent::kServerEcdsa256);
+ auth_type_ = ssl_auth_ecdsa;
+ break;
+ case ssl_sig_ecdsa_secp384r1_sha384:
+ Reset(TlsAgent::kServerEcdsa384);
+ auth_type_ = ssl_auth_ecdsa;
+ break;
+ default:
+ ASSERT_TRUE(false) << "Unsupported signature scheme: "
+ << signature_scheme_;
+ break;
+ }
+ } else {
+ switch (csinfo_.authType) {
+ case ssl_auth_rsa_sign:
+ Reset(TlsAgent::kServerRsaSign);
+ break;
+ case ssl_auth_rsa_decrypt:
+ Reset(TlsAgent::kServerRsaDecrypt);
+ break;
+ case ssl_auth_ecdsa:
+ Reset(TlsAgent::kServerEcdsa256);
+ break;
+ case ssl_auth_ecdh_ecdsa:
+ Reset(TlsAgent::kServerEcdhEcdsa);
+ break;
+ case ssl_auth_ecdh_rsa:
+ Reset(TlsAgent::kServerEcdhRsa);
+ break;
+ case ssl_auth_dsa:
+ Reset(TlsAgent::kServerDsa);
+ break;
+ default:
+ ASSERT_TRUE(false) << "Unsupported cipher suite: " << cipher_suite_;
+ break;
+ }
+ }
+ }
+
+ void ConnectAndCheckCipherSuite() {
+ Connect();
+ SendReceive();
+
+ // Check that we used the right cipher suite.
+ uint16_t actual;
+ EXPECT_TRUE(client_->cipher_suite(&actual) && actual == cipher_suite_);
+ EXPECT_TRUE(server_->cipher_suite(&actual) && actual == cipher_suite_);
+ SSLAuthType auth;
+ EXPECT_TRUE(client_->auth_type(&auth) && auth == auth_type_);
+ EXPECT_TRUE(server_->auth_type(&auth) && auth == auth_type_);
+ SSLKEAType kea;
+ EXPECT_TRUE(client_->kea_type(&kea) && kea == kea_type_);
+ EXPECT_TRUE(server_->kea_type(&kea) && kea == kea_type_);
+ }
+
+ // Get the expected limit on the number of records that can be sent for the
+ // cipher suite.
+ uint64_t record_limit() const {
+ switch (csinfo_.symCipher) {
+ case ssl_calg_rc4:
+ case ssl_calg_3des:
+ return 1ULL << 20;
+ case ssl_calg_aes:
+ case ssl_calg_aes_gcm:
+ return 0x5aULL << 28;
+ case ssl_calg_null:
+ case ssl_calg_chacha20:
+ return (1ULL << 48) - 1;
+ case ssl_calg_rc2:
+ case ssl_calg_des:
+ case ssl_calg_idea:
+ case ssl_calg_fortezza:
+ case ssl_calg_camellia:
+ case ssl_calg_seed:
+ break;
+ }
+ EXPECT_TRUE(false) << "No limit for " << csinfo_.cipherSuiteName;
+ return 1ULL < 48;
+ }
+
+ uint64_t last_safe_write() const {
+ uint64_t limit = record_limit() - 1;
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_1 &&
+ (csinfo_.symCipher == ssl_calg_3des ||
+ csinfo_.symCipher == ssl_calg_aes)) {
+ // 1/n-1 record splitting needs space for two records.
+ limit--;
+ }
+ return limit;
+ }
+
+ protected:
+ uint16_t cipher_suite_;
+ SSLAuthType auth_type_;
+ SSLKEAType kea_type_;
+ SSLNamedGroup group_;
+ SSLSignatureScheme signature_scheme_;
+ SSLCipherSuiteInfo csinfo_;
+};
+
+class TlsCipherSuiteTest
+ : public TlsCipherSuiteTestBase,
+ public ::testing::WithParamInterface<CipherSuiteProfile> {
+ public:
+ TlsCipherSuiteTest()
+ : TlsCipherSuiteTestBase(std::get<0>(GetParam()), std::get<1>(GetParam()),
+ std::get<2>(GetParam()), std::get<3>(GetParam()),
+ std::get<4>(GetParam())) {}
+
+ protected:
+ bool SkipIfCipherSuiteIsDSA() {
+ bool isDSA = csinfo_.authType == ssl_auth_dsa;
+ if (isDSA) {
+ std::cerr << "Skipping DSA suite: " << csinfo_.cipherSuiteName
+ << std::endl;
+ }
+ return isDSA;
+ }
+};
+
+TEST_P(TlsCipherSuiteTest, SingleCipherSuite) {
+ SetupCertificate();
+ EnableSingleCipher();
+ ConnectAndCheckCipherSuite();
+}
+
+TEST_P(TlsCipherSuiteTest, ResumeCipherSuite) {
+ if (SkipIfCipherSuiteIsDSA()) {
+ return; // Tickets don't work with DSA (bug 1174677).
+ }
+
+ SetupCertificate(); // This is only needed once.
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ EnableSingleCipher();
+
+ ConnectAndCheckCipherSuite();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ EnableSingleCipher();
+ ExpectResumption(RESUME_TICKET);
+ ConnectAndCheckCipherSuite();
+}
+
+// This only works for stream ciphers because we modify the sequence number -
+// which is included explicitly in the DTLS record header - and that trips a
+// different error code. Note that the message that the client sends would not
+// decrypt (the nonce/IV wouldn't match), but the record limit is hit before
+// attempting to decrypt a record.
+TEST_P(TlsCipherSuiteTest, ReadLimit) {
+ SetupCertificate();
+ EnableSingleCipher();
+ ConnectAndCheckCipherSuite();
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write()));
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceReadSeqNum(server_->ssl_fd(), last_safe_write()));
+
+ client_->SendData(10, 10);
+ server_->ReadBytes(); // This should be OK.
+
+ // The payload needs to be big enough to pass for encrypted. In the extreme
+ // case (TLS 1.3), this means 1 for payload, 1 for content type and 16 for
+ // authentication tag.
+ static const uint8_t payload[18] = {6};
+ DataBuffer record;
+ uint64_t epoch = 0;
+ if (mode_ == DGRAM) {
+ epoch++;
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ epoch++;
+ }
+ }
+ TlsAgentTestBase::MakeRecord(mode_, kTlsApplicationDataType, version_,
+ payload, sizeof(payload), &record,
+ (epoch << 48) | record_limit());
+ server_->adapter()->PacketReceived(record);
+ server_->ExpectReadWriteError();
+ server_->ReadBytes();
+ EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, server_->error_code());
+}
+
+TEST_P(TlsCipherSuiteTest, WriteLimit) {
+ SetupCertificate();
+ EnableSingleCipher();
+ ConnectAndCheckCipherSuite();
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), last_safe_write()));
+ client_->SendData(10, 10);
+ client_->ExpectReadWriteError();
+ client_->SendData(10, 10);
+ EXPECT_EQ(SSL_ERROR_TOO_MANY_RECORDS, client_->error_code());
+}
+
+// This awful macro makes the test instantiations easier to read.
+#define INSTANTIATE_CIPHER_TEST_P(name, modes, versions, groups, sigalgs, ...) \
+ static const uint16_t k##name##CiphersArr[] = {__VA_ARGS__}; \
+ static const ::testing::internal::ParamGenerator<uint16_t> \
+ k##name##Ciphers = ::testing::ValuesIn(k##name##CiphersArr); \
+ INSTANTIATE_TEST_CASE_P( \
+ CipherSuite##name, TlsCipherSuiteTest, \
+ ::testing::Combine(TlsConnectTestBase::kTlsModes##modes, \
+ TlsConnectTestBase::kTls##versions, k##name##Ciphers, \
+ groups, sigalgs));
+
+static const auto kDummyNamedGroupParams = ::testing::Values(ssl_grp_none);
+static const auto kDummySignatureSchemesParams =
+ ::testing::Values(ssl_sig_none);
+
+#ifndef NSS_DISABLE_TLS_1_3
+static SSLSignatureScheme kSignatureSchemesParamsArr[] = {
+ ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384,
+ ssl_sig_rsa_pkcs1_sha512, ssl_sig_ecdsa_secp256r1_sha256,
+ ssl_sig_ecdsa_secp384r1_sha384, ssl_sig_rsa_pss_sha256,
+ ssl_sig_rsa_pss_sha384, ssl_sig_rsa_pss_sha512,
+};
+#endif
+
+INSTANTIATE_CIPHER_TEST_P(RC4, Stream, V10ToV12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams,
+ TLS_RSA_WITH_RC4_128_SHA,
+ TLS_ECDH_ECDSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
+ TLS_ECDH_RSA_WITH_RC4_128_SHA,
+ TLS_ECDHE_RSA_WITH_RC4_128_SHA);
+INSTANTIATE_CIPHER_TEST_P(AEAD12, All, V12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams,
+ TLS_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_DHE_DSS_WITH_AES_128_GCM_SHA256,
+ TLS_DHE_DSS_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384);
+INSTANTIATE_CIPHER_TEST_P(AEAD, All, V12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams,
+ TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_DHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_DHE_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
+ TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
+ TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256);
+INSTANTIATE_CIPHER_TEST_P(
+ CBC12, All, V12, kDummyNamedGroupParams, kDummySignatureSchemesParams,
+ TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, TLS_RSA_WITH_AES_256_CBC_SHA256,
+ TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_RSA_WITH_AES_128_CBC_SHA256,
+ TLS_RSA_WITH_AES_128_CBC_SHA256, TLS_DHE_DSS_WITH_AES_128_CBC_SHA256,
+ TLS_DHE_DSS_WITH_AES_256_CBC_SHA256);
+INSTANTIATE_CIPHER_TEST_P(
+ CBCStream, Stream, V10ToV12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams, TLS_ECDH_ECDSA_WITH_NULL_SHA,
+ TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_ECDSA_WITH_NULL_SHA,
+ TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDH_RSA_WITH_NULL_SHA,
+ TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDH_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDH_RSA_WITH_AES_256_CBC_SHA, TLS_ECDHE_RSA_WITH_NULL_SHA,
+ TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
+INSTANTIATE_CIPHER_TEST_P(
+ CBCDatagram, Datagram, V11V12, kDummyNamedGroupParams,
+ kDummySignatureSchemesParams, TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA, TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA,
+ TLS_ECDH_RSA_WITH_AES_128_CBC_SHA, TLS_ECDH_RSA_WITH_AES_256_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_CIPHER_TEST_P(TLS13, All, V13,
+ ::testing::ValuesIn(kFasterDHEGroups),
+ ::testing::ValuesIn(kSignatureSchemesParamsArr),
+ TLS_AES_128_GCM_SHA256, TLS_CHACHA20_POLY1305_SHA256,
+ TLS_AES_256_GCM_SHA384);
+INSTANTIATE_CIPHER_TEST_P(TLS13AllGroups, All, V13,
+ ::testing::ValuesIn(kAllDHEGroups),
+ ::testing::Values(ssl_sig_ecdsa_secp384r1_sha384),
+ TLS_AES_256_GCM_SHA384);
+#endif
+
+// Fields are: version, cipher suite, bulk cipher name, secretKeySize
+struct SecStatusParams {
+ uint16_t version;
+ uint16_t cipher_suite;
+ std::string name;
+ int keySize;
+};
+
+inline std::ostream &operator<<(std::ostream &stream,
+ const SecStatusParams &vals) {
+ SSLCipherSuiteInfo csinfo;
+ SECStatus rv =
+ SSL_GetCipherSuiteInfo(vals.cipher_suite, &csinfo, sizeof(csinfo));
+ if (rv != SECSuccess) {
+ return stream << "Error invoking SSL_GetCipherSuiteInfo()";
+ }
+
+ return stream << "TLS " << VersionString(vals.version) << ", "
+ << csinfo.cipherSuiteName << ", name = \"" << vals.name
+ << "\", key size = " << vals.keySize;
+}
+
+class SecurityStatusTest
+ : public TlsCipherSuiteTestBase,
+ public ::testing::WithParamInterface<SecStatusParams> {
+ public:
+ SecurityStatusTest()
+ : TlsCipherSuiteTestBase("TLS", GetParam().version,
+ GetParam().cipher_suite, ssl_grp_none,
+ ssl_sig_none) {}
+};
+
+// SSL_SecurityStatus produces fairly useless output when compared to
+// SSL_GetCipherSuiteInfo and SSL_GetChannelInfo, but we can't break it, so we
+// need to check it.
+TEST_P(SecurityStatusTest, CheckSecurityStatus) {
+ SetupCertificate();
+ EnableSingleCipher();
+ ConnectAndCheckCipherSuite();
+
+ int on;
+ char *cipher;
+ int keySize;
+ int secretKeySize;
+ char *issuer;
+ char *subject;
+ EXPECT_EQ(SECSuccess,
+ SSL_SecurityStatus(client_->ssl_fd(), &on, &cipher, &keySize,
+ &secretKeySize, &issuer, &subject));
+ if (std::string(cipher) == "NULL") {
+ EXPECT_EQ(0, on);
+ } else {
+ EXPECT_NE(0, on);
+ }
+ EXPECT_EQ(GetParam().name, std::string(cipher));
+ // All the ciphers we support have secret key size == key size.
+ EXPECT_EQ(GetParam().keySize, keySize);
+ EXPECT_EQ(GetParam().keySize, secretKeySize);
+ EXPECT_LT(0U, strlen(issuer));
+ EXPECT_LT(0U, strlen(subject));
+
+ PORT_Free(cipher);
+ PORT_Free(issuer);
+ PORT_Free(subject);
+}
+
+static const SecStatusParams kSecStatusTestValuesArr[] = {
+ {SSL_LIBRARY_VERSION_TLS_1_0, TLS_ECDHE_RSA_WITH_NULL_SHA, "NULL", 0},
+ {SSL_LIBRARY_VERSION_TLS_1_0, TLS_RSA_WITH_RC4_128_SHA, "RC4", 128},
+ {SSL_LIBRARY_VERSION_TLS_1_0, TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ "3DES-EDE-CBC", 168},
+ {SSL_LIBRARY_VERSION_TLS_1_0, TLS_RSA_WITH_AES_128_CBC_SHA, "AES-128", 128},
+ {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_256_CBC_SHA256, "AES-256",
+ 256},
+ {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_128_GCM_SHA256,
+ "AES-128-GCM", 128},
+ {SSL_LIBRARY_VERSION_TLS_1_2, TLS_RSA_WITH_AES_256_GCM_SHA384,
+ "AES-256-GCM", 256},
+ {SSL_LIBRARY_VERSION_TLS_1_2, TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
+ "ChaCha20-Poly1305", 256}};
+INSTANTIATE_TEST_CASE_P(TestSecurityStatus, SecurityStatusTest,
+ ::testing::ValuesIn(kSecStatusTestValuesArr));
+
+} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
new file mode 100644
index 0000000..9dadcbd
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
@@ -0,0 +1,61 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->StartConnect();
+ client_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ std::cerr << "Damaging HS secret\n";
+ SSLInt_DamageClientHsTrafficSecret(server_->ssl_fd());
+ client_->Handshake();
+ server_->Handshake();
+ // The client thinks it has connected.
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+}
+
+TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetPacketFilter(new AfterRecordN(
+ server_, client_,
+ 0, // ServerHello.
+ [this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
new file mode 100644
index 0000000..82d5558
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
@@ -0,0 +1,609 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include <set>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGeneric, ConnectDhe) {
+ EnableOnlyDheCiphers();
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+}
+
+TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
+ EnsureTlsSetup();
+ client_->ConfigNamedGroups(kAllDHEGroups);
+
+ auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn);
+ auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
+ std::vector<PacketFilter*> captures;
+ captures.push_back(groups_capture);
+ captures.push_back(shares_capture);
+ client_->SetPacketFilter(new ChainedPacketFilter(captures));
+
+ Connect();
+
+ CheckKeys();
+
+ bool ec, dh;
+ auto track_group_type = [&ec, &dh](SSLNamedGroup group) {
+ if ((group & 0xff00U) == 0x100U) {
+ dh = true;
+ } else {
+ ec = true;
+ }
+ };
+ CheckGroups(groups_capture->extension(), track_group_type);
+ CheckShares(shares_capture->extension(), track_group_type);
+ EXPECT_TRUE(ec) << "Should include an EC group and share";
+ EXPECT_TRUE(dh) << "Should include an FFDHE group and share";
+}
+
+TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
+ EnableOnlyDheCiphers();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn);
+ auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
+ std::vector<PacketFilter*> captures;
+ captures.push_back(groups_capture);
+ captures.push_back(shares_capture);
+ client_->SetPacketFilter(new ChainedPacketFilter(captures));
+
+ Connect();
+
+ CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+ auto is_ffdhe = [](SSLNamedGroup group) {
+ // The group has to be in this range.
+ EXPECT_LE(ssl_grp_ffdhe_2048, group);
+ EXPECT_GE(ssl_grp_ffdhe_8192, group);
+ };
+ CheckGroups(groups_capture->extension(), is_ffdhe);
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ CheckShares(shares_capture->extension(), is_ffdhe);
+ } else {
+ EXPECT_EQ(0U, shares_capture->extension().len());
+ }
+}
+
+// Requiring the FFDHE extension on the server alone means that clients won't be
+// able to connect using a DHE suite. They should still connect in TLS 1.3,
+// because the client automatically sends the supported groups extension.
+TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
+ EnableOnlyDheCiphers();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(server_->ssl_fd(),
+ SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+ } else {
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ }
+}
+
+class TlsDheServerKeyExchangeDamager : public TlsHandshakeFilter {
+ public:
+ TlsDheServerKeyExchangeDamager() {}
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
+ return KEEP;
+ }
+
+ // Damage the first octet of dh_p. Anything other than the known prime will
+ // be rejected as "weak" when we have SSL_REQUIRE_DH_NAMED_GROUPS enabled.
+ *output = input;
+ output->data()[3] ^= 73;
+ return CHANGE;
+ }
+};
+
+// Changing the prime in the server's key share results in an error. This will
+// invalidate the signature over the ServerKeyShare. That's ok, NSS won't check
+// the signature until everything else has been checked.
+TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) {
+ EnableOnlyDheCiphers();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ server_->SetPacketFilter(new TlsDheServerKeyExchangeDamager());
+
+ ConnectExpectFail();
+
+ client_->CheckErrorCode(SSL_ERROR_WEAK_SERVER_EPHEMERAL_DH_KEY);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+class TlsDheSkeChangeY : public TlsHandshakeFilter {
+ public:
+ enum ChangeYTo {
+ kYZero,
+ kYOne,
+ kYPMinusOne,
+ kYGreaterThanP,
+ kYTooLarge,
+ kYZeroPad
+ };
+
+ TlsDheSkeChangeY(ChangeYTo change) : change_Y_(change) {}
+
+ protected:
+ void ChangeY(const DataBuffer& input, DataBuffer* output, size_t offset,
+ const DataBuffer& prime) {
+ static const uint8_t kExtraZero = 0;
+ static const uint8_t kTooLargeExtra = 1;
+
+ uint32_t dh_Ys_len;
+ EXPECT_TRUE(input.Read(offset, 2, &dh_Ys_len));
+ EXPECT_LT(offset + dh_Ys_len, input.len());
+ offset += 2;
+
+ // This isn't generally true, but our code pads.
+ EXPECT_EQ(prime.len(), dh_Ys_len)
+ << "Length of dh_Ys must equal length of dh_p";
+
+ *output = input;
+ switch (change_Y_) {
+ case kYZero:
+ memset(output->data() + offset, 0, prime.len());
+ break;
+
+ case kYOne:
+ memset(output->data() + offset, 0, prime.len() - 1);
+ output->Write(offset + prime.len() - 1, 1U, 1);
+ break;
+
+ case kYPMinusOne:
+ output->Write(offset, prime);
+ EXPECT_TRUE(output->data()[offset + prime.len() - 1] & 0x01)
+ << "P must at least be odd";
+ --output->data()[offset + prime.len() - 1];
+ break;
+
+ case kYGreaterThanP:
+ // Set the first 32 octets of Y to 0xff, except the first which we set
+ // to p[0]. This will make Y > p. That is, unless p is Mersenne, or
+ // improbably large (but still the same bit length). We currently only
+ // use a fixed prime that isn't a problem for this code.
+ EXPECT_LT(0, prime.data()[0]) << "dh_p should not be zero-padded";
+ offset = output->Write(offset, prime.data()[0], 1);
+ memset(output->data() + offset, 0xff, 31);
+ break;
+
+ case kYTooLarge:
+ // Increase the dh_Ys length.
+ output->Write(offset - 2, prime.len() + sizeof(kTooLargeExtra), 2);
+ // Then insert the octet.
+ output->Splice(&kTooLargeExtra, sizeof(kTooLargeExtra), offset);
+ break;
+
+ case kYZeroPad:
+ output->Write(offset - 2, prime.len() + sizeof(kExtraZero), 2);
+ output->Splice(&kExtraZero, sizeof(kExtraZero), offset);
+ break;
+ }
+ }
+
+ private:
+ ChangeYTo change_Y_;
+};
+
+class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
+ public:
+ TlsDheSkeChangeYServer(ChangeYTo change, bool modify)
+ : TlsDheSkeChangeY(change), modify_(modify), p_() {}
+
+ const DataBuffer& prime() const { return p_; }
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) override {
+ if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
+ return KEEP;
+ }
+
+ size_t offset = 2;
+ // Read dh_p
+ uint32_t dh_len = 0;
+ EXPECT_TRUE(input.Read(0, 2, &dh_len));
+ EXPECT_GT(input.len(), offset + dh_len);
+ p_.Assign(input.data() + offset, dh_len);
+ offset += dh_len;
+
+ // Skip dh_g to find dh_Ys
+ EXPECT_TRUE(input.Read(offset, 2, &dh_len));
+ offset += 2 + dh_len;
+
+ if (modify_) {
+ ChangeY(input, output, offset, p_);
+ return CHANGE;
+ }
+ return KEEP;
+ }
+
+ private:
+ bool modify_;
+ DataBuffer p_;
+};
+
+class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
+ public:
+ TlsDheSkeChangeYClient(ChangeYTo change,
+ const TlsDheSkeChangeYServer* server_filter)
+ : TlsDheSkeChangeY(change), server_filter_(server_filter) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) override {
+ if (header.handshake_type() != kTlsHandshakeClientKeyExchange) {
+ return KEEP;
+ }
+
+ ChangeY(input, output, 0, server_filter_->prime());
+ return CHANGE;
+ }
+
+ private:
+ const TlsDheSkeChangeYServer* server_filter_;
+};
+
+/* This matrix includes: mode (stream/datagram), TLS version, what change to
+ * make to dh_Ys, whether the client will be configured to require DH named
+ * groups. Test all combinations. */
+typedef std::tuple<std::string, uint16_t, TlsDheSkeChangeY::ChangeYTo, bool>
+ DamageDHYProfile;
+class TlsDamageDHYTest
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<DamageDHYProfile> {
+ public:
+ TlsDamageDHYTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+};
+
+TEST_P(TlsDamageDHYTest, DamageServerY) {
+ EnableOnlyDheCiphers();
+ if (std::get<3>(GetParam())) {
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ }
+ TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
+ server_->SetPacketFilter(new TlsDheSkeChangeYServer(change, true));
+
+ ConnectExpectFail();
+ if (change == TlsDheSkeChangeY::kYZeroPad) {
+ // Zero padding Y only manifests in a signature failure.
+ // In TLS 1.0 and 1.1, the client reports a device error.
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
+ client_->CheckErrorCode(SEC_ERROR_PKCS11_DEVICE_ERROR);
+ } else {
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+ }
+ server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ }
+}
+
+TEST_P(TlsDamageDHYTest, DamageClientY) {
+ EnableOnlyDheCiphers();
+ if (std::get<3>(GetParam())) {
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ }
+ // The filter on the server is required to capture the prime.
+ TlsDheSkeChangeYServer* server_filter =
+ new TlsDheSkeChangeYServer(TlsDheSkeChangeY::kYZero, false);
+ server_->SetPacketFilter(server_filter);
+
+ // The client filter does the damage.
+ TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
+ client_->SetPacketFilter(new TlsDheSkeChangeYClient(change, server_filter));
+
+ ConnectExpectFail();
+ if (change == TlsDheSkeChangeY::kYZeroPad) {
+ // Zero padding Y only manifests in a finished error.
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE);
+ }
+}
+
+static const TlsDheSkeChangeY::ChangeYTo kAllYArr[] = {
+ TlsDheSkeChangeY::kYZero, TlsDheSkeChangeY::kYOne,
+ TlsDheSkeChangeY::kYPMinusOne, TlsDheSkeChangeY::kYGreaterThanP,
+ TlsDheSkeChangeY::kYTooLarge, TlsDheSkeChangeY::kYZeroPad};
+static ::testing::internal::ParamGenerator<TlsDheSkeChangeY::ChangeYTo> kAllY =
+ ::testing::ValuesIn(kAllYArr);
+static const bool kTrueFalseArr[] = {true, false};
+static ::testing::internal::ParamGenerator<bool> kTrueFalse =
+ ::testing::ValuesIn(kTrueFalseArr);
+
+INSTANTIATE_TEST_CASE_P(DamageYStream, TlsDamageDHYTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsV10ToV12,
+ kAllY, kTrueFalse));
+INSTANTIATE_TEST_CASE_P(
+ DamageYDatagram, TlsDamageDHYTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+ TlsConnectTestBase::kTlsV11V12, kAllY, kTrueFalse));
+
+class TlsDheSkeMakePEven : public TlsHandshakeFilter {
+ public:
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
+ return KEEP;
+ }
+
+ // Find the end of dh_p
+ uint32_t dh_len = 0;
+ EXPECT_TRUE(input.Read(0, 2, &dh_len));
+ EXPECT_GT(input.len(), 2 + dh_len) << "enough space for dh_p";
+ size_t offset = 2 + dh_len - 1;
+ EXPECT_TRUE((input.data()[offset] & 0x01) == 0x01) << "p should be odd";
+
+ *output = input;
+ output->data()[offset] &= 0xfe;
+
+ return CHANGE;
+ }
+};
+
+// Even without requiring named groups, an even value for p is bad news.
+TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
+ EnableOnlyDheCiphers();
+ server_->SetPacketFilter(new TlsDheSkeMakePEven());
+
+ ConnectExpectFail();
+
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+}
+
+class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
+ public:
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
+ return KEEP;
+ }
+
+ *output = input;
+ uint32_t dh_len = 0;
+ EXPECT_TRUE(input.Read(0, 2, &dh_len));
+ static const uint8_t kZeroPad = 0;
+ output->Write(0, dh_len + sizeof(kZeroPad), 2); // increment the length
+ output->Splice(&kZeroPad, sizeof(kZeroPad), 2); // insert a zero
+
+ return CHANGE;
+ }
+};
+
+// Zero padding only causes signature failure.
+TEST_P(TlsConnectGenericPre13, PadDheP) {
+ EnableOnlyDheCiphers();
+ server_->SetPacketFilter(new TlsDheSkeZeroPadP());
+
+ ConnectExpectFail();
+
+ // In TLS 1.0 and 1.1, the client reports a device error.
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
+ client_->CheckErrorCode(SEC_ERROR_PKCS11_DEVICE_ERROR);
+ } else {
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+ }
+ server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+}
+
+// The server should not pick the weak DH group if the client includes FFDHE
+// named groups in the supported_groups extension. The server then picks a
+// commonly-supported named DH group and this connects.
+//
+// Note: This test case can take ages to generate the weak DH key.
+TEST_P(TlsConnectGenericPre13, WeakDHGroup) {
+ EnableOnlyDheCiphers();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ EXPECT_EQ(SECSuccess,
+ SSL_EnableWeakDHEPrimeGroup(server_->ssl_fd(), PR_TRUE));
+
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, Ffdhe3072) {
+ EnableOnlyDheCiphers();
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ffdhe_3072};
+ client_->ConfigNamedGroups(groups);
+
+ Connect();
+}
+
+// Even though the client doesn't have DHE groups enabled the server assumes it
+// does. Because the client doesn't require named groups it accepts FF3072 as
+// custom group.
+TEST_P(TlsConnectGenericPre13, NamedGroupMismatchPre13) {
+ EnableOnlyDheCiphers();
+ static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072};
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1};
+ server_->ConfigNamedGroups(server_groups);
+ client_->ConfigNamedGroups(client_groups);
+
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_custom, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+}
+
+// Same test but for TLS 1.3. This has to fail.
+TEST_P(TlsConnectTls13, NamedGroupMismatch13) {
+ EnableOnlyDheCiphers();
+ static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072};
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1};
+ server_->ConfigNamedGroups(server_groups);
+ client_->ConfigNamedGroups(client_groups);
+
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// Even though the client doesn't have DHE groups enabled the server assumes it
+// does. The client requires named groups and thus does not accept FF3072 as
+// custom group in contrast to the previous test.
+TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) {
+ EnableOnlyDheCiphers();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ static const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_3072};
+ static const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ffdhe_2048};
+ server_->ConfigNamedGroups(server_groups);
+ client_->ConfigNamedGroups(client_groups);
+
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsConnectGenericPre13, PreferredFfdhe) {
+ EnableOnlyDheCiphers();
+ static const SSLDHEGroupType groups[] = {ssl_ff_dhe_3072_group,
+ ssl_ff_dhe_2048_group};
+ EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), groups,
+ PR_ARRAY_SIZE(groups)));
+
+ Connect();
+ client_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+ server_->CheckKEA(ssl_kea_dh, ssl_grp_ffdhe_3072, 3072);
+ client_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256);
+ server_->CheckAuthType(ssl_auth_rsa_sign, ssl_sig_rsa_pss_sha256);
+}
+
+TEST_P(TlsConnectGenericPre13, MismatchDHE) {
+ EnableOnlyDheCiphers();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
+ static const SSLDHEGroupType serverGroups[] = {ssl_ff_dhe_3072_group};
+ EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(server_->ssl_fd(), serverGroups,
+ PR_ARRAY_SIZE(serverGroups)));
+ static const SSLDHEGroupType clientGroups[] = {ssl_ff_dhe_2048_group};
+ EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(client_->ssl_fd(), clientGroups,
+ PR_ARRAY_SIZE(clientGroups)));
+
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsConnectTls13, ResumeFfdhe) {
+ EnableOnlyDheCiphers();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ EnableOnlyDheCiphers();
+ TlsExtensionCapture* clientCapture =
+ new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
+ client_->SetPacketFilter(clientCapture);
+ TlsExtensionCapture* serverCapture =
+ new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
+ server_->SetPacketFilter(serverCapture);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none);
+ ASSERT_LT(0UL, clientCapture->extension().len());
+ ASSERT_LT(0UL, serverCapture->extension().len());
+}
+
+class TlsDheSkeChangeSignature : public TlsHandshakeFilter {
+ public:
+ TlsDheSkeChangeSignature(uint16_t version, const uint8_t* data, size_t len)
+ : version_(version), data_(data), len_(len) {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
+ return KEEP;
+ }
+
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.SkipVariable(2)); // dh_p
+ EXPECT_TRUE(parser.SkipVariable(2)); // dh_g
+ EXPECT_TRUE(parser.SkipVariable(2)); // dh_Ys
+
+ // Copy DH params to output.
+ size_t offset = output->Write(0, input.data(), parser.consumed());
+
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_2) {
+ // Write signature algorithm.
+ offset = output->Write(offset, ssl_sig_dsa_sha256, 2);
+ }
+
+ // Write new signature.
+ offset = output->Write(offset, len_, 2);
+ offset = output->Write(offset, data_, len_);
+
+ return CHANGE;
+ }
+
+ private:
+ uint16_t version_;
+ const uint8_t* data_;
+ size_t len_;
+};
+
+TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) {
+ const uint8_t kBogusDheSignature[] = {
+ 0x30, 0x69, 0x3c, 0x02, 0x1c, 0x7d, 0x0b, 0x2f, 0x64, 0x00, 0x27,
+ 0xae, 0xcf, 0x1e, 0x28, 0x08, 0x6a, 0x7f, 0xb1, 0xbd, 0x78, 0xb5,
+ 0x3b, 0x8c, 0x8f, 0x59, 0xed, 0x8f, 0xee, 0x78, 0xeb, 0x2c, 0xe9,
+ 0x02, 0x1c, 0x6d, 0x7f, 0x3c, 0x0f, 0xf4, 0x44, 0x35, 0x0b, 0xb2,
+ 0x6d, 0xdc, 0xb8, 0x21, 0x87, 0xdd, 0x0d, 0xb9, 0x46, 0x09, 0x3e,
+ 0xef, 0x81, 0x5b, 0x37, 0x09, 0x39, 0xeb};
+
+ Reset(TlsAgent::kServerDsa);
+
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048};
+ client_->ConfigNamedGroups(client_groups);
+
+ server_->SetPacketFilter(new TlsDheSkeChangeSignature(
+ version_, kBogusDheSignature, sizeof(kBogusDheSignature)));
+
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
new file mode 100644
index 0000000..89ca28e
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -0,0 +1,133 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) {
+ client_->SetPacketFilter(new SelectiveDropFilter(0x1));
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
+ server_->SetPacketFilter(new SelectiveDropFilter(0x1));
+ Connect();
+ SendReceive();
+}
+
+// This drops the first transmission from both the client and server of all
+// flights that they send. Note: In DTLS 1.3, the shorter handshake means that
+// this will also drop some application data, so we can't call SendReceive().
+TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) {
+ client_->SetPacketFilter(new SelectiveDropFilter(0x15));
+ server_->SetPacketFilter(new SelectiveDropFilter(0x5));
+ Connect();
+}
+
+// This drops the server's first flight three times.
+TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) {
+ server_->SetPacketFilter(new SelectiveDropFilter(0x7));
+ Connect();
+}
+
+// This drops the client's second flight once
+TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) {
+ client_->SetPacketFilter(new SelectiveDropFilter(0x2));
+ Connect();
+}
+
+// This drops the client's second flight three times.
+TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) {
+ client_->SetPacketFilter(new SelectiveDropFilter(0xe));
+ Connect();
+}
+
+// This drops the server's second flight three times.
+TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) {
+ server_->SetPacketFilter(new SelectiveDropFilter(0xe));
+ Connect();
+}
+
+static void GetCipherAndLimit(uint16_t version, uint16_t* cipher,
+ uint64_t* limit = nullptr) {
+ uint64_t l;
+ if (!limit) limit = &l;
+
+ if (version < SSL_LIBRARY_VERSION_TLS_1_2) {
+ *cipher = TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA;
+ *limit = 0x5aULL << 28;
+ } else if (version == SSL_LIBRARY_VERSION_TLS_1_2) {
+ *cipher = TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256;
+ *limit = (1ULL << 48) - 1;
+ } else {
+ *cipher = TLS_CHACHA20_POLY1305_SHA256;
+ *limit = (1ULL << 48) - 1;
+ }
+}
+
+// This simulates a huge number of drops on one side.
+TEST_P(TlsConnectDatagram, MissLotsOfPackets) {
+ uint16_t cipher;
+ uint64_t limit;
+
+ GetCipherAndLimit(version_, &cipher, &limit);
+
+ EnsureTlsSetup();
+ server_->EnableSingleCipher(cipher);
+ Connect();
+
+ // Note that the limit for ChaCha is 2^48-1.
+ EXPECT_EQ(SECSuccess,
+ SSLInt_AdvanceWriteSeqNum(client_->ssl_fd(), limit - 10));
+ SendReceive();
+}
+
+class TlsConnectDatagram12Plus : public TlsConnectDatagram {
+ public:
+ TlsConnectDatagram12Plus() : TlsConnectDatagram() {}
+};
+
+// This simulates missing a window's worth of packets.
+TEST_P(TlsConnectDatagram12Plus, MissAWindow) {
+ EnsureTlsSetup();
+ uint16_t cipher;
+ GetCipherAndLimit(version_, &cipher);
+ server_->EnableSingleCipher(cipher);
+ Connect();
+
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 0));
+ SendReceive();
+}
+
+TEST_P(TlsConnectDatagram12Plus, MissAWindowAndOne) {
+ EnsureTlsSetup();
+ uint16_t cipher;
+ GetCipherAndLimit(version_, &cipher);
+ server_->EnableSingleCipher(cipher);
+ Connect();
+
+ EXPECT_EQ(SECSuccess, SSLInt_AdvanceWriteSeqByAWindow(client_->ssl_fd(), 1));
+ SendReceive();
+}
+
+INSTANTIATE_TEST_CASE_P(Datagram12Plus, TlsConnectDatagram12Plus,
+ TlsConnectTestBase::kTlsV12Plus);
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
new file mode 100644
index 0000000..43dfcba
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
@@ -0,0 +1,532 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGenericPre13, ConnectEcdh) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Reset(TlsAgent::kServerEcdhEcdsa);
+ DisableAllCiphers();
+ EnableSomeEcdhCiphers();
+
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_ecdh_ecdsa,
+ ssl_sig_none);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectEcdhWithoutDisablingSuites) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Reset(TlsAgent::kServerEcdhEcdsa);
+ EnableSomeEcdhCiphers();
+
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_ecdh_ecdsa,
+ ssl_sig_none);
+}
+
+TEST_P(TlsConnectGeneric, ConnectEcdhe) {
+ Connect();
+ CheckKeys();
+}
+
+// If we pick a 256-bit cipher suite and use a P-384 certificate, the server
+// should choose P-384 for key exchange too. Only valid for TLS == 1.2 because
+// we don't have 256-bit ciphers before then and 1.3 doesn't try to couple
+// DHE size to symmetric size.
+TEST_P(TlsConnectTls12, ConnectEcdheP384) {
+ Reset(TlsAgent::kServerEcdsa384);
+ ConnectWithCipherSuite(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256);
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp384r1_sha384);
+}
+
+TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) {
+ EnsureTlsSetup();
+ const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048};
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+}
+
+// This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care.
+TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) {
+ EnsureTlsSetup();
+ auto hrr_capture =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest);
+ server_->SetPacketFilter(hrr_capture);
+ const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ EXPECT_EQ(version_ == SSL_LIBRARY_VERSION_TLS_1_3,
+ hrr_capture->buffer().len() != 0);
+}
+
+// This enables only P-256 on the client and disables it on the server.
+// This test will fail when we add other groups that identify as ECDHE.
+TEST_P(TlsConnectGeneric, ConnectEcdheGroupMismatch) {
+ EnsureTlsSetup();
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ffdhe_2048};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ffdhe_2048};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
+}
+
+TEST_P(TlsKeyExchangeTest, P384Priority) {
+ // P256, P384 and P521 are enabled. Both prefer P384.
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ EnsureKeyShareSetup();
+ ConfigNamedGroups(groups);
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+
+ std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
+ CheckKEXDetails(groups, shares);
+}
+
+TEST_P(TlsKeyExchangeTest, DuplicateGroupConfig) {
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp384r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp256r1};
+ EnsureKeyShareSetup();
+ ConfigNamedGroups(groups);
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+
+ std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
+ std::vector<SSLNamedGroup> expectedGroups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp256r1};
+ CheckKEXDetails(expectedGroups, shares);
+}
+
+TEST_P(TlsKeyExchangeTest, P384PriorityDHEnabled) {
+ // P256, P384, P521, and FFDHE2048 are enabled. Both prefer P384.
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ffdhe_2048, ssl_grp_ec_secp256r1,
+ ssl_grp_ec_secp521r1};
+ EnsureKeyShareSetup();
+ ConfigNamedGroups(groups);
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp384r1};
+ CheckKEXDetails(groups, shares);
+ } else {
+ std::vector<SSLNamedGroup> oldtlsgroups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ CheckKEXDetails(oldtlsgroups, std::vector<SSLNamedGroup>());
+ }
+}
+
+TEST_P(TlsConnectGenericPre13, P384PriorityOnServer) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ // The server prefers P384. It has to win.
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+}
+
+TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) {
+ EnsureModelSockets();
+
+ /* Both prefer P384, set on the model socket. */
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1,
+ ssl_grp_ffdhe_2048};
+ client_model_->ConfigNamedGroups(groups);
+ server_model_->ConfigNamedGroups(groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+}
+
+// If we only have a lame group, we fall back to static RSA.
+TEST_P(TlsConnectGenericPre13, UseLameGroup) {
+ const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp192r1};
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+// In TLS 1.3, we can't generate the ClientHello.
+TEST_P(TlsConnectTls13, UseLameGroup) {
+ const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_sect283k1};
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ client_->StartConnect();
+ client_->Handshake();
+ client_->CheckErrorCode(SSL_ERROR_NO_CIPHERS_SUPPORTED);
+}
+
+TEST_P(TlsConnectStreamPre13, ConfiguredGroupsRenegotiate) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_secp256r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ CheckConnected();
+
+ // The renegotiation has to use the same preferences as the original session.
+ server_->PrepareForRenegotiate();
+ client_->StartRenegotiate();
+ Handshake();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+}
+
+TEST_P(TlsKeyExchangeTest, Curve25519) {
+ Reset(TlsAgent::kServerEcdsa256);
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ EnsureKeyShareSetup();
+ ConfigNamedGroups(groups);
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_ecdsa,
+ ssl_sig_ecdsa_secp256r1_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(groups, shares);
+}
+
+TEST_P(TlsConnectGenericPre13, GroupPreferenceServerPriority) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ // The client prefers P256 while the server prefers 25519.
+ // The server's preference has to win.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+}
+
+#ifndef NSS_DISABLE_TLS_1_3
+TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityClient13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a P256 key share while the server prefers 25519.
+ // We have to accept P256 without retry.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_secp256r1};
+ CheckKEXDetails(client_groups, shares);
+}
+
+TEST_P(TlsKeyExchangeTest13, Curve25519P256EqualPriorityServer13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a 25519 key share while the server prefers P256.
+ // We have to accept 25519 without retry.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares);
+}
+
+TEST_P(TlsKeyExchangeTest13, EqualPriorityTestRetryECServer13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a 25519 key share while the server prefers P256.
+ // The server prefers P-384 over x25519, so it must not consider P-256 and
+ // x25519 to be equivalent. It will therefore request a P-256 share
+ // with a HelloRetryRequest.
+ const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
+}
+
+TEST_P(TlsKeyExchangeTest13, NotEqualPriorityWithIntermediateGroup13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a 25519 key share while the server prefers P256.
+ // The server prefers ffdhe_2048 over x25519, so it must not consider the
+ // P-256 and x25519 to be equivalent. It will therefore request a P-256 share
+ // with a HelloRetryRequest.
+ const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
+}
+
+TEST_P(TlsKeyExchangeTest13,
+ NotEqualPriorityWithUnsupportedFFIntermediateGroup13) {
+ EnsureKeyShareSetup();
+
+ // As in the previous test, the server prefers ffdhe_2048. Thus, even though
+ // the client doesn't support this group, the server must not regard x25519 as
+ // equivalent to P-256.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ffdhe_2048, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
+}
+
+TEST_P(TlsKeyExchangeTest13,
+ NotEqualPriorityWithUnsupportedECIntermediateGroup13) {
+ EnsureKeyShareSetup();
+
+ // As in the previous test, the server prefers P-384. Thus, even though
+ // the client doesn't support this group, the server must not regard x25519 as
+ // equivalent to P-256. The server sends a HelloRetryRequest.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares, ssl_grp_ec_secp256r1);
+}
+
+TEST_P(TlsKeyExchangeTest13, EqualPriority13) {
+ EnsureKeyShareSetup();
+
+ // The client sends a 25519 key share while the server prefers P256.
+ // We have to accept 25519 without retry because it's considered equivalent to
+ // P256 by the server.
+ const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ffdhe_2048, ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ Connect();
+
+ CheckKeys();
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519};
+ CheckKEXDetails(client_groups, shares);
+}
+#endif
+
+TEST_P(TlsConnectGeneric, P256ClientAndCurve25519Server) {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_ecdh);
+
+ // The client sends a P256 key share while the server prefers 25519.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_curve25519};
+
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsKeyExchangeTest13, MultipleClientShares) {
+ EnsureKeyShareSetup();
+
+ // The client sends 25519 and P256 key shares. The server prefers P256,
+ // which must be chosen here.
+ const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ const std::vector<SSLNamedGroup> server_groups = {ssl_grp_ec_secp256r1,
+ ssl_grp_ec_curve25519};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+
+ // Generate a key share on the client for both curves.
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ Connect();
+
+ // The server would accept 25519 but its preferred group (P256) has to win.
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp256r1, ssl_auth_rsa_sign,
+ ssl_sig_rsa_pss_sha256);
+ const std::vector<SSLNamedGroup> shares = {ssl_grp_ec_curve25519,
+ ssl_grp_ec_secp256r1};
+ CheckKEXDetails(client_groups, shares);
+}
+
+// Replace the point in the client key exchange message with an empty one
+class ECCClientKEXFilter : public TlsHandshakeFilter {
+ public:
+ ECCClientKEXFilter() {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
+ const DataBuffer &input,
+ DataBuffer *output) {
+ if (header.handshake_type() != kTlsHandshakeClientKeyExchange) {
+ return KEEP;
+ }
+
+ // Replace the client key exchange message with an empty point
+ output->Allocate(1);
+ output->Write(0, 0U, 1); // set point length 0
+ return CHANGE;
+ }
+};
+
+// Replace the point in the server key exchange message with an empty one
+class ECCServerKEXFilter : public TlsHandshakeFilter {
+ public:
+ ECCServerKEXFilter() {}
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
+ const DataBuffer &input,
+ DataBuffer *output) {
+ if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
+ return KEEP;
+ }
+
+ // Replace the server key exchange message with an empty point
+ output->Allocate(4);
+ output->Write(0, 3U, 1); // named curve
+ uint32_t curve;
+ EXPECT_TRUE(input.Read(1, 2, &curve)); // get curve id
+ output->Write(1, curve, 2); // write curve id
+ output->Write(3, 0U, 1); // point length 0
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) {
+ // add packet filter
+ server_->SetPacketFilter(new ECCServerKEXFilter());
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) {
+ // add packet filter
+ client_->SetPacketFilter(new ECCClientKEXFilter());
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH);
+}
+
+INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV11Plus));
+
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest13,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV13));
+#endif
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_ems_unittest.cc b/nss/gtests/ssl_gtest/ssl_ems_unittest.cc
new file mode 100644
index 0000000..b9c725b
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_ems_unittest.cc
@@ -0,0 +1,100 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecret) {
+ EnableExtendedMasterSecret();
+ Connect();
+ Reset();
+ ExpectResumption(RESUME_SESSIONID);
+ EnableExtendedMasterSecret();
+ Connect();
+}
+
+TEST_P(TlsConnectTls12, ConnectExtendedMasterSecretSha384) {
+ EnableExtendedMasterSecret();
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384);
+ ConnectWithCipherSuite(TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretStaticRSA) {
+ EnableOnlyStaticRsaCiphers();
+ EnableExtendedMasterSecret();
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretECDHE) {
+ EnableExtendedMasterSecret();
+ Connect();
+
+ Reset();
+ EnableExtendedMasterSecret();
+ ExpectResumption(RESUME_SESSIONID);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ EnableExtendedMasterSecret();
+ Connect();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+
+ EnableExtendedMasterSecret();
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretClientOnly) {
+ client_->EnableExtendedMasterSecret();
+ ExpectExtendedMasterSecret(false);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretServerOnly) {
+ server_->EnableExtendedMasterSecret();
+ ExpectExtendedMasterSecret(false);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretResumeWithout) {
+ EnableExtendedMasterSecret();
+ Connect();
+
+ Reset();
+ server_->EnableExtendedMasterSecret();
+ auto alert_recorder = new TlsAlertRecorder();
+ server_->SetPacketFilter(alert_recorder);
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertHandshakeFailure, alert_recorder->description());
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectNormalResumeWithExtendedMasterSecret) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ ExpectExtendedMasterSecret(false);
+ Connect();
+
+ Reset();
+ EnableExtendedMasterSecret();
+ ExpectResumption(RESUME_NONE);
+ Connect();
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
new file mode 100644
index 0000000..0a0d9f2
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
@@ -0,0 +1,122 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+static const char* kExporterLabel = "EXPORTER-duck";
+static const uint8_t kExporterContext[] = {0x12, 0x34, 0x56};
+
+static void ExportAndCompare(TlsAgent* client, TlsAgent* server, bool context) {
+ static const size_t exporter_len = 10;
+ uint8_t client_value[exporter_len] = {0};
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportKeyingMaterial(
+ client->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ context ? PR_TRUE : PR_FALSE, kExporterContext,
+ sizeof(kExporterContext), client_value, sizeof(client_value)));
+ uint8_t server_value[exporter_len] = {0xff};
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportKeyingMaterial(
+ server->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ context ? PR_TRUE : PR_FALSE, kExporterContext,
+ sizeof(kExporterContext), server_value, sizeof(server_value)));
+ EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value)));
+}
+
+TEST_P(TlsConnectGeneric, ExporterBasic) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ } else {
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ }
+ Connect();
+ CheckKeys();
+ ExportAndCompare(client_, server_, false);
+}
+
+TEST_P(TlsConnectGeneric, ExporterContext) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ } else {
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ }
+ Connect();
+ CheckKeys();
+ ExportAndCompare(client_, server_, true);
+}
+
+// Bug 1312976 - SHA-384 doesn't work in 1.2 right now.
+TEST_P(TlsConnectTls13, ExporterSha384) {
+ EnsureTlsSetup();
+ client_->EnableSingleCipher(TLS_AES_256_GCM_SHA384);
+ Connect();
+ CheckKeys();
+ ExportAndCompare(client_, server_, false);
+}
+
+TEST_P(TlsConnectTls13, ExporterContextEmptyIsSameAsNone) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ server_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ } else {
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ }
+ Connect();
+ CheckKeys();
+ ExportAndCompare(client_, server_, false);
+}
+
+// This has a weird signature so that it can be passed to the SNI callback.
+int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr,
+ PRUint32 srvNameArrSize) {
+ uint8_t val[10];
+ EXPECT_EQ(SECFailure, SSL_ExportKeyingMaterial(
+ agent->ssl_fd(), kExporterLabel,
+ strlen(kExporterLabel), PR_TRUE, kExporterContext,
+ sizeof(kExporterContext), val, sizeof(val)))
+ << "regular exporter should fail";
+ return 0;
+}
+
+TEST_P(TlsConnectTls13, EarlyExporter) {
+ SetupForZeroRtt();
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ client_->Handshake(); // Send ClientHello.
+ uint8_t client_value[10] = {0};
+ RegularExporterShouldFail(client_, nullptr, 0);
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportEarlyKeyingMaterial(
+ client_->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ kExporterContext, sizeof(kExporterContext), client_value,
+ sizeof(client_value)));
+
+ server_->SetSniCallback(RegularExporterShouldFail);
+ server_->Handshake(); // Handle ClientHello.
+ uint8_t server_value[10] = {0};
+ EXPECT_EQ(SECSuccess,
+ SSL_ExportEarlyKeyingMaterial(
+ server_->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
+ kExporterContext, sizeof(kExporterContext), server_value,
+ sizeof(server_value)));
+ EXPECT_EQ(0, memcmp(client_value, server_value, sizeof(client_value)));
+
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
new file mode 100644
index 0000000..9200e72
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -0,0 +1,985 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+class TlsExtensionTruncator : public TlsExtensionFilter {
+ public:
+ TlsExtensionTruncator(uint16_t extension, size_t length)
+ : extension_(extension), length_(length) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+ if (input.len() <= length_) {
+ return KEEP;
+ }
+
+ output->Assign(input.data(), length_);
+ return CHANGE;
+ }
+
+ private:
+ uint16_t extension_;
+ size_t length_;
+};
+
+class TlsExtensionDamager : public TlsExtensionFilter {
+ public:
+ TlsExtensionDamager(uint16_t extension, size_t index)
+ : extension_(extension), index_(index) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ *output = input;
+ output->data()[index_] += 73; // Increment selected for maximum damage
+ return CHANGE;
+ }
+
+ private:
+ uint16_t extension_;
+ size_t index_;
+};
+
+class TlsExtensionInjector : public TlsHandshakeFilter {
+ public:
+ TlsExtensionInjector(uint16_t ext, DataBuffer& data)
+ : extension_(ext), data_(data) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ size_t offset;
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
+ return KEEP;
+ }
+ offset = parser.consumed();
+ } else if (header.handshake_type() == kTlsHandshakeServerHello) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) {
+ return KEEP;
+ }
+ offset = parser.consumed();
+ } else {
+ return KEEP;
+ }
+
+ *output = input;
+
+ // Increase the size of the extensions.
+ uint16_t ext_len;
+ memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
+ ext_len = htons(ntohs(ext_len) + data_.len() + 4);
+ memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
+
+ // Insert the extension type and length.
+ DataBuffer type_length;
+ type_length.Allocate(4);
+ type_length.Write(0, extension_, 2);
+ type_length.Write(2, data_.len(), 2);
+ output->Splice(type_length, offset + 2);
+
+ // Insert the payload.
+ if (data_.len() > 0) {
+ output->Splice(data_, offset + 6);
+ }
+
+ return CHANGE;
+ }
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionAppender : public TlsHandshakeFilter {
+ public:
+ TlsExtensionAppender(uint16_t ext, DataBuffer& data)
+ : extension_(ext), data_(data) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ size_t offset;
+ TlsParser parser(input);
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
+ if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
+ return KEEP;
+ }
+ } else if (header.handshake_type() == kTlsHandshakeServerHello) {
+ if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) {
+ return KEEP;
+ }
+ } else {
+ return KEEP;
+ }
+ offset = parser.consumed();
+ *output = input;
+
+ uint32_t ext_len;
+ if (!parser.Read(&ext_len, 2)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ ext_len += 4 + data_.len();
+ output->Write(offset, ext_len, 2);
+
+ offset = output->len();
+ offset = output->Write(offset, extension_, 2);
+ WriteVariable(output, offset, data_, 2);
+
+ return CHANGE;
+ }
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionTestBase : public TlsConnectTestBase {
+ protected:
+ TlsExtensionTestBase(Mode mode, uint16_t version)
+ : TlsConnectTestBase(mode, version) {}
+ TlsExtensionTestBase(const std::string& mode, uint16_t version)
+ : TlsConnectTestBase(mode, version) {}
+
+ void ClientHelloErrorTest(PacketFilter* filter,
+ uint8_t alert = kTlsAlertDecodeError) {
+ auto alert_recorder = new TlsAlertRecorder();
+ server_->SetPacketFilter(alert_recorder);
+ if (filter) {
+ client_->SetPacketFilter(filter);
+ }
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(alert, alert_recorder->description());
+ }
+
+ void ServerHelloErrorTest(PacketFilter* filter,
+ uint8_t alert = kTlsAlertDecodeError) {
+ auto alert_recorder = new TlsAlertRecorder();
+ client_->SetPacketFilter(alert_recorder);
+ if (filter) {
+ server_->SetPacketFilter(filter);
+ }
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(alert, alert_recorder->description());
+ }
+
+ static void InitSimpleSni(DataBuffer* extension) {
+ const char* name = "host.name";
+ const size_t namelen = PL_strlen(name);
+ extension->Allocate(namelen + 5);
+ extension->Write(0, namelen + 3, 2);
+ extension->Write(2, static_cast<uint32_t>(0), 1); // 0 == hostname
+ extension->Write(3, namelen, 2);
+ extension->Write(5, reinterpret_cast<const uint8_t*>(name), namelen);
+ }
+
+ void HrrThenRemoveExtensionsTest(SSLExtensionType type, PRInt32 client_error,
+ PRInt32 server_error) {
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+ EnsureTlsSetup();
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send HRR.
+ client_->SetPacketFilter(new TlsExtensionDropper(type));
+ Handshake();
+ client_->CheckErrorCode(client_error);
+ server_->CheckErrorCode(server_error);
+ }
+};
+
+class TlsExtensionTestDtls : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {}
+};
+
+class TlsExtensionTest12Plus
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTest12Plus()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTest12
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTest12()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTest13 : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::string> {
+ public:
+ TlsExtensionTest13()
+ : TlsExtensionTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+ void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) {
+ DataBuffer versions_buf(buf, len);
+ client_->SetPacketFilter(new TlsExtensionReplacer(
+ ssl_tls13_supported_versions_xtn, versions_buf));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+ }
+
+ void ConnectWithReplacementVersionList(uint16_t version) {
+ DataBuffer versions_buf;
+
+ size_t index = versions_buf.Write(0, 2, 1);
+ versions_buf.Write(index, version, 2);
+ client_->SetPacketFilter(new TlsExtensionReplacer(
+ ssl_tls13_supported_versions_xtn, versions_buf));
+ ConnectExpectFail();
+ }
+};
+
+class TlsExtensionTest13Stream : public TlsExtensionTestBase {
+ public:
+ TlsExtensionTest13Stream()
+ : TlsExtensionTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+class TlsExtensionTestGeneric
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTestGeneric()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTestPre13
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTestPre13()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
+ ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 1));
+}
+
+TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
+ ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 4));
+}
+
+TEST_P(TlsExtensionTestGeneric, TruncateSni) {
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_server_name_xtn, 7));
+}
+
+// A valid extension that appears twice will be reported as unsupported.
+TEST_P(TlsExtensionTestGeneric, RepeatSni) {
+ DataBuffer extension;
+ InitSimpleSni(&extension);
+ ClientHelloErrorTest(new TlsExtensionInjector(ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+// An SNI entry with zero length is considered invalid (strangely, not if it is
+// the last entry, which is probably a bug).
+TEST_P(TlsExtensionTestGeneric, BadSni) {
+ DataBuffer simple;
+ InitSimpleSni(&simple);
+ DataBuffer extension;
+ extension.Allocate(simple.len() + 3);
+ extension.Write(0, static_cast<uint32_t>(0), 3);
+ extension.Write(3, simple);
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_server_name_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, EmptySni) {
+ DataBuffer extension;
+ extension.Allocate(2);
+ extension.Write(0, static_cast<uint32_t>(0), 2);
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_server_name_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
+ EnableAlpn();
+ DataBuffer extension;
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+// An empty ALPN isn't considered bad, though it does lead to there being no
+// protocol for the server to select.
+TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertNoApplicationProtocol);
+}
+
+TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
+ EnableAlpn();
+ ClientHelloErrorTest(
+ new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 1));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
+ EnableAlpn();
+ // This will leave the length of the second entry, but no value.
+ ClientHelloErrorTest(
+ new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 5));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
+ const uint8_t client_alpn[] = {0x01, 0x61};
+ client_->EnableAlpn(client_alpn, sizeof(client_alpn));
+ const uint8_t server_alpn[] = {0x02, 0x61, 0x62};
+ server_->EnableAlpn(server_alpn, sizeof(server_alpn));
+
+ ClientHelloErrorTest(nullptr, kTlsAlertNoApplicationProtocol);
+}
+
+// Many of these tests fail in TLS 1.3 because the extension is encrypted, which
+// prevents modification of the value from the ServerHello.
+TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x02, 0x99, 0x61};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestDtls, SrtpShort) {
+ EnableSrtp();
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_use_srtp_xtn, 3));
+}
+
+TEST_P(TlsExtensionTestDtls, SrtpOdd) {
+ EnableSrtp();
+ const uint8_t val[] = {0x00, 0x01, 0xff, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_use_srtp_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
+ const uint8_t val[] = {0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
+ const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
+ const uint8_t val[] = {0x00, 0x01, 0x04};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) {
+ ClientHelloErrorTest(new TlsExtensionDropper(ssl_supported_groups_xtn),
+ version_ < SSL_LIBRARY_VERSION_TLS_1_3
+ ? kTlsAlertDecryptError
+ : kTlsAlertMissingExtension);
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
+ const uint8_t val[] = {0x00, 0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
+ const uint8_t val[] = {0x09, 0x99, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
+ const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) {
+ const uint8_t val[] = {0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) {
+ const uint8_t val[] = {0x99, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) {
+ const uint8_t val[] = {0x01, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) {
+ const uint8_t val[] = {0x99};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) {
+ const uint8_t val[] = {0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+}
+
+// The extension has to contain a length.
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) {
+ DataBuffer extension;
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+}
+
+// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl
+// picks the wrong cipher suite.
+TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
+ const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512,
+ ssl_sig_rsa_pss_sha384};
+
+ TlsExtensionCapture* capture =
+ new TlsExtensionCapture(ssl_signature_algorithms_xtn);
+ client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes));
+ client_->SetPacketFilter(capture);
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+
+ const DataBuffer& ext = capture->extension();
+ EXPECT_EQ(2 + PR_ARRAY_SIZE(schemes) * 2, ext.len());
+ for (size_t i = 0, cursor = 2;
+ i < PR_ARRAY_SIZE(schemes) && cursor < ext.len(); ++i) {
+ uint32_t v = 0;
+ EXPECT_TRUE(ext.Read(cursor, 2, &v));
+ cursor += 2;
+ EXPECT_EQ(schemes[i], static_cast<SSLSignatureScheme>(v));
+ }
+}
+
+// Temporary test to verify that we choke on an empty ClientKeyShare.
+// This test will fail when we implement HelloRetryRequest.
+TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_tls13_key_share_xtn, 2),
+ kTlsAlertHandshakeFailure);
+}
+
+// These tests only work in stream mode because the client sends a
+// cleartext alert which causes a MAC error on the server. With
+// stream this causes handshake failure but with datagram, the
+// packet gets dropped.
+TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) {
+ EnsureTlsSetup();
+ server_->SetPacketFilter(new TlsExtensionDropper(ssl_tls13_key_share_xtn));
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
+TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) {
+ const uint16_t wrong_group = ssl_grp_ec_secp384r1;
+
+ static const uint8_t key_share[] = {
+ wrong_group >> 8,
+ wrong_group & 0xff, // Group we didn't offer.
+ 0x00,
+ 0x02, // length = 2
+ 0x01,
+ 0x02};
+ DataBuffer buf(key_share, sizeof(key_share));
+ EnsureTlsSetup();
+ server_->SetPacketFilter(
+ new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf));
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
+// TODO(ekr@rtfm.com): This is the wrong error code. See bug 1307269.
+TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
+ const uint16_t wrong_group = 0xffff;
+
+ static const uint8_t key_share[] = {
+ wrong_group >> 8,
+ wrong_group & 0xff, // Group we didn't offer.
+ 0x00,
+ 0x02, // length = 2
+ 0x01,
+ 0x02};
+ DataBuffer buf(key_share, sizeof(key_share));
+ EnsureTlsSetup();
+ server_->SetPacketFilter(
+ new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf));
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
+TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) {
+ SetupForResume();
+ DataBuffer empty;
+ server_->SetPacketFilter(
+ new TlsExtensionInjector(ssl_signature_algorithms_xtn, empty));
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_EXTENSION_DISALLOWED_FOR_VERSION, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
+struct PskIdentity {
+ DataBuffer identity;
+ uint32_t obfuscated_ticket_age;
+};
+
+class TlsPreSharedKeyReplacer;
+
+typedef std::function<void(TlsPreSharedKeyReplacer*)>
+ TlsPreSharedKeyReplacerFunc;
+
+class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
+ public:
+ TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function)
+ : identities_(), binders_(), function_(function) {}
+
+ static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size,
+ const std::unique_ptr<DataBuffer>& replace,
+ size_t index, DataBuffer* output) {
+ DataBuffer tmp;
+ bool ret = parser->ReadVariable(&tmp, size);
+ EXPECT_EQ(true, ret);
+ if (!ret) return 0;
+ if (replace) {
+ tmp = *replace;
+ }
+
+ return WriteVariable(output, index, tmp, size);
+ }
+
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != ssl_tls13_pre_shared_key_xtn) {
+ return KEEP;
+ }
+
+ if (!Decode(input)) {
+ return KEEP;
+ }
+
+ // Call the function.
+ function_(this);
+
+ Encode(output);
+
+ return CHANGE;
+ }
+
+ std::vector<PskIdentity> identities_;
+ std::vector<DataBuffer> binders_;
+
+ private:
+ bool Decode(const DataBuffer& input) {
+ std::unique_ptr<TlsParser> parser(new TlsParser(input));
+ DataBuffer identities;
+
+ if (!parser->ReadVariable(&identities, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ DataBuffer binders;
+ if (!parser->ReadVariable(&binders, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+ EXPECT_EQ(0UL, parser->remaining());
+
+ // Now parse the inner sections.
+ parser.reset(new TlsParser(identities));
+ while (parser->remaining()) {
+ PskIdentity identity;
+
+ if (!parser->ReadVariable(&identity.identity, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ if (!parser->Read(&identity.obfuscated_ticket_age, 4)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ identities_.push_back(identity);
+ }
+
+ parser.reset(new TlsParser(binders));
+ while (parser->remaining()) {
+ DataBuffer binder;
+
+ if (!parser->ReadVariable(&binder, 1)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ binders_.push_back(binder);
+ }
+
+ return true;
+ }
+
+ void Encode(DataBuffer* output) {
+ DataBuffer identities;
+ size_t index = 0;
+ for (auto id : identities_) {
+ index = WriteVariable(&identities, index, id.identity, 2);
+ index = identities.Write(index, id.obfuscated_ticket_age, 4);
+ }
+
+ DataBuffer binders;
+ index = 0;
+ for (auto binder : binders_) {
+ index = WriteVariable(&binders, index, binder, 1);
+ }
+
+ output->Truncate(0);
+ index = 0;
+ index = WriteVariable(output, index, identities, 2);
+ index = WriteVariable(output, index, binders, 2);
+ }
+
+ TlsPreSharedKeyReplacerFunc function_;
+};
+
+TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
+ SetupForResume();
+
+ client_->SetPacketFilter(new TlsPreSharedKeyReplacer([](
+ TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Flip the first byte of the binder.
+TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
+ SetupForResume();
+
+ client_->SetPacketFilter(
+ new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
+ }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+// Extend the binder by one.
+TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
+ SetupForResume();
+
+ client_->SetPacketFilter(
+ new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ r->binders_[0].Write(r->binders_[0].len(), 0xff, 1);
+ }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Binders must be at least 32 bytes.
+TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
+ SetupForResume();
+
+ client_->SetPacketFilter(new TlsPreSharedKeyReplacer(
+ [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Duplicate the identity and binder. This will fail with an error
+// processing the binder (because we extended the identity list.)
+TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
+ SetupForResume();
+
+ client_->SetPacketFilter(
+ new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ r->identities_.push_back(r->identities_[0]);
+ r->binders_.push_back(r->binders_[0]);
+ }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+// The next two tests have mismatches in the number of identities
+// and binders. This generates an illegal parameter alert.
+TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
+ SetupForResume();
+
+ client_->SetPacketFilter(
+ new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ r->identities_.push_back(r->identities_[0]);
+ }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) {
+ SetupForResume();
+
+ client_->SetPacketFilter(new TlsPreSharedKeyReplacer([](
+ TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
+ SetupForResume();
+
+ const uint8_t empty_buf[] = {0};
+ DataBuffer empty(empty_buf, 0);
+ client_->SetPacketFilter(
+ // Inject an unused extension.
+ new TlsExtensionAppender(0xffff, empty));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
+ SetupForResume();
+
+ DataBuffer empty;
+ client_->SetPacketFilter(
+ new TlsExtensionDropper(ssl_tls13_psk_key_exchange_modes_xtn));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
+}
+
+// The following test contains valid but unacceptable PreSharedKey
+// modes and therefore produces non-resumption followed by MAC
+// errors.
+TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
+ SetupForResume();
+ const static uint8_t ke_modes[] = {1, // Length
+ kTls13PskKe};
+
+ DataBuffer modes(ke_modes, sizeof(ke_modes));
+ client_->SetPacketFilter(
+ new TlsExtensionReplacer(ssl_tls13_psk_key_exchange_modes_xtn, modes));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ auto capture = new TlsExtensionCapture(ssl_tls13_psk_key_exchange_modes_xtn);
+ client_->SetPacketFilter(capture);
+ Connect();
+ EXPECT_FALSE(capture->captured());
+}
+
+// In these tests, we downgrade to TLS 1.2, causing the
+// server to negotiate TLS 1.2.
+// 1. Both sides only support TLS 1.3, so we get a cipher version
+// error.
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) {
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+}
+
+// 2. Server supports 1.2 and 1.3, client supports 1.2, so we
+// can't negotiate any ciphers.
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListServerV12) {
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// 3. Server supports 1.2 and 1.3, client supports 1.2 and 1.3
+// but advertises 1.2 (because we changed things).
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+#ifndef TLS_1_3_DRAFT_VERSION
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+#else
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+#endif
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveSignatureAlgorithms) {
+ HrrThenRemoveExtensionsTest(ssl_signature_algorithms_xtn,
+ SSL_ERROR_MISSING_EXTENSION_ALERT,
+ SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveKeyShare) {
+ HrrThenRemoveExtensionsTest(ssl_tls13_key_share_xtn,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT,
+ SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveSupportedGroups) {
+ HrrThenRemoveExtensionsTest(ssl_supported_groups_xtn,
+ SSL_ERROR_MISSING_EXTENSION_ALERT,
+ SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
+}
+
+TEST_P(TlsExtensionTest13, EmptyVersionList) {
+ static const uint8_t ext[] = {0x00, 0x00};
+ ConnectWithBogusVersionList(ext, sizeof(ext));
+}
+
+TEST_P(TlsExtensionTest13, OddVersionList) {
+ static const uint8_t ext[] = {0x00, 0x01, 0x00};
+ ConnectWithBogusVersionList(ext, sizeof(ext));
+}
+
+INSTANTIATE_TEST_CASE_P(ExtensionStream, TlsExtensionTestGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_CASE_P(ExtensionDatagram, TlsExtensionTestGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV11Plus));
+INSTANTIATE_TEST_CASE_P(ExtensionDatagramOnly, TlsExtensionTestDtls,
+ TlsConnectTestBase::kTlsV11Plus);
+
+INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12Plus));
+
+INSTANTIATE_TEST_CASE_P(ExtensionPre13Stream, TlsExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_CASE_P(ExtensionPre13Datagram, TlsExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV11V12));
+
+INSTANTIATE_TEST_CASE_P(ExtensionTls13, TlsExtensionTest13,
+ TlsConnectTestBase::kTlsModesAll);
+
+} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
new file mode 100644
index 0000000..d144cd7
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
@@ -0,0 +1,223 @@
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "blapi.h"
+#include "ssl.h"
+#include "sslimpl.h"
+#include "tls_connect.h"
+
+#include "gtest/gtest.h"
+
+namespace nss_test {
+
+#ifdef UNSAFE_FUZZER_MODE
+
+const uint8_t kShortEmptyFinished[8] = {0};
+const uint8_t kLongEmptyFinished[128] = {0};
+
+class TlsFuzzTest : public ::testing::Test {};
+
+// Record the application data stream.
+class TlsApplicationDataRecorder : public TlsRecordFilter {
+ public:
+ TlsApplicationDataRecorder() : buffer_() {}
+
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (header.content_type() == kTlsApplicationDataType) {
+ buffer_.Append(input);
+ }
+
+ return KEEP;
+ }
+
+ const DataBuffer& buffer() const { return buffer_; }
+
+ private:
+ DataBuffer buffer_;
+};
+
+// Damages an SKE or CV signature.
+class TlsSignatureDamager : public TlsHandshakeFilter {
+ public:
+ TlsSignatureDamager(uint8_t type) : type_(type) {}
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ if (header.handshake_type() != type_) {
+ return KEEP;
+ }
+
+ *output = input;
+
+ // Modify the last byte of the signature.
+ output->data()[output->len() - 1]++;
+ return CHANGE;
+ }
+
+ private:
+ uint8_t type_;
+};
+
+void ResetState() {
+ // Clear the list of RSA blinding params.
+ BL_Cleanup();
+
+ // Reinit the list of RSA blinding params.
+ EXPECT_EQ(SECSuccess, BL_Init());
+
+ // Reset the RNG state.
+ EXPECT_EQ(SECSuccess, RNG_ResetForFuzzing());
+}
+
+// Ensure that ssl_Time() returns a constant value.
+TEST_F(TlsFuzzTest, Fuzz_SSL_Time_Constant) {
+ PRInt32 now = ssl_Time();
+ PR_Sleep(PR_SecondsToInterval(2));
+ EXPECT_EQ(ssl_Time(), now);
+}
+
+// Check that due to the deterministic PRNG we derive
+// the same master secret in two consecutive TLS sessions.
+TEST_P(TlsConnectGeneric, Fuzz_DeterministicExporter) {
+ const char kLabel[] = "label";
+ std::vector<unsigned char> out1(32), out2(32);
+
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ DisableECDHEServerKeyReuse();
+
+ ResetState();
+ Connect();
+
+ // Export a key derived from the MS and nonces.
+ SECStatus rv =
+ SSL_ExportKeyingMaterial(client_->ssl_fd(), kLabel, strlen(kLabel), false,
+ NULL, 0, out1.data(), out1.size());
+ EXPECT_EQ(SECSuccess, rv);
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ DisableECDHEServerKeyReuse();
+
+ ResetState();
+ Connect();
+
+ // Export another key derived from the MS and nonces.
+ rv = SSL_ExportKeyingMaterial(client_->ssl_fd(), kLabel, strlen(kLabel),
+ false, NULL, 0, out2.data(), out2.size());
+ EXPECT_EQ(SECSuccess, rv);
+
+ // The two exported keys should be the same.
+ EXPECT_EQ(out1, out2);
+}
+
+// Check that due to the deterministic RNG two consecutive
+// TLS sessions will have the exact same transcript.
+TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) {
+ // Connect a few times and compare the transcripts byte-by-byte.
+ DataBuffer last;
+ for (size_t i = 0; i < 5; i++) {
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ DisableECDHEServerKeyReuse();
+
+ DataBuffer buffer;
+ client_->SetPacketFilter(new TlsConversationRecorder(buffer));
+ server_->SetPacketFilter(new TlsConversationRecorder(buffer));
+
+ ResetState();
+ Connect();
+
+ // Ensure the filters go away before |buffer| does.
+ client_->SetPacketFilter(nullptr);
+ server_->SetPacketFilter(nullptr);
+
+ if (last.len() > 0) {
+ EXPECT_EQ(last, buffer);
+ }
+
+ last = buffer;
+ }
+}
+
+// Check that we can establish and use a connection
+// with all supported TLS versions, STREAM and DGRAM.
+// Check that records are NOT encrypted.
+// Check that records don't have a MAC.
+TEST_P(TlsConnectGeneric, Fuzz_ConnectSendReceive_NullCipher) {
+ EnsureTlsSetup();
+
+ // Set up app data filters.
+ auto client_recorder = new TlsApplicationDataRecorder();
+ client_->SetPacketFilter(client_recorder);
+ auto server_recorder = new TlsApplicationDataRecorder();
+ server_->SetPacketFilter(server_recorder);
+
+ Connect();
+
+ // Construct the plaintext.
+ DataBuffer buf;
+ buf.Allocate(50);
+ for (size_t i = 0; i < buf.len(); ++i) {
+ buf.data()[i] = i & 0xff;
+ }
+
+ // Send/Receive data.
+ client_->SendBuffer(buf);
+ server_->SendBuffer(buf);
+ Receive(buf.len());
+
+ // Check for plaintext on the wire.
+ EXPECT_EQ(buf, client_recorder->buffer());
+ EXPECT_EQ(buf, server_recorder->buffer());
+}
+
+// Check that an invalid Finished message doesn't abort the connection.
+TEST_P(TlsConnectGeneric, Fuzz_BogusClientFinished) {
+ EnsureTlsSetup();
+
+ auto i1 = new TlsInspectorReplaceHandshakeMessage(
+ kTlsHandshakeFinished,
+ DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished)));
+ client_->SetPacketFilter(i1);
+ Connect();
+ SendReceive();
+}
+
+// Check that an invalid Finished message doesn't abort the connection.
+TEST_P(TlsConnectGeneric, Fuzz_BogusServerFinished) {
+ EnsureTlsSetup();
+
+ auto i1 = new TlsInspectorReplaceHandshakeMessage(
+ kTlsHandshakeFinished,
+ DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished)));
+ server_->SetPacketFilter(i1);
+ Connect();
+ SendReceive();
+}
+
+// Check that an invalid server auth signature doesn't abort the connection.
+TEST_P(TlsConnectGeneric, Fuzz_BogusServerAuthSignature) {
+ EnsureTlsSetup();
+ uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3
+ ? kTlsHandshakeCertificateVerify
+ : kTlsHandshakeServerKeyExchange;
+ server_->SetPacketFilter(new TlsSignatureDamager(msg_type));
+ Connect();
+ SendReceive();
+}
+
+// Check that an invalid client auth signature doesn't abort the connection.
+TEST_P(TlsConnectGeneric, Fuzz_BogusClientAuthSignature) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->SetPacketFilter(
+ new TlsSignatureDamager(kTlsHandshakeCertificateVerify));
+ Connect();
+}
+
+#endif
+}
diff --git a/nss/gtests/ssl_gtest/ssl_gtest.cc b/nss/gtests/ssl_gtest/ssl_gtest.cc
new file mode 100644
index 0000000..2d08dd8
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_gtest.cc
@@ -0,0 +1,44 @@
+#include "nspr.h"
+#include "nss.h"
+#include "prenv.h"
+#include "ssl.h"
+
+#include <cstdlib>
+
+#include "test_io.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+std::string g_working_dir_path;
+bool g_ssl_gtest_verbose;
+
+int main(int argc, char** argv) {
+ // Start the tests
+ ::testing::InitGoogleTest(&argc, argv);
+ g_working_dir_path = ".";
+ g_ssl_gtest_verbose = false;
+
+ char* workdir = PR_GetEnvSecure("NSS_GTEST_WORKDIR");
+ if (workdir) g_working_dir_path = workdir;
+
+ for (int i = 0; i < argc; i++) {
+ if (!strcmp(argv[i], "-d")) {
+ g_working_dir_path = argv[i + 1];
+ ++i;
+ } else if (!strcmp(argv[i], "-v")) {
+ g_ssl_gtest_verbose = true;
+ }
+ }
+
+ NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB,
+ NSS_INIT_READONLY);
+ NSS_SetDomesticPolicy();
+ int rv = RUN_ALL_TESTS();
+
+ NSS_Shutdown();
+
+ nss_test::Poller::Shutdown();
+
+ return rv;
+}
diff --git a/nss/gtests/ssl_gtest/ssl_gtest.gyp b/nss/gtests/ssl_gtest/ssl_gtest.gyp
new file mode 100644
index 0000000..e232a8b
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_gtest.gyp
@@ -0,0 +1,101 @@
+# This Source Code Form is subject to the terms of the Mozilla Public
+# License, v. 2.0. If a copy of the MPL was not distributed with this
+# file, You can obtain one at http://mozilla.org/MPL/2.0/.
+{
+ 'includes': [
+ '../../coreconf/config.gypi',
+ '../common/gtest.gypi',
+ ],
+ 'targets': [
+ {
+ 'target_name': 'ssl_gtest',
+ 'type': 'executable',
+ 'sources': [
+ 'libssl_internals.c',
+ 'ssl_0rtt_unittest.cc',
+ 'ssl_agent_unittest.cc',
+ 'ssl_auth_unittest.cc',
+ 'ssl_cert_ext_unittest.cc',
+ 'ssl_ciphersuite_unittest.cc',
+ 'ssl_damage_unittest.cc',
+ 'ssl_dhe_unittest.cc',
+ 'ssl_drop_unittest.cc',
+ 'ssl_ecdh_unittest.cc',
+ 'ssl_ems_unittest.cc',
+ 'ssl_exporter_unittest.cc',
+ 'ssl_extension_unittest.cc',
+ 'ssl_fuzz_unittest.cc',
+ 'ssl_gtest.cc',
+ 'ssl_hrr_unittest.cc',
+ 'ssl_loopback_unittest.cc',
+ 'ssl_record_unittest.cc',
+ 'ssl_resumption_unittest.cc',
+ 'ssl_skip_unittest.cc',
+ 'ssl_staticrsa_unittest.cc',
+ 'ssl_v2_client_hello_unittest.cc',
+ 'ssl_version_unittest.cc',
+ 'test_io.cc',
+ 'tls_agent.cc',
+ 'tls_connect.cc',
+ 'tls_filter.cc',
+ 'tls_hkdf_unittest.cc',
+ 'tls_parser.cc'
+ ],
+ 'dependencies': [
+ '<(DEPTH)/exports.gyp:nss_exports',
+ '<(DEPTH)/lib/util/util.gyp:nssutil3',
+ '<(DEPTH)/lib/sqlite/sqlite.gyp:sqlite3',
+ '<(DEPTH)/gtests/google_test/google_test.gyp:gtest',
+ '<(DEPTH)/lib/softoken/softoken.gyp:softokn',
+ '<(DEPTH)/lib/smime/smime.gyp:smime',
+ '<(DEPTH)/lib/ssl/ssl.gyp:ssl',
+ '<(DEPTH)/lib/nss/nss.gyp:nss_static',
+ '<(DEPTH)/cmd/lib/lib.gyp:sectool',
+ '<(DEPTH)/lib/pkcs12/pkcs12.gyp:pkcs12',
+ '<(DEPTH)/lib/pkcs7/pkcs7.gyp:pkcs7',
+ '<(DEPTH)/lib/certhigh/certhigh.gyp:certhi',
+ '<(DEPTH)/lib/cryptohi/cryptohi.gyp:cryptohi',
+ '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap',
+ '<(DEPTH)/lib/softoken/softoken.gyp:softokn',
+ '<(DEPTH)/lib/certdb/certdb.gyp:certdb',
+ '<(DEPTH)/lib/pki/pki.gyp:nsspki',
+ '<(DEPTH)/lib/dev/dev.gyp:nssdev',
+ '<(DEPTH)/lib/base/base.gyp:nssb',
+ '<(DEPTH)/lib/freebl/freebl.gyp:<(freebl_name)',
+ '<(DEPTH)/lib/zlib/zlib.gyp:nss_zlib'
+ ],
+ 'conditions': [
+ [ 'disable_dbm==0', {
+ 'dependencies': [
+ '<(DEPTH)/lib/dbm/src/src.gyp:dbm',
+ ],
+ }],
+ [ 'disable_libpkix==0', {
+ 'dependencies': [
+ '<(DEPTH)/lib/libpkix/pkix/certsel/certsel.gyp:pkixcertsel',
+ '<(DEPTH)/lib/libpkix/pkix/checker/checker.gyp:pkixchecker',
+ '<(DEPTH)/lib/libpkix/pkix/crlsel/crlsel.gyp:pkixcrlsel',
+ '<(DEPTH)/lib/libpkix/pkix/params/params.gyp:pkixparams',
+ '<(DEPTH)/lib/libpkix/pkix/results/results.gyp:pkixresults',
+ '<(DEPTH)/lib/libpkix/pkix/store/store.gyp:pkixstore',
+ '<(DEPTH)/lib/libpkix/pkix/top/top.gyp:pkixtop',
+ '<(DEPTH)/lib/libpkix/pkix/util/util.gyp:pkixutil',
+ '<(DEPTH)/lib/libpkix/pkix_pl_nss/system/system.gyp:pkixsystem',
+ '<(DEPTH)/lib/libpkix/pkix_pl_nss/module/module.gyp:pkixmodule',
+ '<(DEPTH)/lib/libpkix/pkix_pl_nss/pki/pki.gyp:pkixpki',
+ ],
+ }],
+ ],
+ }
+ ],
+ 'target_defaults': {
+ 'include_dirs': [
+ '../../gtests/google_test/gtest/include',
+ '../../gtests/common',
+ '../../lib/ssl'
+ ],
+ },
+ 'variables': {
+ 'module': 'nss',
+ }
+}
diff --git a/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
new file mode 100644
index 0000000..5d670fa
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -0,0 +1,285 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+// This is internal, just to get TLS_1_3_DRAFT_VERSION.
+#include "ssl3prot.h"
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
+ const char* k0RttData = "Such is life";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+
+ SetupForZeroRtt(); // initial handshake as normal
+
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ // Send first ClientHello and send 0-RTT data
+ auto capture_early_data = new TlsExtensionCapture(ssl_tls13_early_data_xtn);
+ client_->SetPacketFilter(capture_early_data);
+ client_->Handshake();
+ EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData,
+ k0RttDataLen)); // 0-RTT write.
+ EXPECT_TRUE(capture_early_data->captured());
+
+ // Send the HelloRetryRequest
+ auto hrr_capture =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest);
+ server_->SetPacketFilter(hrr_capture);
+ server_->Handshake();
+ EXPECT_LT(0U, hrr_capture->buffer().len());
+
+ // The server can't read
+ std::vector<uint8_t> buf(k0RttDataLen);
+ EXPECT_EQ(SECFailure, PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen));
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Make a new capture for the early data.
+ capture_early_data = new TlsExtensionCapture(ssl_tls13_early_data_xtn);
+ client_->SetPacketFilter(capture_early_data);
+
+ // Complete the handshake successfully
+ Handshake();
+ ExpectEarlyDataAccepted(false); // The server should reject 0-RTT
+ CheckConnected();
+ SendReceive();
+ EXPECT_FALSE(capture_early_data->captured());
+}
+
+class KeyShareReplayer : public TlsExtensionFilter {
+ public:
+ KeyShareReplayer() {}
+
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != ssl_tls13_key_share_xtn) {
+ return KEEP;
+ }
+
+ if (!data_.len()) {
+ data_ = input;
+ return KEEP;
+ }
+
+ *output = data_;
+ return CHANGE;
+ }
+
+ private:
+ DataBuffer data_;
+};
+
+// This forces a HelloRetryRequest by disabling P-256 on the server. However,
+// the second ClientHello is modified so that it omits the requested share. The
+// server should reject this.
+TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
+ EnsureTlsSetup();
+ client_->SetPacketFilter(new KeyShareReplayer());
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
+ EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
+}
+
+// This tests that the second attempt at sending a ClientHello (after receiving
+// a HelloRetryRequest) is correctly retransmitted.
+TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(groups);
+ server_->SetPacketFilter(new SelectiveDropFilter(0x2));
+ Connect();
+}
+
+class TlsKeyExchange13 : public TlsKeyExchangeTest {};
+
+// This should work, with an HRR, because the server prefers x25519 and the
+// client generates a share for P-384 on the initial ClientHello.
+TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrr) {
+ EnsureKeyShareSetup();
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+ Connect();
+ CheckKeys();
+ static const std::vector<SSLNamedGroup> expectedShares = {
+ ssl_grp_ec_secp384r1};
+ CheckKEXDetails(client_groups, expectedShares, ssl_grp_ec_curve25519);
+}
+
+// This should work, but not use HRR because the key share for x25519 was
+// pre-generated by the client.
+TEST_P(TlsKeyExchange13, ConnectEcdhePreferenceMismatchHrrExtraShares) {
+ EnsureKeyShareSetup();
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+ EXPECT_EQ(SECSuccess, SSL_SendAdditionalKeyShares(client_->ssl_fd(), 1));
+
+ Connect();
+ CheckKeys();
+ CheckKEXDetails(client_groups, client_groups);
+}
+
+TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
+ EnsureTlsSetup();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp256r1, ssl_grp_ec_secp521r1};
+ client_->ConfigNamedGroups(client_groups);
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_secp521r1};
+ server_->ConfigNamedGroups(server_groups);
+ client_->StartConnect();
+ server_->StartConnect();
+
+ client_->Handshake();
+ server_->Handshake();
+
+ // Here we replace the TLS server with one that does TLS 1.2 only.
+ // This will happily send the client a TLS 1.2 ServerHello.
+ TlsAgent* replacement_server =
+ new TlsAgent(server_->name(), TlsAgent::SERVER, mode_);
+ delete server_;
+ server_ = replacement_server;
+ server_->Init();
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->StartConnect();
+ Handshake();
+ EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, server_->error_code());
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+}
+
+class HelloRetryRequestAgentTest : public TlsAgentTestClient {
+ protected:
+ void SetUp() override {
+ TlsAgentTestClient::SetUp();
+ EnsureInit();
+ agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ agent_->StartConnect();
+ }
+
+ void MakeCannedHrr(const uint8_t* body, size_t len, DataBuffer* hrr_record,
+ uint32_t seq_num = 0) const {
+ DataBuffer hrr_data;
+ hrr_data.Allocate(len + 4);
+ size_t i = 0;
+ i = hrr_data.Write(i, 0x7f00 | TLS_1_3_DRAFT_VERSION, 2);
+ i = hrr_data.Write(i, static_cast<uint32_t>(len), 2);
+ if (len) {
+ hrr_data.Write(i, body, len);
+ }
+ DataBuffer hrr;
+ MakeHandshakeMessage(kTlsHandshakeHelloRetryRequest, hrr_data.data(),
+ hrr_data.len(), &hrr, seq_num);
+ MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3, hrr.data(),
+ hrr.len(), hrr_record, seq_num);
+ }
+
+ void MakeGroupHrr(SSLNamedGroup group, DataBuffer* hrr_record,
+ uint32_t seq_num = 0) const {
+ const uint8_t group_hrr[] = {
+ static_cast<uint8_t>(ssl_tls13_key_share_xtn >> 8),
+ static_cast<uint8_t>(ssl_tls13_key_share_xtn),
+ 0,
+ 2, // length of key share extension
+ static_cast<uint8_t>(group >> 8),
+ static_cast<uint8_t>(group)};
+ MakeCannedHrr(group_hrr, sizeof(group_hrr), hrr_record, seq_num);
+ }
+};
+
+// Send two HelloRetryRequest messages in response to the ClientHello. The are
+// constructed to appear legitimate by asking for a new share in each, so that
+// the client has to count to work out that the server is being unreasonable.
+TEST_P(HelloRetryRequestAgentTest, SendSecondHelloRetryRequest) {
+ DataBuffer hrr;
+ MakeGroupHrr(ssl_grp_ec_secp384r1, &hrr, 0);
+ ProcessMessage(hrr, TlsAgent::STATE_CONNECTING);
+ MakeGroupHrr(ssl_grp_ec_secp521r1, &hrr, 1);
+ ProcessMessage(hrr, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_UNEXPECTED_HELLO_RETRY_REQUEST);
+}
+
+// Here the client receives a HelloRetryRequest with a group that they already
+// provided a share for.
+TEST_P(HelloRetryRequestAgentTest, HandleBogusHelloRetryRequest) {
+ DataBuffer hrr;
+ MakeGroupHrr(ssl_grp_ec_curve25519, &hrr);
+ ProcessMessage(hrr, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST);
+}
+
+TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) {
+ DataBuffer hrr;
+ MakeCannedHrr(nullptr, 0U, &hrr);
+ ProcessMessage(hrr, TlsAgent::STATE_ERROR,
+ SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST);
+}
+
+TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) {
+ const uint8_t canned_cookie_hrr[] = {
+ static_cast<uint8_t>(ssl_tls13_cookie_xtn >> 8),
+ static_cast<uint8_t>(ssl_tls13_cookie_xtn),
+ 0,
+ 5, // length of cookie extension
+ 0,
+ 3, // cookie value length
+ 0xc0,
+ 0x0c,
+ 0x13};
+ DataBuffer hrr;
+ MakeCannedHrr(canned_cookie_hrr, sizeof(canned_cookie_hrr), &hrr);
+ TlsExtensionCapture* capture = new TlsExtensionCapture(ssl_tls13_cookie_xtn);
+ agent_->SetPacketFilter(capture);
+ ProcessMessage(hrr, TlsAgent::STATE_CONNECTING);
+ const size_t cookie_pos = 2 + 2; // cookie_xtn, extension len
+ DataBuffer cookie(canned_cookie_hrr + cookie_pos,
+ sizeof(canned_cookie_hrr) - cookie_pos);
+ EXPECT_EQ(cookie, capture->extension());
+}
+
+INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest,
+ TlsConnectTestBase::kTlsModesAll);
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_CASE_P(HelloRetryRequestKeyExchangeTests, TlsKeyExchange13,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV13));
+#endif
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
new file mode 100644
index 0000000..65c0ca1
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -0,0 +1,274 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectGeneric, SetupOnly) {}
+
+TEST_P(TlsConnectGeneric, Connect) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Connect();
+ CheckKeys();
+}
+
+TEST_P(TlsConnectGeneric, ConnectEcdsa) {
+ SetExpectedVersion(std::get<1>(GetParam()));
+ Reset(TlsAgent::kServerEcdsa256);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
+}
+
+TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) {
+ EnsureTlsSetup();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
+ server_->EnableSingleCipher(TLS_AES_256_GCM_SHA384);
+ } else {
+ client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
+ }
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectFalseStart) {
+ client_->EnableFalseStart();
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpn) {
+ EnableAlpn();
+ Connect();
+ CheckAlpn("a");
+}
+
+TEST_P(TlsConnectGeneric, ConnectAlpnClone) {
+ EnsureModelSockets();
+ client_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ server_model_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ Connect();
+ CheckAlpn("a");
+}
+
+TEST_P(TlsConnectDatagram, ConnectSrtp) {
+ EnableSrtp();
+ Connect();
+ CheckSrtp();
+ SendReceive();
+}
+
+// 1.3 is disabled in the next few tests because we don't
+// presently support resumption in 1.3.
+TEST_P(TlsConnectStreamPre13, ConnectAndClientRenegotiate) {
+ Connect();
+ server_->PrepareForRenegotiate();
+ client_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+}
+
+TEST_P(TlsConnectStreamPre13, ConnectAndServerRenegotiate) {
+ Connect();
+ client_->PrepareForRenegotiate();
+ server_->StartRenegotiate();
+ Handshake();
+ CheckConnected();
+}
+
+TEST_P(TlsConnectGeneric, ConnectSendReceive) {
+ Connect();
+ SendReceive();
+}
+
+// The next two tests takes advantage of the fact that we
+// automatically read the first 1024 bytes, so if
+// we provide 1200 bytes, they overrun the read buffer
+// provided by the calling test.
+
+// DTLS should return an error.
+TEST_P(TlsConnectDatagram, ShortRead) {
+ Connect();
+ client_->ExpectReadWriteError();
+ server_->SendData(1200, 1200);
+ client_->WaitForErrorCode(SSL_ERROR_RX_SHORT_DTLS_READ, 2000);
+
+ // Now send and receive another packet.
+ server_->ResetSentBytes(); // Reset the counter.
+ SendReceive();
+}
+
+// TLS should get the write in two chunks.
+TEST_P(TlsConnectStream, ShortRead) {
+ // This test behaves oddly with TLS 1.0 because of 1/n+1 splitting,
+ // so skip in that case.
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) return;
+
+ Connect();
+ server_->SendData(1200, 1200);
+ // Read the first tranche.
+ WAIT_(client_->received_bytes() == 1024, 2000);
+ ASSERT_EQ(1024U, client_->received_bytes());
+ // The second tranche should now immediately be available.
+ client_->ReadBytes();
+ ASSERT_EQ(1200U, client_->received_bytes());
+}
+
+TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) {
+ EnsureTlsSetup();
+ client_->EnableCompression();
+ server_->EnableCompression();
+ Connect();
+ EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && mode_ != DGRAM,
+ client_->is_compressed());
+ SendReceive();
+}
+
+TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) {
+ Connect();
+ std::cerr << "Expiring holddown timer\n";
+ SSLInt_ForceTimerExpiry(client_->ssl_fd());
+ SSLInt_ForceTimerExpiry(server_->ssl_fd());
+ SendReceive();
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // One for send, one for receive.
+ EXPECT_EQ(2, SSLInt_CountTls13CipherSpecs(client_->ssl_fd()));
+ }
+}
+
+class TlsPreCCSHeaderInjector : public TlsRecordFilter {
+ public:
+ TlsPreCCSHeaderInjector() {}
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header,
+ const DataBuffer& input,
+ size_t* offset,
+ DataBuffer* output) override {
+ if (record_header.content_type() != kTlsChangeCipherSpecType) return KEEP;
+
+ std::cerr << "Injecting Finished header before CCS\n";
+ const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c};
+ DataBuffer hhdr_buf(hhdr, sizeof(hhdr));
+ RecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0);
+ *offset = nhdr.Write(output, *offset, hhdr_buf);
+ *offset = record_header.Write(output, *offset, input);
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
+ client_->SetPacketFilter(new TlsPreCCSHeaderInjector());
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+}
+
+TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
+ server_->SetPacketFilter(new TlsPreCCSHeaderInjector());
+ client_->StartConnect();
+ server_->StartConnect();
+ Handshake();
+ EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+}
+
+TEST_P(TlsConnectTls13, UnknownAlert) {
+ Connect();
+ SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning,
+ 0xff); // Unknown value.
+ client_->ExpectReadWriteError();
+ client_->WaitForErrorCode(SSL_ERROR_RX_UNKNOWN_ALERT, 2000);
+}
+
+TEST_P(TlsConnectTls13, AlertWrongLevel) {
+ Connect();
+ SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning,
+ kTlsAlertUnexpectedMessage);
+ client_->ExpectReadWriteError();
+ client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000);
+}
+
+TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) {
+ client_->SetShortHeadersEnabled();
+ server_->SetShortHeadersEnabled();
+ client_->ExpectShortHeaders();
+ server_->ExpectShortHeaders();
+ Connect();
+}
+
+TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) {
+ EnsureTlsSetup();
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake(); // Send first flight.
+ client_->adapter()->CloseWrites();
+ client_->Handshake(); // This will get an error, but shouldn't crash.
+ client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE);
+}
+
+INSTANTIATE_TEST_CASE_P(GenericStream, TlsConnectGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_CASE_P(
+ GenericDatagram, TlsConnectGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
+
+INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream,
+ TlsConnectTestBase::kTlsVAll);
+INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram,
+ TlsConnectTestBase::kTlsV11Plus);
+
+INSTANTIATE_TEST_CASE_P(Pre12Stream, TlsConnectPre12,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsV10V11));
+INSTANTIATE_TEST_CASE_P(
+ Pre12Datagram, TlsConnectPre12,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+ TlsConnectTestBase::kTlsV11));
+
+INSTANTIATE_TEST_CASE_P(Version12Only, TlsConnectTls12,
+ TlsConnectTestBase::kTlsModesAll);
+#ifndef NSS_DISABLE_TLS_1_3
+INSTANTIATE_TEST_CASE_P(Version13Only, TlsConnectTls13,
+ TlsConnectTestBase::kTlsModesAll);
+#endif
+
+INSTANTIATE_TEST_CASE_P(Pre13Stream, TlsConnectGenericPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_CASE_P(
+ Pre13Datagram, TlsConnectGenericPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+INSTANTIATE_TEST_CASE_P(Pre13StreamOnly, TlsConnectStreamPre13,
+ TlsConnectTestBase::kTlsV10ToV12);
+
+INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12Plus));
+
+} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_record_unittest.cc b/nss/gtests/ssl_gtest/ssl_record_unittest.cc
new file mode 100644
index 0000000..ef81b22
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_record_unittest.cc
@@ -0,0 +1,111 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "nss.h"
+#include "ssl.h"
+#include "sslimpl.h"
+
+#include "databuffer.h"
+#include "gtest_utils.h"
+
+namespace nss_test {
+
+const static size_t kMacSize = 20;
+
+class TlsPaddingTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<std::tuple<size_t, bool>> {
+ public:
+ TlsPaddingTest() : plaintext_len_(std::get<0>(GetParam())) {
+ size_t extra =
+ (plaintext_len_ + 1) % 16; // Bytes past a block (1 == pad len)
+ // Minimal padding.
+ pad_len_ = extra ? 16 - extra : 0;
+ if (std::get<1>(GetParam())) {
+ // Maximal padding.
+ pad_len_ += 240;
+ }
+ MakePaddedPlaintext();
+ }
+
+ // Makes a plaintext record with correct padding.
+ void MakePaddedPlaintext() {
+ EXPECT_EQ(0UL, (plaintext_len_ + pad_len_ + 1) % 16);
+ size_t i = 0;
+ plaintext_.Allocate(plaintext_len_ + pad_len_ + 1);
+ for (; i < plaintext_len_; ++i) {
+ plaintext_.Write(i, 'A', 1);
+ }
+
+ for (; i < plaintext_len_ + pad_len_ + 1; ++i) {
+ plaintext_.Write(i, pad_len_, 1);
+ }
+ }
+
+ void Unpad(bool expect_success) {
+ std::cerr << "Content length=" << plaintext_len_
+ << " padding length=" << pad_len_
+ << " total length=" << plaintext_.len() << std::endl;
+ std::cerr << "Plaintext: " << plaintext_ << std::endl;
+ sslBuffer s;
+ s.buf = const_cast<unsigned char *>(
+ static_cast<const unsigned char *>(plaintext_.data()));
+ s.len = plaintext_.len();
+ SECStatus rv = ssl_RemoveTLSCBCPadding(&s, kMacSize);
+ if (expect_success) {
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(plaintext_len_, static_cast<size_t>(s.len));
+ } else {
+ EXPECT_EQ(SECFailure, rv);
+ }
+ }
+
+ protected:
+ size_t plaintext_len_;
+ size_t pad_len_;
+ DataBuffer plaintext_;
+};
+
+TEST_P(TlsPaddingTest, Correct) {
+ if (plaintext_len_ >= kMacSize) {
+ Unpad(true);
+ } else {
+ Unpad(false);
+ }
+}
+
+TEST_P(TlsPaddingTest, PadTooLong) {
+ if (plaintext_.len() < 255) {
+ plaintext_.Write(plaintext_.len() - 1, plaintext_.len(), 1);
+ Unpad(false);
+ }
+}
+
+TEST_P(TlsPaddingTest, FirstByteOfPadWrong) {
+ if (pad_len_) {
+ plaintext_.Write(plaintext_len_, plaintext_.data()[plaintext_len_] + 1, 1);
+ Unpad(false);
+ }
+}
+
+TEST_P(TlsPaddingTest, LastByteOfPadWrong) {
+ if (pad_len_) {
+ plaintext_.Write(plaintext_.len() - 2,
+ plaintext_.data()[plaintext_.len() - 1] + 1, 1);
+ Unpad(false);
+ }
+}
+
+const static size_t kContentSizesArr[] = {
+ 1, kMacSize - 1, kMacSize, 30, 31, 32, 36, 256, 257, 287, 288};
+
+auto kContentSizes = ::testing::ValuesIn(kContentSizesArr);
+const static bool kTrueFalseArr[] = {true, false};
+auto kTrueFalse = ::testing::ValuesIn(kTrueFalseArr);
+
+INSTANTIATE_TEST_CASE_P(TlsPadding, TlsPaddingTest,
+ ::testing::Combine(kContentSizes, kTrueFalse));
+} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
new file mode 100644
index 0000000..cfe42cb
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
@@ -0,0 +1,582 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+class TlsServerKeyExchangeEcdhe {
+ public:
+ bool Parse(const DataBuffer& buffer) {
+ TlsParser parser(buffer);
+
+ uint8_t curve_type;
+ if (!parser.Read(&curve_type)) {
+ return false;
+ }
+
+ if (curve_type != 3) { // named_curve
+ return false;
+ }
+
+ uint32_t named_curve;
+ if (!parser.Read(&named_curve, 2)) {
+ return false;
+ }
+
+ return parser.ReadVariable(&public_key_, 1);
+ }
+
+ DataBuffer public_key_;
+};
+
+TEST_P(TlsConnectGenericPre13, ConnectResumed) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ Connect();
+
+ Reset();
+ ExpectResumption(RESUME_SESSIONID);
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) {
+ // This prefers tickets.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) {
+ // This causes no resumption because the client needs the
+ // session cache to resume even with tickets.
+ ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) {
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeClientServerTicketOnly) {
+ // This causes no resumption because the client needs the
+ // session cache to resume even with tickets.
+ ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeClientBothServerNone) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_NONE);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_NONE);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeClientNoneServerBoth) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_BOTH);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_BOTH);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectGenericPre13, ConnectResumeWithHigherVersion) {
+ ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_1);
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1);
+ Connect();
+
+ Reset();
+ EnsureTlsSetup();
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) {
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ClearServerCache();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ Connect();
+ SendReceive();
+}
+
+// This callback switches out the "server" cert used on the server with
+// the "client" certificate, which should be the same type.
+static int32_t SwitchCertificates(TlsAgent* agent, const SECItem* srvNameArr,
+ uint32_t srvNameArrSize) {
+ bool ok = agent->ConfigServerCert("client");
+ if (!ok) return SSL_SNI_SEND_ALERT;
+
+ return 0; // first config
+};
+
+TEST_P(TlsConnectGeneric, ServerSNICertSwitch) {
+ Connect();
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd()));
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ server_->SetSniCallback(SwitchCertificates);
+
+ Connect();
+ ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd()));
+ CheckKeys();
+ EXPECT_FALSE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) {
+ Reset(TlsAgent::kServerEcdsa256);
+ Connect();
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd()));
+
+ Reset();
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+
+ // Because we configure an RSA certificate here, it only adds a second, unused
+ // certificate, which has no effect on what the server uses.
+ server_->SetSniCallback(SwitchCertificates);
+
+ Connect();
+ ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd()));
+ CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
+ EXPECT_TRUE(SECITEM_ItemsAreEqual(&cert1->derCert, &cert2->derCert));
+}
+
+// Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that
+TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
+ TlsInspectorRecordHandshakeMessage* i1 =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ server_->SetPacketFilter(i1);
+ Connect();
+ CheckKeys();
+ TlsServerKeyExchangeEcdhe dhe1;
+ EXPECT_TRUE(dhe1.Parse(i1->buffer()));
+
+ // Restart
+ Reset();
+ TlsInspectorRecordHandshakeMessage* i2 =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ server_->SetPacketFilter(i2);
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ Connect();
+ CheckKeys();
+
+ TlsServerKeyExchangeEcdhe dhe2;
+ EXPECT_TRUE(dhe2.Parse(i2->buffer()));
+
+ // Make sure they are the same.
+ EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len());
+ EXPECT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
+ dhe1.public_key_.len()));
+}
+
+// This test parses the ServerKeyExchange, which isn't in 1.3
+TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
+ server_->EnsureTlsSetup();
+ SECStatus rv =
+ SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ TlsInspectorRecordHandshakeMessage* i1 =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ server_->SetPacketFilter(i1);
+ Connect();
+ CheckKeys();
+ TlsServerKeyExchangeEcdhe dhe1;
+ EXPECT_TRUE(dhe1.Parse(i1->buffer()));
+
+ // Restart
+ Reset();
+ server_->EnsureTlsSetup();
+ rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ TlsInspectorRecordHandshakeMessage* i2 =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ server_->SetPacketFilter(i2);
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ Connect();
+ CheckKeys();
+
+ TlsServerKeyExchangeEcdhe dhe2;
+ EXPECT_TRUE(dhe2.Parse(i2->buffer()));
+
+ // Make sure they are different.
+ EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) &&
+ (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(),
+ dhe1.public_key_.len())));
+}
+
+// Verify that TLS 1.3 reports an accurate group on resumption.
+TEST_P(TlsConnectTls13, TestTls13ResumeDifferentGroup) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ client_->ConfigNamedGroups(kFFDHEGroups);
+ server_->ConfigNamedGroups(kFFDHEGroups);
+ Connect();
+ CheckKeys(ssl_kea_dh, ssl_grp_ffdhe_2048, ssl_auth_rsa_sign, ssl_sig_none);
+}
+
+// We need to enable different cipher suites at different times in the following
+// tests. Those cipher suites need to be suited to the version.
+static uint16_t ChooseOneCipher(uint16_t version) {
+ if (version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return TLS_AES_128_GCM_SHA256;
+ }
+ return TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA;
+}
+
+static uint16_t ChooseAnotherCipher(uint16_t version) {
+ if (version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return TLS_AES_256_GCM_SHA384;
+ }
+ return TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA;
+}
+
+// Test that we don't resume when we can't negotiate the same cipher.
+TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ client_->EnableSingleCipher(ChooseAnotherCipher(version_));
+ uint16_t ticket_extension;
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ticket_extension = ssl_tls13_pre_shared_key_xtn;
+ } else {
+ ticket_extension = ssl_session_ticket_xtn;
+ }
+ auto ticket_capture = new TlsExtensionCapture(ticket_extension);
+ client_->SetPacketFilter(ticket_capture);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+ EXPECT_EQ(0U, ticket_capture->extension().len());
+}
+
+// Test that we don't resume when we can't negotiate the same cipher.
+TEST_P(TlsConnectGeneric, TestResumeServerDifferentCipher) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+ server_->EnableSingleCipher(ChooseAnotherCipher(version_));
+ Connect();
+ CheckKeys();
+}
+
+class SelectedCipherSuiteReplacer : public TlsHandshakeFilter {
+ public:
+ SelectedCipherSuiteReplacer(uint16_t suite) : cipher_suite_(suite) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ if (header.handshake_type() != kTlsHandshakeServerHello) {
+ return KEEP;
+ }
+
+ *output = input;
+ uint32_t temp = 0;
+ EXPECT_TRUE(input.Read(0, 2, &temp));
+ // Cipher suite is after version(2) and random(32).
+ size_t pos = 34;
+ if (temp < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In old versions, we have to skip a session_id too.
+ EXPECT_TRUE(input.Read(pos, 1, &temp));
+ pos += 1 + temp;
+ }
+ output->Write(pos, static_cast<uint32_t>(cipher_suite_), 2);
+ return CHANGE;
+ }
+
+ private:
+ uint16_t cipher_suite_;
+};
+
+// Test that the client doesn't tolerate the server picking a different cipher
+// suite for resumption.
+TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->EnableSingleCipher(ChooseOneCipher(version_));
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->SetPacketFilter(
+ new SelectedCipherSuiteReplacer(ChooseAnotherCipher(version_)));
+
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // The reason this test is stream only: the server is unable to decrypt
+ // the alert that the client sends, see bug 1304603.
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+ }
+}
+
+class SelectedVersionReplacer : public TlsHandshakeFilter {
+ public:
+ SelectedVersionReplacer(uint16_t version) : version_(version) {}
+
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ if (header.handshake_type() != kTlsHandshakeServerHello) {
+ return KEEP;
+ }
+
+ *output = input;
+ output->Write(0, static_cast<uint32_t>(version_), 2);
+ return CHANGE;
+ }
+
+ private:
+ uint16_t version_;
+};
+
+// Test how the client handles the case where the server picks a
+// lower version number on resumption.
+TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) {
+ uint16_t override_version = 0;
+ if (mode_ == STREAM) {
+ switch (version_) {
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ return; // Skip the test.
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ override_version = SSL_LIBRARY_VERSION_TLS_1_0;
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ override_version = SSL_LIBRARY_VERSION_TLS_1_1;
+ break;
+ default:
+ ASSERT_TRUE(false) << "unknown version";
+ }
+ } else {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_2) {
+ override_version = SSL_LIBRARY_VERSION_DTLS_1_0_WIRE;
+ } else {
+ ASSERT_EQ(SSL_LIBRARY_VERSION_TLS_1_1, version_);
+ return; // Skip the test.
+ }
+ }
+
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ // Need to use a cipher that is plausible for the lower version.
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ Connect();
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ // Enable the lower version on the client.
+ client_->SetVersionRange(version_ - 1, version_);
+ server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ server_->SetPacketFilter(new SelectedVersionReplacer(override_version));
+
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
+}
+
+// Test that two TLS resumptions work and produce the same ticket.
+// This will change after bug 1257047 is fixed.
+TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+ uint16_t original_suite;
+ EXPECT_TRUE(client_->cipher_suite(&original_suite));
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+ TlsExtensionCapture* c1 =
+ new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
+ client_->SetPacketFilter(c1);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_none);
+ // The filter will go away when we reset, so save the captured extension.
+ DataBuffer initialTicket(c1->extension());
+ ASSERT_LT(0U, initialTicket.len());
+
+ ScopedCERTCertificate cert1(SSL_PeerCertificate(client_->ssl_fd()));
+ ASSERT_TRUE(!!cert1.get());
+
+ Reset();
+ ClearStats();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ TlsExtensionCapture* c2 =
+ new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
+ client_->SetPacketFilter(c2);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+ CheckKeys(ssl_kea_ecdh, ssl_grp_ec_curve25519, ssl_auth_rsa_sign,
+ ssl_sig_none);
+ ASSERT_LT(0U, c2->extension().len());
+
+ ScopedCERTCertificate cert2(SSL_PeerCertificate(client_->ssl_fd()));
+ ASSERT_TRUE(!!cert2.get());
+
+ // Check that the cipher suite is reported the same on both sides, though in
+ // TLS 1.3 resumption actually negotiates a different cipher suite.
+ uint16_t resumed_suite;
+ EXPECT_TRUE(server_->cipher_suite(&resumed_suite));
+ EXPECT_EQ(original_suite, resumed_suite);
+ EXPECT_TRUE(client_->cipher_suite(&resumed_suite));
+ EXPECT_EQ(original_suite, resumed_suite);
+
+ ASSERT_NE(initialTicket, c2->extension());
+}
+
+// Check that resumption works after receiving two NST messages.
+TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+
+ // Clear the session ticket keys to invalidate the old ticket.
+ SSLInt_ClearSessionTicketKey();
+ SSLInt_SendNewSessionTicket(server_->ssl_fd());
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Resume the connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
new file mode 100644
index 0000000..523a374
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -0,0 +1,158 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "sslerr.h"
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+/*
+ * The tests in this file test that the TLS state machine is robust against
+ * attacks that alter the order of handshake messages.
+ *
+ * See <https://www.smacktls.com/smack.pdf> for a description of the problems
+ * that this sort of attack can enable.
+ */
+namespace nss_test {
+
+class TlsHandshakeSkipFilter : public TlsRecordFilter {
+ public:
+ // A TLS record filter that skips handshake messages of the identified type.
+ TlsHandshakeSkipFilter(uint8_t handshake_type)
+ : handshake_type_(handshake_type), skipped_(false) {}
+
+ protected:
+ // Takes a record; if it is a handshake record, it removes the first handshake
+ // message that is of handshake_type_ type.
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (record_header.content_type() != kTlsHandshakeType) {
+ return KEEP;
+ }
+
+ size_t output_offset = 0U;
+ output->Allocate(input.len());
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ size_t start = parser.consumed();
+ TlsHandshakeFilter::HandshakeHeader header;
+ DataBuffer ignored;
+ if (!header.Parse(&parser, record_header, &ignored)) {
+ return KEEP;
+ }
+
+ if (skipped_ || header.handshake_type() != handshake_type_) {
+ size_t entire_length = parser.consumed() - start;
+ output->Write(output_offset, input.data() + start, entire_length);
+ // DTLS sequence numbers need to be rewritten
+ if (skipped_ && header.is_dtls()) {
+ output->data()[start + 5] -= 1;
+ }
+ output_offset += entire_length;
+ } else {
+ std::cerr << "Dropping handshake: "
+ << static_cast<unsigned>(handshake_type_) << std::endl;
+ // We only need to report that the output contains changed data if we
+ // drop a handshake message. But once we've skipped one message, we
+ // have to modify all subsequent handshake messages so that they include
+ // the correct DTLS sequence numbers.
+ skipped_ = true;
+ }
+ }
+ output->Truncate(output_offset);
+ return skipped_ ? CHANGE : KEEP;
+ }
+
+ private:
+ // The type of handshake message to drop.
+ uint8_t handshake_type_;
+ // Whether this filter has ever skipped a handshake message. Track this so
+ // that sequence numbers on DTLS handshake messages can be rewritten in
+ // subsequent calls.
+ bool skipped_;
+};
+
+class TlsSkipTest
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ protected:
+ TlsSkipTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+ void ServerSkipTest(PacketFilter* filter,
+ uint8_t alert = kTlsAlertUnexpectedMessage) {
+ auto alert_recorder = new TlsAlertRecorder();
+ client_->SetPacketFilter(alert_recorder);
+ if (filter) {
+ server_->SetPacketFilter(filter);
+ }
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(alert, alert_recorder->description());
+ }
+};
+
+TEST_P(TlsSkipTest, SkipCertificateRsa) {
+ EnableOnlyStaticRsaCiphers();
+ ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateDhe) {
+ ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
+ ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
+}
+
+TEST_P(TlsSkipTest, SkipServerKeyExchange) {
+ ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
+ auto chain = new ChainedPacketFilter();
+ chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(chain);
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
+ Reset(TlsAgent::kServerEcdsa256);
+ auto chain = new ChainedPacketFilter();
+ chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(chain);
+ client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
+}
+
+INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsV10));
+INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV11V12));
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
new file mode 100644
index 0000000..baf24ed
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -0,0 +1,123 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <functional>
+#include <memory>
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+const uint8_t kBogusClientKeyExchange[] = {
+ 0x01, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+ 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
+};
+
+TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) {
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+ CheckKeys(ssl_kea_rsa, ssl_grp_none, ssl_auth_rsa_decrypt, ssl_sig_none);
+}
+
+// Test that a totally bogus EPMS is handled correctly.
+// This test is stream so we can catch the bad_record_mac alert.
+TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
+ EnableOnlyStaticRsaCiphers();
+ TlsInspectorReplaceHandshakeMessage* i1 =
+ new TlsInspectorReplaceHandshakeMessage(
+ kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
+ client_->SetPacketFilter(i1);
+ auto alert_recorder = new TlsAlertRecorder();
+ server_->SetPacketFilter(alert_recorder);
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
+}
+
+// Test that a PMS with a bogus version number is handled correctly.
+// This test is stream so we can catch the bad_record_mac alert.
+TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
+ EnableOnlyStaticRsaCiphers();
+ client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_));
+ auto alert_recorder = new TlsAlertRecorder();
+ server_->SetPacketFilter(alert_recorder);
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
+}
+
+// Test that a PMS with a bogus version number is ignored when
+// rollback detection is disabled. This is a positive control for
+// ConnectStaticRSABogusPMSVersionDetect.
+TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
+ EnableOnlyStaticRsaCiphers();
+ client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_));
+ server_->DisableRollbackDetection();
+ Connect();
+}
+
+// This test is stream so we can catch the bad_record_mac alert.
+TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
+ EnableOnlyStaticRsaCiphers();
+ EnableExtendedMasterSecret();
+ TlsInspectorReplaceHandshakeMessage* inspect =
+ new TlsInspectorReplaceHandshakeMessage(
+ kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
+ client_->SetPacketFilter(inspect);
+ auto alert_recorder = new TlsAlertRecorder();
+ server_->SetPacketFilter(alert_recorder);
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
+}
+
+// This test is stream so we can catch the bad_record_mac alert.
+TEST_P(TlsConnectStreamPre13,
+ ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) {
+ EnableOnlyStaticRsaCiphers();
+ EnableExtendedMasterSecret();
+ client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_));
+ auto alert_recorder = new TlsAlertRecorder();
+ server_->SetPacketFilter(alert_recorder);
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
+}
+
+TEST_P(TlsConnectStreamPre13,
+ ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) {
+ EnableOnlyStaticRsaCiphers();
+ EnableExtendedMasterSecret();
+ client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_));
+ server_->DisableRollbackDetection();
+ Connect();
+}
+
+} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
new file mode 100644
index 0000000..8b586be
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
@@ -0,0 +1,351 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "pk11pub.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+
+namespace nss_test {
+
+// Replaces the client hello with an SSLv2 version once.
+class SSLv2ClientHelloFilter : public PacketFilter {
+ public:
+ SSLv2ClientHelloFilter(TlsAgent* client, uint16_t version)
+ : replaced_(false),
+ client_(client),
+ version_(version),
+ pad_len_(0),
+ reported_pad_len_(0),
+ client_random_len_(16),
+ ciphers_(0),
+ send_escape_(false) {}
+
+ void SetVersion(uint16_t version) { version_ = version; }
+
+ void SetCipherSuites(const std::vector<uint16_t>& ciphers) {
+ ciphers_ = ciphers;
+ }
+
+ // Set a padding length and announce it correctly.
+ void SetPadding(uint8_t pad_len) { SetPadding(pad_len, pad_len); }
+
+ // Set a padding length and allow to lie about its length.
+ void SetPadding(uint8_t pad_len, uint8_t reported_pad_len) {
+ pad_len_ = pad_len;
+ reported_pad_len_ = reported_pad_len;
+ }
+
+ void SetClientRandomLength(uint16_t client_random_len) {
+ client_random_len_ = client_random_len;
+ }
+
+ void SetSendEscape(bool send_escape) { send_escape_ = send_escape; }
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ if (replaced_) {
+ return KEEP;
+ }
+
+ // Replace only the very first packet.
+ replaced_ = true;
+
+ // The SSLv2 client hello size.
+ size_t packet_len = SSL_HL_CLIENT_HELLO_HBYTES + (ciphers_.size() * 3) +
+ client_random_len_ + pad_len_;
+
+ size_t idx = 0;
+ *output = input;
+ output->Allocate(packet_len);
+ output->Truncate(packet_len);
+
+ // Write record length.
+ if (pad_len_ > 0) {
+ size_t masked_len = 0x3fff & packet_len;
+ if (send_escape_) {
+ masked_len |= 0x4000;
+ }
+
+ idx = output->Write(idx, masked_len, 2);
+ idx = output->Write(idx, reported_pad_len_, 1);
+ } else {
+ PR_ASSERT(!send_escape_);
+ idx = output->Write(idx, 0x8000 | packet_len, 2);
+ }
+
+ // Remember header length.
+ size_t hdr_len = idx;
+
+ // Write client hello.
+ idx = output->Write(idx, SSL_MT_CLIENT_HELLO, 1);
+ idx = output->Write(idx, version_, 2);
+
+ // Cipher list length.
+ idx = output->Write(idx, (ciphers_.size() * 3), 2);
+
+ // Session ID length.
+ idx = output->Write(idx, static_cast<uint32_t>(0), 2);
+
+ // ClientRandom length.
+ idx = output->Write(idx, client_random_len_, 2);
+
+ // Cipher suites.
+ for (auto cipher : ciphers_) {
+ idx = output->Write(idx, static_cast<uint32_t>(cipher), 3);
+ }
+
+ // Challenge.
+ std::vector<uint8_t> challenge(client_random_len_);
+ PK11_GenerateRandom(challenge.data(), challenge.size());
+ idx = output->Write(idx, challenge.data(), challenge.size());
+
+ // Add padding if any.
+ if (pad_len_ > 0) {
+ std::vector<uint8_t> pad(pad_len_);
+ idx = output->Write(idx, pad.data(), pad.size());
+ }
+
+ // Update the client random so that the handshake succeeds.
+ SECStatus rv = SSLInt_UpdateSSLv2ClientRandom(
+ client_->ssl_fd(), challenge.data(), challenge.size(),
+ output->data() + hdr_len, output->len() - hdr_len);
+ EXPECT_EQ(SECSuccess, rv);
+
+ return CHANGE;
+ }
+
+ private:
+ bool replaced_;
+ TlsAgent* client_;
+ uint16_t version_;
+ uint8_t pad_len_;
+ uint8_t reported_pad_len_;
+ uint16_t client_random_len_;
+ std::vector<uint16_t> ciphers_;
+ bool send_escape_;
+};
+
+class SSLv2ClientHelloTestF : public TlsConnectTestBase {
+ public:
+ SSLv2ClientHelloTestF() : TlsConnectTestBase(STREAM, 0), filter_(nullptr) {}
+
+ SSLv2ClientHelloTestF(Mode mode, uint16_t version)
+ : TlsConnectTestBase(mode, version), filter_(nullptr) {}
+
+ void SetUp() {
+ TlsConnectTestBase::SetUp();
+ filter_ = new SSLv2ClientHelloFilter(client_, version_);
+ client_->SetPacketFilter(filter_);
+ }
+
+ void RequireSafeRenegotiation() {
+ server_->EnsureTlsSetup();
+ SECStatus rv =
+ SSL_OptionSet(server_->ssl_fd(), SSL_REQUIRE_SAFE_NEGOTIATION, PR_TRUE);
+ EXPECT_EQ(rv, SECSuccess);
+ }
+
+ void SetExpectedVersion(uint16_t version) {
+ TlsConnectTestBase::SetExpectedVersion(version);
+ filter_->SetVersion(version);
+ }
+
+ void SetAvailableCipherSuite(uint16_t cipher) {
+ filter_->SetCipherSuites(std::vector<uint16_t>(1, cipher));
+ }
+
+ void SetAvailableCipherSuites(const std::vector<uint16_t>& ciphers) {
+ filter_->SetCipherSuites(ciphers);
+ }
+
+ void SetPadding(uint8_t pad_len) { filter_->SetPadding(pad_len); }
+
+ void SetPadding(uint8_t pad_len, uint8_t reported_pad_len) {
+ filter_->SetPadding(pad_len, reported_pad_len);
+ }
+
+ void SetClientRandomLength(uint16_t client_random_len) {
+ filter_->SetClientRandomLength(client_random_len);
+ }
+
+ void SetSendEscape(bool send_escape) { filter_->SetSendEscape(send_escape); }
+
+ private:
+ SSLv2ClientHelloFilter* filter_;
+};
+
+// Parameterized version of SSLv2ClientHelloTestF we can
+// use with TEST_P to test multiple TLS versions easily.
+class SSLv2ClientHelloTest : public SSLv2ClientHelloTestF,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ SSLv2ClientHelloTest() : SSLv2ClientHelloTestF(STREAM, GetParam()) {}
+};
+
+// Test negotiating TLS 1.0 - 1.2.
+TEST_P(SSLv2ClientHelloTest, Connect) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ Connect();
+}
+
+// Test negotiating TLS 1.3.
+TEST_F(SSLv2ClientHelloTestF, Connect13) {
+ EnsureTlsSetup();
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+
+ std::vector<uint16_t> cipher_suites = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256};
+ SetAvailableCipherSuites(cipher_suites);
+
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Test negotiating an EC suite.
+TEST_P(SSLv2ClientHelloTest, NegotiateECSuite) {
+ SetAvailableCipherSuite(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
+ Connect();
+}
+
+// Test negotiating TLS 1.0 - 1.2 with a padded client hello.
+TEST_P(SSLv2ClientHelloTest, AddPadding) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ SetPadding(255);
+ Connect();
+}
+
+// Test that sending a security escape fails the handshake.
+TEST_P(SSLv2ClientHelloTest, SendSecurityEscape) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Send a security escape.
+ SetSendEscape(true);
+
+ // Set a big padding so that the server fails instead of timing out.
+ SetPadding(255);
+
+ ConnectExpectFail();
+}
+
+// Invalid SSLv2 client hello padding must fail the handshake.
+TEST_P(SSLv2ClientHelloTest, AddErroneousPadding) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Append 5 bytes of padding but say it's only 4.
+ SetPadding(5, 4);
+
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Invalid SSLv2 client hello padding must fail the handshake.
+TEST_P(SSLv2ClientHelloTest, AddErroneousPadding2) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Append 5 bytes of padding but say it's 6.
+ SetPadding(5, 6);
+
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Wrong amount of bytes for the ClientRandom must fail the handshake.
+TEST_P(SSLv2ClientHelloTest, SmallClientRandom) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Send a ClientRandom that's too small.
+ SetClientRandomLength(15);
+
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Test sending the maximum accepted number of ClientRandom bytes.
+TEST_P(SSLv2ClientHelloTest, MaxClientRandom) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ SetClientRandomLength(32);
+ Connect();
+}
+
+// Wrong amount of bytes for the ClientRandom must fail the handshake.
+TEST_P(SSLv2ClientHelloTest, BigClientRandom) {
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+
+ // Send a ClientRandom that's too big.
+ SetClientRandomLength(33);
+
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
+}
+
+// Connection must fail if we require safe renegotiation but the client doesn't
+// include TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites.
+TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) {
+ RequireSafeRenegotiation();
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code());
+}
+
+// Connection must succeed when requiring safe renegotiation and the client
+// includes TLS_EMPTY_RENEGOTIATION_INFO_SCSV in the list of cipher suites.
+TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiationWithSCSV) {
+ RequireSafeRenegotiation();
+ std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_EMPTY_RENEGOTIATION_INFO_SCSV};
+ SetAvailableCipherSuites(cipher_suites);
+ Connect();
+}
+
+// Connect to the server with TLS 1.1, signalling that this is a fallback from
+// a higher version. As the server doesn't support anything higher than TLS 1.1
+// it must accept the connection.
+TEST_F(SSLv2ClientHelloTestF, FallbackSCSV) {
+ EnsureTlsSetup();
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_1);
+
+ std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_FALLBACK_SCSV};
+ SetAvailableCipherSuites(cipher_suites);
+ Connect();
+}
+
+// Connect to the server with TLS 1.1, signalling that this is a fallback from
+// a higher version. As the server supports TLS 1.2 though it must reject the
+// connection due to a possible downgrade attack.
+TEST_F(SSLv2ClientHelloTestF, InappropriateFallbackSCSV) {
+ SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+
+ std::vector<uint16_t> cipher_suites = {TLS_DHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_FALLBACK_SCSV};
+ SetAvailableCipherSuites(cipher_suites);
+
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT, server_->error_code());
+}
+
+INSTANTIATE_TEST_CASE_P(VersionsStream10Pre13, SSLv2ClientHelloTest,
+ TlsConnectTestBase::kTlsV10);
+INSTANTIATE_TEST_CASE_P(VersionsStreamPre13, SSLv2ClientHelloTest,
+ TlsConnectTestBase::kTlsV11V12);
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/nss/gtests/ssl_gtest/ssl_version_unittest.cc
new file mode 100644
index 0000000..b353849
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_version_unittest.cc
@@ -0,0 +1,300 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+TEST_P(TlsConnectStream, ServerNegotiateTls10) {
+ uint16_t minver, maxver;
+ client_->GetVersionRange(&minver, &maxver);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, maxver);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, ServerNegotiateTls11) {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) return;
+
+ uint16_t minver, maxver;
+ client_->GetVersionRange(&minver, &maxver);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, maxver);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ Connect();
+}
+
+TEST_P(TlsConnectGeneric, ServerNegotiateTls12) {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) return;
+
+ uint16_t minver, maxver;
+ client_->GetVersionRange(&minver, &maxver);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, maxver);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+}
+#ifndef TLS_1_3_DRAFT_VERSION
+
+// Test the ServerRandom version hack from
+// [draft-ietf-tls-tls13-11 Section 6.3.1.1].
+// The first three tests test for active tampering. The next
+// two validate that we can also detect fallback using the
+// SSL_SetDowngradeCheckVersion() API.
+TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) {
+ client_->SetPacketFilter(
+ new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_1));
+ ConnectExpectFail();
+ ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+}
+
+/* Attempt to negotiate the bogus DTLS 1.1 version. */
+TEST_F(DtlsConnectTest, TestDtlsVersion11) {
+ client_->SetPacketFilter(
+ new TlsInspectorClientHelloVersionSetter(((~0x0101) & 0xffff)));
+ ConnectExpectFail();
+ // It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is
+ // what is returned here, but this is deliberate in ssl3_HandleAlert().
+ EXPECT_EQ(SSL_ERROR_NO_CYPHER_OVERLAP, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_UNSUPPORTED_VERSION, server_->error_code());
+}
+
+// Disabled as long as we have draft version.
+TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
+ EnsureTlsSetup();
+ client_->SetPacketFilter(
+ new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_2));
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectExpectFail();
+ ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+}
+
+// TLS 1.1 clients do not check the random values, so we should
+// instead get a handshake failure alert from the server.
+TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) {
+ client_->SetPacketFilter(
+ new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_0));
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ ConnectExpectFail();
+ ASSERT_EQ(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE, server_->error_code());
+ ASSERT_EQ(SSL_ERROR_DECRYPT_ERROR_ALERT, client_->error_code());
+}
+
+TEST_F(TlsConnectTest, TestFallbackFromTls12) {
+ EnsureTlsSetup();
+ client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_1);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ ConnectExpectFail();
+ ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+}
+
+TEST_F(TlsConnectTest, TestFallbackFromTls13) {
+ EnsureTlsSetup();
+ client_->SetDowngradeCheckVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectExpectFail();
+ ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
+}
+#endif
+
+// The TLS v1.3 spec section C.4 states that 'Implementations MUST NOT send or
+// accept any records with a version less than { 3, 0 }'. Thus we will not
+// allow version ranges including both SSL v3 and TLS v1.3.
+TEST_F(TlsConnectTest, DisallowSSLv3HelloWithTLSv13Enabled) {
+ SECStatus rv;
+ SSLVersionRange vrange = {SSL_LIBRARY_VERSION_3_0,
+ SSL_LIBRARY_VERSION_TLS_1_3};
+
+ EnsureTlsSetup();
+ rv = SSL_VersionRangeSet(client_->ssl_fd(), &vrange);
+ EXPECT_EQ(SECFailure, rv);
+
+ rv = SSL_VersionRangeSet(server_->ssl_fd(), &vrange);
+ EXPECT_EQ(SECFailure, rv);
+}
+
+TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ // Set the client so it will accept any version from 1.0
+ // to |version_|.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version so that the checks succeed.
+ uint16_t test_version = version_;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_0;
+ Connect();
+
+ // Now renegotiate, with the server being set to do
+ // |version_|.
+ client_->PrepareForRenegotiate();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+ server_->StartRenegotiate();
+ Handshake();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the server detects this problem.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ }
+}
+
+TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) {
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_0) {
+ return;
+ }
+ // Set the client so it will accept any version from 1.0
+ // to |version_|.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, version_);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_0);
+ // Reset version so that the checks succeed.
+ uint16_t test_version = version_;
+ version_ = SSL_LIBRARY_VERSION_TLS_1_0;
+ Connect();
+
+ // Now renegotiate, with the server being set to do
+ // |version_|.
+ server_->PrepareForRenegotiate();
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0, test_version);
+ // Reset version and cipher suite so that the preinfo callback
+ // doesn't fail.
+ server_->ResetPreliminaryInfo();
+ client_->StartRenegotiate();
+ Handshake();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // In TLS 1.3, the server detects this problem.
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED);
+ } else {
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ }
+}
+
+TEST_F(TlsConnectTest, Tls13RejectsRehandshakeClient) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SECStatus rv = SSL_ReHandshake(client_->ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
+}
+
+TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) {
+ EnsureTlsSetup();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+ SECStatus rv = SSL_ReHandshake(server_->ssl_fd(), PR_TRUE);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(SSL_ERROR_RENEGOTIATION_NOT_ALLOWED, PORT_GetError());
+}
+
+TEST_P(TlsConnectGeneric, AlertBeforeServerHello) {
+ EnsureTlsSetup();
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake(); // Send ClientHello.
+ static const uint8_t kWarningAlert[] = {kTlsAlertWarning,
+ kTlsAlertUnrecognizedName};
+ DataBuffer alert;
+ TlsAgentTestBase::MakeRecord(mode_, kTlsAlertType,
+ SSL_LIBRARY_VERSION_TLS_1_0, kWarningAlert,
+ PR_ARRAY_SIZE(kWarningAlert), &alert);
+ client_->adapter()->PacketReceived(alert);
+ Handshake();
+ CheckConnected();
+}
+
+class Tls13NoSupportedVersions : public TlsConnectStreamTls12 {
+ protected:
+ void Run(uint16_t overwritten_client_version, uint16_t max_server_version) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version);
+ client_->SetPacketFilter(
+ new TlsInspectorClientHelloVersionSetter(overwritten_client_version));
+ auto capture =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerHello);
+ server_->SetPacketFilter(capture);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+ const DataBuffer& server_hello = capture->buffer();
+ ASSERT_GT(server_hello.len(), 2U);
+ uint32_t ver;
+ ASSERT_TRUE(server_hello.Read(0, 2, &ver));
+ ASSERT_EQ(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver);
+ }
+};
+
+// If we offer a 1.3 ClientHello w/o supported_versions, the server should
+// negotiate 1.2.
+TEST_F(Tls13NoSupportedVersions,
+ Tls13ClientHelloWithoutSupportedVersionsServer12) {
+ Run(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_2);
+}
+
+TEST_F(Tls13NoSupportedVersions,
+ Tls13ClientHelloWithoutSupportedVersionsServer13) {
+ Run(SSL_LIBRARY_VERSION_TLS_1_3, SSL_LIBRARY_VERSION_TLS_1_3);
+}
+
+TEST_F(Tls13NoSupportedVersions,
+ Tls14ClientHelloWithoutSupportedVersionsServer13) {
+ Run(SSL_LIBRARY_VERSION_TLS_1_3 + 1, SSL_LIBRARY_VERSION_TLS_1_3);
+}
+
+// Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This
+// causes a bad MAC error when we read EncryptedExtensions.
+TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) {
+ client_->SetPacketFilter(new TlsInspectorClientHelloVersionSetter(
+ SSL_LIBRARY_VERSION_TLS_1_3 + 1));
+ auto capture =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerHello);
+ server_->SetPacketFilter(capture);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ const DataBuffer& server_hello = capture->buffer();
+ ASSERT_GT(server_hello.len(), 2U);
+ uint32_t ver;
+ ASSERT_TRUE(server_hello.Read(0, 2, &ver));
+ // This way we don't need to change with new draft version.
+ ASSERT_LT(static_cast<uint32_t>(SSL_LIBRARY_VERSION_TLS_1_2), ver);
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/test_io.cc b/nss/gtests/ssl_gtest/test_io.cc
new file mode 100644
index 0000000..f3fd0b2
--- /dev/null
+++ b/nss/gtests/ssl_gtest/test_io.cc
@@ -0,0 +1,536 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "test_io.h"
+
+#include <algorithm>
+#include <cassert>
+#include <iostream>
+#include <memory>
+
+#include "prerror.h"
+#include "prlog.h"
+#include "prthread.h"
+
+#include "databuffer.h"
+
+extern bool g_ssl_gtest_verbose;
+
+namespace nss_test {
+
+static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER;
+
+#define UNIMPLEMENTED() \
+ std::cerr << "Call to unimplemented function " << __FUNCTION__ << std::endl; \
+ PR_ASSERT(PR_FALSE); \
+ PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0)
+
+#define LOG(a) std::cerr << name_ << ": " << a << std::endl
+#define LOGV(a) \
+ do { \
+ if (g_ssl_gtest_verbose) LOG(a); \
+ } while (false)
+
+class Packet : public DataBuffer {
+ public:
+ Packet(const DataBuffer &buf) : DataBuffer(buf), offset_(0) {}
+
+ void Advance(size_t delta) {
+ PR_ASSERT(offset_ + delta <= len());
+ offset_ = std::min(len(), offset_ + delta);
+ }
+
+ size_t offset() const { return offset_; }
+ size_t remaining() const { return len() - offset_; }
+
+ private:
+ size_t offset_;
+};
+
+// Implementation of NSPR methods
+static PRStatus DummyClose(PRFileDesc *f) {
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+ f->secret = nullptr;
+ f->dtor(f);
+ delete io;
+ return PR_SUCCESS;
+}
+
+static int32_t DummyRead(PRFileDesc *f, void *buf, int32_t length) {
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+ return io->Read(buf, length);
+}
+
+static int32_t DummyWrite(PRFileDesc *f, const void *buf, int32_t length) {
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+ return io->Write(buf, length);
+}
+
+static int32_t DummyAvailable(PRFileDesc *f) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+int64_t DummyAvailable64(PRFileDesc *f) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummySync(PRFileDesc *f) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static int32_t DummySeek(PRFileDesc *f, int32_t offset, PRSeekWhence how) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int64_t DummySeek64(PRFileDesc *f, int64_t offset, PRSeekWhence how) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummyFileInfo(PRFileDesc *f, PRFileInfo *info) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRStatus DummyFileInfo64(PRFileDesc *f, PRFileInfo64 *info) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static int32_t DummyWritev(PRFileDesc *f, const PRIOVec *iov, int32_t iov_size,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummyConnect(PRFileDesc *f, const PRNetAddr *addr,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRFileDesc *DummyAccept(PRFileDesc *sd, PRNetAddr *addr,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return nullptr;
+}
+
+static PRStatus DummyBind(PRFileDesc *f, const PRNetAddr *addr) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRStatus DummyListen(PRFileDesc *f, int32_t depth) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRStatus DummyShutdown(PRFileDesc *f, int32_t how) {
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+ io->Reset();
+ return PR_SUCCESS;
+}
+
+// This function does not support peek.
+static int32_t DummyRecv(PRFileDesc *f, void *buf, int32_t buflen,
+ int32_t flags, PRIntervalTime to) {
+ PR_ASSERT(flags == 0);
+ if (flags != 0) {
+ PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
+ return -1;
+ }
+
+ DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
+
+ if (io->mode() == DGRAM) {
+ return io->Recv(buf, buflen);
+ } else {
+ return io->Read(buf, buflen);
+ }
+}
+
+// Note: this is always nonblocking and assumes a zero timeout.
+static int32_t DummySend(PRFileDesc *f, const void *buf, int32_t amount,
+ int32_t flags, PRIntervalTime to) {
+ int32_t written = DummyWrite(f, buf, amount);
+ return written;
+}
+
+static int32_t DummyRecvfrom(PRFileDesc *f, void *buf, int32_t amount,
+ int32_t flags, PRNetAddr *addr,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int32_t DummySendto(PRFileDesc *f, const void *buf, int32_t amount,
+ int32_t flags, const PRNetAddr *addr,
+ PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int16_t DummyPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int32_t DummyAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
+ PRNetAddr **raddr, void *buf, int32_t amount,
+ PRIntervalTime t) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static int32_t DummyTransmitFile(PRFileDesc *sd, PRFileDesc *f,
+ const void *headers, int32_t hlen,
+ PRTransmitFileFlags flags, PRIntervalTime t) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummyGetpeername(PRFileDesc *f, PRNetAddr *addr) {
+ // TODO: Modify to return unique names for each channel
+ // somehow, as opposed to always the same static address. The current
+ // implementation messes up the session cache, which is why it's off
+ // elsewhere
+ addr->inet.family = PR_AF_INET;
+ addr->inet.port = 0;
+ addr->inet.ip = 0;
+
+ return PR_SUCCESS;
+}
+
+static PRStatus DummyGetsockname(PRFileDesc *f, PRNetAddr *addr) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static PRStatus DummyGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) {
+ switch (opt->option) {
+ case PR_SockOpt_Nonblocking:
+ opt->value.non_blocking = PR_TRUE;
+ return PR_SUCCESS;
+ default:
+ UNIMPLEMENTED();
+ break;
+ }
+
+ return PR_FAILURE;
+}
+
+// Imitate setting socket options. These are mostly noops.
+static PRStatus DummySetsockoption(PRFileDesc *f,
+ const PRSocketOptionData *opt) {
+ switch (opt->option) {
+ case PR_SockOpt_Nonblocking:
+ return PR_SUCCESS;
+ case PR_SockOpt_NoDelay:
+ return PR_SUCCESS;
+ default:
+ UNIMPLEMENTED();
+ break;
+ }
+
+ return PR_FAILURE;
+}
+
+static int32_t DummySendfile(PRFileDesc *out, PRSendFileData *in,
+ PRTransmitFileFlags flags, PRIntervalTime to) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+static PRStatus DummyConnectContinue(PRFileDesc *f, int16_t flags) {
+ UNIMPLEMENTED();
+ return PR_FAILURE;
+}
+
+static int32_t DummyReserved(PRFileDesc *f) {
+ UNIMPLEMENTED();
+ return -1;
+}
+
+DummyPrSocket::~DummyPrSocket() { Reset(); }
+
+void DummyPrSocket::SetPacketFilter(PacketFilter *filter) {
+ if (filter_) {
+ delete filter_;
+ }
+ filter_ = filter;
+}
+
+void DummyPrSocket::Reset() {
+ delete filter_;
+ if (peer_) {
+ peer_->SetPeer(nullptr);
+ peer_ = nullptr;
+ }
+ while (!input_.empty()) {
+ Packet *front = input_.front();
+ input_.pop();
+ delete front;
+ }
+}
+
+static const struct PRIOMethods DummyMethods = {
+ PR_DESC_LAYERED, DummyClose,
+ DummyRead, DummyWrite,
+ DummyAvailable, DummyAvailable64,
+ DummySync, DummySeek,
+ DummySeek64, DummyFileInfo,
+ DummyFileInfo64, DummyWritev,
+ DummyConnect, DummyAccept,
+ DummyBind, DummyListen,
+ DummyShutdown, DummyRecv,
+ DummySend, DummyRecvfrom,
+ DummySendto, DummyPoll,
+ DummyAcceptRead, DummyTransmitFile,
+ DummyGetsockname, DummyGetpeername,
+ DummyReserved, DummyReserved,
+ DummyGetsockoption, DummySetsockoption,
+ DummySendfile, DummyConnectContinue,
+ DummyReserved, DummyReserved,
+ DummyReserved, DummyReserved};
+
+PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) {
+ if (test_fd_identity == PR_INVALID_IO_LAYER) {
+ test_fd_identity = PR_GetUniqueIdentity("testtransportadapter");
+ }
+
+ PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods));
+ fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode));
+
+ return fd;
+}
+
+DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) {
+ return reinterpret_cast<DummyPrSocket *>(fd->secret);
+}
+
+void DummyPrSocket::PacketReceived(const DataBuffer &packet) {
+ input_.push(new Packet(packet));
+}
+
+int32_t DummyPrSocket::Read(void *data, int32_t len) {
+ PR_ASSERT(mode_ == STREAM);
+
+ if (mode_ != STREAM) {
+ PR_SetError(PR_INVALID_METHOD_ERROR, 0);
+ return -1;
+ }
+
+ if (input_.empty()) {
+ LOGV("Read --> wouldblock " << len);
+ PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
+ return -1;
+ }
+
+ Packet *front = input_.front();
+ size_t to_read =
+ std::min(static_cast<size_t>(len), front->len() - front->offset());
+ memcpy(data, static_cast<const void *>(front->data() + front->offset()),
+ to_read);
+ front->Advance(to_read);
+
+ if (!front->remaining()) {
+ input_.pop();
+ delete front;
+ }
+
+ return static_cast<int32_t>(to_read);
+}
+
+int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
+ if (input_.empty()) {
+ PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
+ return -1;
+ }
+
+ Packet *front = input_.front();
+ if (static_cast<size_t>(buflen) < front->len()) {
+ PR_ASSERT(false);
+ PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
+ return -1;
+ }
+
+ size_t count = front->len();
+ memcpy(buf, front->data(), count);
+
+ input_.pop();
+ delete front;
+
+ return static_cast<int32_t>(count);
+}
+
+int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
+ if (!peer_ || !writeable_) {
+ PR_SetError(PR_IO_ERROR, 0);
+ return -1;
+ }
+
+ DataBuffer packet(static_cast<const uint8_t *>(buf),
+ static_cast<size_t>(length));
+ DataBuffer filtered;
+ PacketFilter::Action action = PacketFilter::KEEP;
+ if (filter_) {
+ action = filter_->Filter(packet, &filtered);
+ }
+ switch (action) {
+ case PacketFilter::CHANGE:
+ LOG("Original packet: " << packet);
+ LOG("Filtered packet: " << filtered);
+ peer_->PacketReceived(filtered);
+ break;
+ case PacketFilter::DROP:
+ LOG("Droppped packet: " << packet);
+ break;
+ case PacketFilter::KEEP:
+ LOGV("Packet: " << packet);
+ peer_->PacketReceived(packet);
+ break;
+ }
+ // libssl can't handle it if this reports something other than the length
+ // of what was passed in (or less, but we're not doing partial writes).
+ return static_cast<int32_t>(packet.len());
+}
+
+Poller *Poller::instance;
+
+Poller *Poller::Instance() {
+ if (!instance) instance = new Poller();
+
+ return instance;
+}
+
+void Poller::Shutdown() {
+ delete instance;
+ instance = nullptr;
+}
+
+Poller::~Poller() {
+ while (!timers_.empty()) {
+ Timer *timer = timers_.top();
+ timers_.pop();
+ delete timer;
+ }
+}
+
+void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target,
+ PollCallback cb) {
+ auto it = waiters_.find(adapter);
+ Waiter *waiter;
+
+ if (it == waiters_.end()) {
+ waiter = new Waiter(adapter);
+ } else {
+ waiter = it->second;
+ }
+
+ assert(event < TIMER_EVENT);
+ if (event >= TIMER_EVENT) return;
+
+ waiter->targets_[event] = target;
+ waiter->callbacks_[event] = cb;
+ waiters_[adapter] = waiter;
+}
+
+void Poller::Cancel(Event event, DummyPrSocket *adapter) {
+ auto it = waiters_.find(adapter);
+ Waiter *waiter;
+
+ if (it == waiters_.end()) {
+ return;
+ }
+
+ waiter = it->second;
+
+ waiter->targets_[event] = nullptr;
+ waiter->callbacks_[event] = nullptr;
+
+ // Clean up if there are no callbacks.
+ for (size_t i = 0; i < TIMER_EVENT; ++i) {
+ if (waiter->callbacks_[i]) return;
+ }
+
+ delete waiter;
+ waiters_.erase(adapter);
+}
+
+void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb,
+ Timer **timer) {
+ Timer *t = new Timer(PR_Now() + timer_ms * 1000, target, cb);
+ timers_.push(t);
+ if (timer) *timer = t;
+}
+
+bool Poller::Poll() {
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Poll() waiters = " << waiters_.size()
+ << " timers = " << timers_.size() << std::endl;
+ }
+ PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT;
+ PRTime now = PR_Now();
+ bool fired = false;
+
+ // Figure out the timer for the select.
+ if (!timers_.empty()) {
+ Timer *first_timer = timers_.top();
+ if (now >= first_timer->deadline_) {
+ // Timer expired.
+ timeout = PR_INTERVAL_NO_WAIT;
+ } else {
+ timeout =
+ PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000);
+ }
+ }
+
+ for (auto it = waiters_.begin(); it != waiters_.end(); ++it) {
+ Waiter *waiter = it->second;
+
+ if (waiter->callbacks_[READABLE_EVENT]) {
+ if (waiter->io_->readable()) {
+ PollCallback callback = waiter->callbacks_[READABLE_EVENT];
+ PollTarget *target = waiter->targets_[READABLE_EVENT];
+ waiter->callbacks_[READABLE_EVENT] = nullptr;
+ waiter->targets_[READABLE_EVENT] = nullptr;
+ callback(target, READABLE_EVENT);
+ fired = true;
+ }
+ }
+ }
+
+ if (fired) timeout = PR_INTERVAL_NO_WAIT;
+
+ // Can't wait forever and also have nothing readable now.
+ if (timeout == PR_INTERVAL_NO_TIMEOUT) return false;
+
+ // Sleep.
+ if (timeout != PR_INTERVAL_NO_WAIT) {
+ PR_Sleep(timeout);
+ }
+
+ // Now process anything that timed out.
+ now = PR_Now();
+ while (!timers_.empty()) {
+ if (now < timers_.top()->deadline_) break;
+
+ Timer *timer = timers_.top();
+ timers_.pop();
+ if (timer->callback_) {
+ timer->callback_(timer->target_, TIMER_EVENT);
+ }
+ delete timer;
+ }
+
+ return true;
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/test_io.h b/nss/gtests/ssl_gtest/test_io.h
new file mode 100644
index 0000000..b78db0d
--- /dev/null
+++ b/nss/gtests/ssl_gtest/test_io.h
@@ -0,0 +1,152 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef test_io_h_
+#define test_io_h_
+
+#include <string.h>
+#include <map>
+#include <memory>
+#include <ostream>
+#include <queue>
+#include <string>
+
+#include "prio.h"
+
+namespace nss_test {
+
+class DataBuffer;
+class Packet;
+class DummyPrSocket; // Fwd decl.
+
+// Allow us to inspect a packet before it is written.
+class PacketFilter {
+ public:
+ enum Action {
+ KEEP, // keep the original packet unmodified
+ CHANGE, // change the packet to a different value
+ DROP // drop the packet
+ };
+
+ virtual ~PacketFilter() {}
+
+ // The packet filter takes input and has the option of mutating it.
+ //
+ // A filter that modifies the data places the modified data in *output and
+ // returns CHANGE. A filter that does not modify data returns LEAVE, in which
+ // case the value in *output is ignored. A Filter can return DROP, in which
+ // case the packet is dropped (and *output is ignored).
+ virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
+};
+
+enum Mode { STREAM, DGRAM };
+
+inline std::ostream& operator<<(std::ostream& os, Mode m) {
+ return os << ((m == STREAM) ? "TLS" : "DTLS");
+}
+
+class DummyPrSocket {
+ public:
+ ~DummyPrSocket();
+
+ static PRFileDesc* CreateFD(const std::string& name,
+ Mode mode); // Returns an FD.
+ static DummyPrSocket* GetAdapter(PRFileDesc* fd);
+
+ DummyPrSocket* peer() const { return peer_; }
+ void SetPeer(DummyPrSocket* peer) { peer_ = peer; }
+ void SetPacketFilter(PacketFilter* filter);
+ // Drops peer, packet filter and any outstanding packets.
+ void Reset();
+
+ void PacketReceived(const DataBuffer& data);
+ int32_t Read(void* data, int32_t len);
+ int32_t Recv(void* buf, int32_t buflen);
+ int32_t Write(const void* buf, int32_t length);
+ void CloseWrites() { writeable_ = false; }
+
+ Mode mode() const { return mode_; }
+ bool readable() const { return !input_.empty(); }
+
+ private:
+ DummyPrSocket(const std::string& name, Mode mode)
+ : name_(name),
+ mode_(mode),
+ peer_(nullptr),
+ input_(),
+ filter_(nullptr),
+ writeable_(true) {}
+
+ const std::string name_;
+ Mode mode_;
+ DummyPrSocket* peer_;
+ std::queue<Packet*> input_;
+ PacketFilter* filter_;
+ bool writeable_;
+};
+
+// Marker interface.
+class PollTarget {};
+
+enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ };
+
+typedef void (*PollCallback)(PollTarget*, Event);
+
+class Poller {
+ public:
+ static Poller* Instance(); // Get a singleton.
+ static void Shutdown(); // Shut it down.
+
+ class Timer {
+ public:
+ Timer(PRTime deadline, PollTarget* target, PollCallback callback)
+ : deadline_(deadline), target_(target), callback_(callback) {}
+ void Cancel() { callback_ = nullptr; }
+
+ PRTime deadline_;
+ PollTarget* target_;
+ PollCallback callback_;
+ };
+
+ void Wait(Event event, DummyPrSocket* adapter, PollTarget* target,
+ PollCallback cb);
+ void Cancel(Event event, DummyPrSocket* adapter);
+ void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb,
+ Timer** handle);
+ bool Poll();
+
+ private:
+ Poller() : waiters_(), timers_() {}
+ ~Poller();
+
+ class Waiter {
+ public:
+ Waiter(DummyPrSocket* io) : io_(io) {
+ memset(&callbacks_[0], 0, sizeof(callbacks_));
+ }
+
+ void WaitFor(Event event, PollCallback callback);
+
+ DummyPrSocket* io_;
+ PollTarget* targets_[TIMER_EVENT];
+ PollCallback callbacks_[TIMER_EVENT];
+ };
+
+ class TimerComparator {
+ public:
+ bool operator()(const Timer* lhs, const Timer* rhs) {
+ return lhs->deadline_ > rhs->deadline_;
+ }
+ };
+
+ static Poller* instance;
+ std::map<DummyPrSocket*, Waiter*> waiters_;
+ std::priority_queue<Timer*, std::vector<Timer*>, TimerComparator> timers_;
+};
+
+} // end of namespace
+
+#endif
diff --git a/nss/gtests/ssl_gtest/tls_agent.cc b/nss/gtests/ssl_gtest/tls_agent.cc
new file mode 100644
index 0000000..b75bba5
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_agent.cc
@@ -0,0 +1,992 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_agent.h"
+#include "databuffer.h"
+#include "keyhi.h"
+#include "pk11func.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+#include "tls_parser.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+
+extern std::string g_working_dir_path;
+
+namespace nss_test {
+
+const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"};
+
+const std::string TlsAgent::kClient = "client"; // both sign and encrypt
+const std::string TlsAgent::kRsa2048 = "rsa2048"; // bigger
+const std::string TlsAgent::kServerRsa = "rsa"; // both sign and encrypt
+const std::string TlsAgent::kServerRsaSign = "rsa_sign";
+const std::string TlsAgent::kServerRsaPss = "rsa_pss";
+const std::string TlsAgent::kServerRsaDecrypt = "rsa_decrypt";
+const std::string TlsAgent::kServerRsaChain = "rsa_chain";
+const std::string TlsAgent::kServerEcdsa256 = "ecdsa256";
+const std::string TlsAgent::kServerEcdsa384 = "ecdsa384";
+const std::string TlsAgent::kServerEcdsa521 = "ecdsa521";
+const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
+const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
+const std::string TlsAgent::kServerDsa = "dsa";
+
+TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
+ : name_(name),
+ mode_(mode),
+ server_key_bits_(0),
+ pr_fd_(nullptr),
+ adapter_(nullptr),
+ ssl_fd_(nullptr),
+ role_(role),
+ state_(STATE_INIT),
+ timer_handle_(nullptr),
+ falsestart_enabled_(false),
+ expected_version_(0),
+ expected_cipher_suite_(0),
+ expect_resumption_(false),
+ expect_client_auth_(false),
+ can_falsestart_hook_called_(false),
+ sni_hook_called_(false),
+ auth_certificate_hook_called_(false),
+ handshake_callback_called_(false),
+ error_code_(0),
+ send_ctr_(0),
+ recv_ctr_(0),
+ expect_readwrite_error_(false),
+ handshake_callback_(),
+ auth_certificate_callback_(),
+ sni_callback_(),
+ expect_short_headers_(false) {
+ memset(&info_, 0, sizeof(info_));
+ memset(&csinfo_, 0, sizeof(csinfo_));
+ SECStatus rv = SSL_VersionRangeGetDefault(
+ mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+TlsAgent::~TlsAgent() {
+ if (adapter_) {
+ Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
+ // The adapter is closed when the FD closes.
+ }
+ if (timer_handle_) {
+ timer_handle_->Cancel();
+ }
+
+ if (pr_fd_) {
+ PR_Close(pr_fd_);
+ }
+
+ if (ssl_fd_) {
+ PR_Close(ssl_fd_);
+ }
+}
+
+void TlsAgent::SetState(State state) {
+ if (state_ == state) return;
+
+ LOG("Changing state from " << state_ << " to " << state);
+ state_ = state;
+}
+
+bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits,
+ const SSLExtraServerCertData* serverCertData) {
+ ScopedCERTCertificate cert(PK11_FindCertFromNickname(name.c_str(), nullptr));
+ EXPECT_NE(nullptr, cert.get());
+ if (!cert.get()) return false;
+
+ ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
+ EXPECT_NE(nullptr, pub.get());
+ if (!pub.get()) return false;
+ if (updateKeyBits) {
+ server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
+ }
+
+ ScopedSECKEYPrivateKey priv(PK11_FindKeyByAnyCert(cert.get(), nullptr));
+ EXPECT_NE(nullptr, priv.get());
+ if (!priv.get()) return false;
+
+ SECStatus rv =
+ SSL_ConfigSecureServer(ssl_fd_, nullptr, nullptr, ssl_kea_null);
+ EXPECT_EQ(SECFailure, rv);
+ rv = SSL_ConfigServerCert(ssl_fd_, cert.get(), priv.get(), serverCertData,
+ serverCertData ? sizeof(*serverCertData) : 0);
+ return rv == SECSuccess;
+}
+
+bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
+ // Don't set up twice
+ if (ssl_fd_) return true;
+
+ if (adapter_->mode() == STREAM) {
+ ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_);
+ } else {
+ ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_);
+ }
+
+ EXPECT_NE(nullptr, ssl_fd_);
+ if (!ssl_fd_) return false;
+ pr_fd_ = nullptr;
+
+ SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ if (role_ == SERVER) {
+ EXPECT_TRUE(ConfigServerCert(name_, true));
+
+ rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ ScopedCERTCertList anchors(CERT_NewCertList());
+ rv = SSL_SetTrustAnchors(ssl_fd_, anchors.get());
+ if (rv != SECSuccess) return false;
+ } else {
+ rv = SSL_SetURL(ssl_fd_, "server");
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
+
+ rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ return true;
+}
+
+void TlsAgent::SetupClientAuth() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ ASSERT_EQ(CLIENT, role_);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
+ reinterpret_cast<void*>(this)));
+}
+
+bool TlsAgent::GetClientAuthCredentials(CERTCertificate** cert,
+ SECKEYPrivateKey** priv) const {
+ *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
+ EXPECT_NE(nullptr, *cert);
+ if (!*cert) return false;
+
+ *priv = PK11_FindKeyByAnyCert(*cert, nullptr);
+ EXPECT_NE(nullptr, *priv);
+ if (!*priv) return false; // Leak cert.
+
+ return true;
+}
+
+SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** cert,
+ SECKEYPrivateKey** privKey) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
+ ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
+ EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
+ if (agent->GetClientAuthCredentials(cert, privKey)) {
+ return SECSuccess;
+ }
+ return SECFailure;
+}
+
+bool TlsAgent::GetPeerChainLength(size_t* count) {
+ CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd_);
+ if (!chain) return false;
+ *count = 0;
+
+ for (PRCList* cursor = PR_NEXT_LINK(&chain->list); cursor != &chain->list;
+ cursor = PR_NEXT_LINK(cursor)) {
+ CERTCertListNode* node = (CERTCertListNode*)cursor;
+ std::cerr << node->cert->subjectName << std::endl;
+ ++(*count);
+ }
+
+ CERT_DestroyCertList(chain);
+
+ return true;
+}
+
+void TlsAgent::RequestClientAuth(bool requireAuth) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ ASSERT_EQ(SERVER, role_);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE,
+ requireAuth ? PR_TRUE : PR_FALSE));
+
+ EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
+ ssl_fd_, &TlsAgent::ClientAuthenticated, this));
+ expect_client_auth_ = true;
+}
+
+void TlsAgent::StartConnect(PRFileDesc* model) {
+ EXPECT_TRUE(EnsureTlsSetup(model));
+
+ SECStatus rv;
+ rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ SetState(STATE_CONNECTING);
+}
+
+void TlsAgent::DisableAllCiphers() {
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SECStatus rv =
+ SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+}
+
+// Not actually all groups, just the onece that we are actually willing
+// to use.
+const std::vector<SSLNamedGroup> kAllDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1, ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072,
+ ssl_grp_ffdhe_4096, ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
+
+const std::vector<SSLNamedGroup> kECDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+
+const std::vector<SSLNamedGroup> kFFDHEGroups = {
+ ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072, ssl_grp_ffdhe_4096,
+ ssl_grp_ffdhe_6144, ssl_grp_ffdhe_8192};
+
+// Defined because the big DHE groups are ridiculously slow.
+const std::vector<SSLNamedGroup> kFasterDHEGroups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048, ssl_grp_ffdhe_3072};
+
+void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SSLCipherSuiteInfo csinfo;
+
+ SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
+ sizeof(csinfo));
+ ASSERT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(csinfo), csinfo.length);
+
+ if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
+ rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+void TlsAgent::EnableGroupsByKeyExchange(SSLKEAType kea) {
+ switch (kea) {
+ case ssl_kea_dh:
+ ConfigNamedGroups(kFFDHEGroups);
+ break;
+ case ssl_kea_ecdh:
+ ConfigNamedGroups(kECDHEGroups);
+ break;
+ default:
+ break;
+ }
+}
+
+void TlsAgent::EnableGroupsByAuthType(SSLAuthType authType) {
+ if (authType == ssl_auth_ecdh_rsa || authType == ssl_auth_ecdh_ecdsa ||
+ authType == ssl_auth_ecdsa || authType == ssl_auth_tls13_any) {
+ ConfigNamedGroups(kECDHEGroups);
+ }
+}
+
+void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
+ SSLCipherSuiteInfo csinfo;
+
+ SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], &csinfo,
+ sizeof(csinfo));
+ ASSERT_EQ(SECSuccess, rv);
+
+ if ((csinfo.authType == authType) ||
+ (csinfo.keaType == ssl_kea_tls13_any)) {
+ rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+ }
+}
+
+void TlsAgent::EnableSingleCipher(uint16_t cipher) {
+ DisableAllCiphers();
+ SECStatus rv = SSL_CipherPrefSet(ssl_fd_, cipher, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ SECStatus rv = SSL_NamedGroupConfig(ssl_fd_, &groups[0], groups.size());
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetSessionTicketsEnabled(bool en) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ en ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetSessionCacheEnabled(bool en) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::Set0RttEnabled(bool en) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv =
+ SSL_OptionSet(ssl_fd_, SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetShortHeadersEnabled() {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd_);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
+ vrange_.min = minver;
+ vrange_.max = maxver;
+
+ if (ssl_fd_) {
+ SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ }
+}
+
+void TlsAgent::GetVersionRange(uint16_t* minver, uint16_t* maxver) {
+ *minver = vrange_.min;
+ *maxver = vrange_.max;
+}
+
+void TlsAgent::SetExpectedVersion(uint16_t version) {
+ expected_version_ = version;
+}
+
+void TlsAgent::SetServerKeyBits(uint16_t bits) { server_key_bits_ = bits; }
+
+void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
+
+void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; }
+
+void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
+ size_t count) {
+ EXPECT_TRUE(EnsureTlsSetup());
+ EXPECT_LE(count, SSL_SignatureMaxCount());
+ EXPECT_EQ(SECSuccess,
+ SSL_SignatureSchemePrefSet(ssl_fd_, schemes,
+ static_cast<unsigned int>(count)));
+ EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd_, schemes, 0))
+ << "setting no schemes should fail and do nothing";
+
+ std::vector<SSLSignatureScheme> configuredSchemes(count);
+ unsigned int configuredCount;
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd_, nullptr, &configuredCount, 1))
+ << "get schemes, schemes is nullptr";
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0],
+ &configuredCount, 0))
+ << "get schemes, too little space";
+ EXPECT_EQ(SECFailure,
+ SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], nullptr,
+ configuredSchemes.size()))
+ << "get schemes, countOut is nullptr";
+
+ EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
+ ssl_fd_, &configuredSchemes[0], &configuredCount,
+ configuredSchemes.size()));
+ // SignatureSchemePrefSet drops unsupported algorithms silently, so the
+ // number that are configured might be fewer.
+ EXPECT_LE(configuredCount, count);
+ unsigned int i = 0;
+ for (unsigned int j = 0; j < count && i < configuredCount; ++j) {
+ if (i < configuredCount && schemes[j] == configuredSchemes[i]) {
+ ++i;
+ }
+ }
+ EXPECT_EQ(i, configuredCount) << "schemes in use were all set";
+}
+
+void TlsAgent::CheckKEA(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ size_t kea_size) const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ EXPECT_EQ(kea_type, info_.keaType);
+ if (kea_size == 0) {
+ switch (kea_group) {
+ case ssl_grp_ec_curve25519:
+ kea_size = 255;
+ break;
+ case ssl_grp_ec_secp256r1:
+ kea_size = 256;
+ break;
+ case ssl_grp_ec_secp384r1:
+ kea_size = 384;
+ break;
+ case ssl_grp_ffdhe_2048:
+ kea_size = 2048;
+ break;
+ case ssl_grp_ffdhe_3072:
+ kea_size = 3072;
+ break;
+ case ssl_grp_ffdhe_custom:
+ break;
+ default:
+ if (kea_type == ssl_kea_rsa) {
+ kea_size = server_key_bits_;
+ } else {
+ EXPECT_TRUE(false) << "need to update group sizes";
+ }
+ }
+ }
+ if (kea_group != ssl_grp_ffdhe_custom) {
+ EXPECT_EQ(kea_size, info_.keaKeyBits);
+ EXPECT_EQ(kea_group, info_.keaGroup);
+ }
+}
+
+void TlsAgent::CheckAuthType(SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ EXPECT_EQ(auth_type, info_.authType);
+ EXPECT_EQ(server_key_bits_, info_.authKeyBits);
+ if (expected_version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
+ switch (auth_type) {
+ case ssl_auth_rsa_sign:
+ sig_scheme = ssl_sig_rsa_pkcs1_sha1md5;
+ break;
+ case ssl_auth_ecdsa:
+ sig_scheme = ssl_sig_ecdsa_sha1;
+ break;
+ default:
+ break;
+ }
+ }
+ EXPECT_EQ(sig_scheme, info_.signatureScheme);
+
+ if (info_.protocolVersion >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ return;
+ }
+
+ // Check authAlgorithm, which is the old value for authType. This is a second
+ // switch
+ // statement because default label is different.
+ switch (auth_type) {
+ case ssl_auth_rsa_sign:
+ EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
+ << "authAlgorithm for RSA is always decrypt";
+ break;
+ case ssl_auth_ecdh_rsa:
+ EXPECT_EQ(ssl_auth_rsa_decrypt, csinfo_.authAlgorithm)
+ << "authAlgorithm for ECDH_RSA is RSA decrypt (i.e., wrong)";
+ break;
+ case ssl_auth_ecdh_ecdsa:
+ EXPECT_EQ(ssl_auth_ecdsa, csinfo_.authAlgorithm)
+ << "authAlgorithm for ECDH_ECDSA is ECDSA (i.e., wrong)";
+ break;
+ default:
+ EXPECT_EQ(auth_type, csinfo_.authAlgorithm)
+ << "authAlgorithm is (usually) the same as authType";
+ break;
+ }
+}
+
+void TlsAgent::EnableFalseStart() {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ falsestart_enabled_ = true;
+ EXPECT_EQ(SECSuccess,
+ SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this));
+ EXPECT_EQ(SECSuccess,
+ SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE));
+}
+
+void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
+
+void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
+}
+
+void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected) const {
+ SSLNextProtoState state;
+ char chosen[10];
+ unsigned int chosen_len;
+ SECStatus rv = SSL_GetNextProto(ssl_fd_, &state,
+ reinterpret_cast<unsigned char*>(chosen),
+ &chosen_len, sizeof(chosen));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(expected_state, state);
+ if (state == SSL_NEXT_PROTO_NO_SUPPORT) {
+ EXPECT_EQ("", expected);
+ } else {
+ EXPECT_NE("", expected);
+ EXPECT_EQ(expected, std::string(chosen, chosen_len));
+ }
+}
+
+void TlsAgent::EnableSrtp() {
+ EXPECT_TRUE(EnsureTlsSetup());
+ const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
+ SRTP_AES128_CM_HMAC_SHA1_32};
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers)));
+}
+
+void TlsAgent::CheckSrtp() const {
+ uint16_t actual;
+ EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
+ EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
+}
+
+void TlsAgent::CheckErrorCode(int32_t expected) const {
+ EXPECT_EQ(STATE_ERROR, state_);
+ EXPECT_EQ(expected, error_code_)
+ << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
+ << PORT_ErrorToName(expected) << std::endl;
+}
+
+void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
+ ASSERT_EQ(0, error_code_);
+ WAIT_(error_code_ != 0, delay);
+ EXPECT_EQ(expected, error_code_)
+ << "Got error code " << PORT_ErrorToName(error_code_) << " expecting "
+ << PORT_ErrorToName(expected) << std::endl;
+}
+
+void TlsAgent::CheckPreliminaryInfo() {
+ SSLPreliminaryChannelInfo info;
+ EXPECT_EQ(SECSuccess,
+ SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info)));
+ EXPECT_EQ(sizeof(info), info.length);
+ EXPECT_TRUE(info.valuesSet & ssl_preinfo_version);
+ EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite);
+
+ // A version of 0 is invalid and indicates no expectation. This value is
+ // initialized to 0 so that tests that don't explicitly set an expected
+ // version can negotiate a version.
+ if (!expected_version_) {
+ expected_version_ = info.protocolVersion;
+ }
+ EXPECT_EQ(expected_version_, info.protocolVersion);
+
+ // As with the version; 0 is the null cipher suite (and also invalid).
+ if (!expected_cipher_suite_) {
+ expected_cipher_suite_ = info.cipherSuite;
+ }
+ EXPECT_EQ(expected_cipher_suite_, info.cipherSuite);
+}
+
+// Check that all the expected callbacks have been called.
+void TlsAgent::CheckCallbacks() const {
+ // If false start happens, the handshake is reported as being complete at the
+ // point that false start happens.
+ if (expect_resumption_ || !falsestart_enabled_) {
+ EXPECT_TRUE(handshake_callback_called_);
+ }
+
+ // These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
+ if (role_ == SERVER) {
+ PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd_, ssl_server_name_xtn);
+ EXPECT_EQ(((!expect_resumption_ && have_sni) ||
+ expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
+ sni_hook_called_);
+ } else {
+ EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_);
+ // Note that this isn't unconditionally called, even with false start on.
+ // But the callback is only skipped if a cipher that is ridiculously weak
+ // (80 bits) is chosen. Don't test that: plan to remove bad ciphers.
+ EXPECT_EQ(falsestart_enabled_ && !expect_resumption_,
+ can_falsestart_hook_called_);
+ }
+}
+
+void TlsAgent::ResetPreliminaryInfo() {
+ expected_version_ = 0;
+ expected_cipher_suite_ = 0;
+}
+
+void TlsAgent::Connected() {
+ LOG("Handshake success");
+ CheckPreliminaryInfo();
+ CheckCallbacks();
+
+ SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(info_), info_.length);
+
+ // Preliminary values are exposed through callbacks during the handshake.
+ // If either expected values were set or the callbacks were called, check
+ // that the final values are correct.
+ EXPECT_EQ(expected_version_, info_.protocolVersion);
+ EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite);
+
+ rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
+
+ if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd_);
+ // We use one ciphersuite in each direction, plus one that's kept around
+ // by DTLS for retransmission.
+ PRInt32 expected = ((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2;
+ EXPECT_EQ(expected, cipherSuites);
+ if (expected != cipherSuites) {
+ SSLInt_PrintTls13CipherSpecs(ssl_fd_);
+ }
+ }
+
+ PRBool short_headers;
+ rv = SSLInt_UsingShortHeaders(ssl_fd_, &short_headers);
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ((PRBool)expect_short_headers_, short_headers);
+ SetState(STATE_CONNECTED);
+}
+
+void TlsAgent::EnableExtendedMasterSecret() {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv =
+ SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
+
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::CheckExtendedMasterSecret(bool expected) {
+ if (version() >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ expected = PR_TRUE;
+ }
+ ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE)
+ << "unexpected extended master secret state for " << name_;
+}
+
+void TlsAgent::CheckEarlyDataAccepted(bool expected) {
+ if (version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ expected = false;
+ }
+ ASSERT_EQ(expected, info_.earlyDataAccepted != PR_FALSE)
+ << "unexpected early data state for " << name_;
+}
+
+void TlsAgent::CheckSecretsDestroyed() {
+ ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd_));
+}
+
+void TlsAgent::DisableRollbackDetection() {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE);
+
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::EnableCompression() {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_DEFLATE, PR_TRUE);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetDowngradeCheckVersion(uint16_t version) {
+ ASSERT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd_, version);
+ ASSERT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::Handshake() {
+ LOGV("Handshake");
+ SECStatus rv = SSL_ForceHandshake(ssl_fd_);
+ if (rv == SECSuccess) {
+ Connected();
+
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ }
+
+ int32_t err = PR_GetError();
+ if (err == PR_WOULD_BLOCK_ERROR) {
+ LOGV("Would have blocked");
+ if (mode_ == DGRAM) {
+ if (timer_handle_) {
+ timer_handle_->Cancel();
+ timer_handle_ = nullptr;
+ }
+
+ PRIntervalTime timeout;
+ rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
+ if (rv == SECSuccess) {
+ Poller::Instance()->SetTimer(
+ timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
+ }
+ }
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ return;
+ }
+
+ LOG("Handshake failed with error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ error_code_ = err;
+ SetState(STATE_ERROR);
+}
+
+void TlsAgent::PrepareForRenegotiate() {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+
+ SetState(STATE_CONNECTING);
+}
+
+void TlsAgent::StartRenegotiate() {
+ PrepareForRenegotiate();
+
+ SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SendDirect(const DataBuffer& buf) {
+ LOG("Send Direct " << buf);
+ adapter_->peer()->PacketReceived(buf);
+}
+
+static bool ErrorIsNonFatal(PRErrorCode code) {
+ return code == PR_WOULD_BLOCK_ERROR || code == SSL_ERROR_RX_SHORT_DTLS_READ;
+}
+
+void TlsAgent::SendData(size_t bytes, size_t blocksize) {
+ uint8_t block[4096];
+
+ ASSERT_LT(blocksize, sizeof(block));
+
+ while (bytes) {
+ size_t tosend = std::min(blocksize, bytes);
+
+ for (size_t i = 0; i < tosend; ++i) {
+ block[i] = 0xff & send_ctr_;
+ ++send_ctr_;
+ }
+
+ SendBuffer(DataBuffer(block, tosend));
+ bytes -= tosend;
+ }
+}
+
+void TlsAgent::SendBuffer(const DataBuffer& buf) {
+ LOGV("Writing " << buf.len() << " bytes");
+ int32_t rv = PR_Write(ssl_fd_, buf.data(), buf.len());
+ if (expect_readwrite_error_) {
+ EXPECT_GT(0, rv);
+ EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
+ error_code_ = PR_GetError();
+ expect_readwrite_error_ = false;
+ } else {
+ ASSERT_EQ(buf.len(), static_cast<size_t>(rv));
+ }
+}
+
+void TlsAgent::ReadBytes() {
+ uint8_t block[1024];
+
+ int32_t rv = PR_Read(ssl_fd_, block, sizeof(block));
+ LOGV("ReadBytes " << rv);
+ int32_t err;
+
+ if (rv >= 0) {
+ size_t count = static_cast<size_t>(rv);
+ for (size_t i = 0; i < count; ++i) {
+ ASSERT_EQ(recv_ctr_ & 0xff, block[i]);
+ recv_ctr_++;
+ }
+ } else {
+ err = PR_GetError();
+ LOG("Read error " << PORT_ErrorToName(err) << ": "
+ << PORT_ErrorToString(err));
+ if (err != PR_WOULD_BLOCK_ERROR && expect_readwrite_error_) {
+ error_code_ = err;
+ expect_readwrite_error_ = false;
+ }
+ }
+
+ // If closed, then don't bother waiting around.
+ if (rv > 0 || (rv < 0 && ErrorIsNonFatal(err))) {
+ LOGV("Re-arming");
+ Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
+ &TlsAgent::ReadableCallback);
+ }
+}
+
+void TlsAgent::ResetSentBytes() { send_ctr_ = 0; }
+
+void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
+ EXPECT_TRUE(EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE,
+ mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
+ EXPECT_EQ(SECSuccess, rv);
+
+ rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::DisableECDHEServerKeyReuse() {
+ ASSERT_EQ(TlsAgent::SERVER, role_);
+ SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"};
+::testing::internal::ParamGenerator<std::string>
+ TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr);
+
+void TlsAgentTestBase::SetUp() {
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+}
+
+void TlsAgentTestBase::TearDown() {
+ delete agent_;
+ SSL_ClearSessionCache();
+ SSL_ShutdownServerSessionIDCache();
+}
+
+void TlsAgentTestBase::Reset(const std::string& server_name) {
+ delete agent_;
+ Init(server_name);
+}
+
+void TlsAgentTestBase::Init(const std::string& server_name) {
+ agent_ =
+ new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
+ role_, mode_);
+ agent_->Init();
+ fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_);
+ agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_));
+ agent_->StartConnect();
+}
+
+void TlsAgentTestBase::EnsureInit() {
+ if (!agent_) {
+ Init();
+ }
+ const std::vector<SSLNamedGroup> groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
+ ssl_grp_ffdhe_2048};
+ agent_->ConfigNamedGroups(groups);
+}
+
+void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
+ TlsAgent::State expected_state,
+ int32_t error_code) {
+ std::cerr << "Process message: " << buffer << std::endl;
+ EnsureInit();
+ agent_->adapter()->PacketReceived(buffer);
+ agent_->Handshake();
+
+ ASSERT_EQ(expected_state, agent_->state());
+
+ if (expected_state == TlsAgent::STATE_ERROR) {
+ ASSERT_EQ(error_code, agent_->error_code());
+ }
+}
+
+void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version,
+ const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num) {
+ size_t index = 0;
+ index = out->Write(index, type, 1);
+ index = out->Write(
+ index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2);
+ if (mode == DGRAM) {
+ index = out->Write(index, seq_num >> 32, 4);
+ index = out->Write(index, seq_num & PR_UINT32_MAX, 4);
+ }
+ index = out->Write(index, len, 2);
+ out->Write(index, buf, len);
+}
+
+void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
+ const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num) const {
+ MakeRecord(mode_, type, version, buf, len, out, seq_num);
+}
+
+void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
+ const uint8_t* data, size_t hs_len,
+ DataBuffer* out,
+ uint64_t seq_num) const {
+ return MakeHandshakeMessageFragment(hs_type, data, hs_len, out, seq_num, 0,
+ 0);
+}
+
+void TlsAgentTestBase::MakeHandshakeMessageFragment(
+ uint8_t hs_type, const uint8_t* data, size_t hs_len, DataBuffer* out,
+ uint64_t seq_num, uint32_t fragment_offset,
+ uint32_t fragment_length) const {
+ size_t index = 0;
+ if (!fragment_length) fragment_length = hs_len;
+ index = out->Write(index, hs_type, 1); // Handshake record type.
+ index = out->Write(index, hs_len, 3); // Handshake length
+ if (mode_ == DGRAM) {
+ index = out->Write(index, seq_num, 2);
+ index = out->Write(index, fragment_offset, 3);
+ index = out->Write(index, fragment_length, 3);
+ }
+ if (data) {
+ index = out->Write(index, data, fragment_length);
+ } else {
+ for (size_t i = 0; i < fragment_length; ++i) {
+ index = out->Write(index, 1, 1);
+ }
+ }
+}
+
+void TlsAgentTestBase::MakeTrivialHandshakeRecord(uint8_t hs_type,
+ size_t hs_len,
+ DataBuffer* out) {
+ size_t index = 0;
+ index = out->Write(index, kTlsHandshakeType, 1); // Content Type
+ index = out->Write(index, 3, 1); // Version high
+ index = out->Write(index, 1, 1); // Version low
+ index = out->Write(index, 4 + hs_len, 2); // Length
+
+ index = out->Write(index, hs_type, 1); // Handshake record type.
+ index = out->Write(index, hs_len, 3); // Handshake length
+ for (size_t i = 0; i < hs_len; ++i) {
+ index = out->Write(index, 1, 1);
+ }
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/tls_agent.h b/nss/gtests/ssl_gtest/tls_agent.h
new file mode 100644
index 0000000..78923c9
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_agent.h
@@ -0,0 +1,457 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_agent_h_
+#define tls_agent_h_
+
+#include "prio.h"
+#include "ssl.h"
+
+#include <functional>
+#include <iostream>
+
+#include "test_io.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+extern bool g_ssl_gtest_verbose;
+
+namespace nss_test {
+
+#define LOG(msg) std::cerr << role_str() << ": " << msg << std::endl
+#define LOGV(msg) \
+ do { \
+ if (g_ssl_gtest_verbose) LOG(msg); \
+ } while (false)
+
+enum SessionResumptionMode {
+ RESUME_NONE = 0,
+ RESUME_SESSIONID = 1,
+ RESUME_TICKET = 2,
+ RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET
+};
+
+class TlsAgent;
+
+const extern std::vector<SSLNamedGroup> kAllDHEGroups;
+const extern std::vector<SSLNamedGroup> kECDHEGroups;
+const extern std::vector<SSLNamedGroup> kFFDHEGroups;
+const extern std::vector<SSLNamedGroup> kFasterDHEGroups;
+
+typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
+ AuthCertificateCallbackFunction;
+
+typedef std::function<void(TlsAgent* agent)> HandshakeCallbackFunction;
+
+typedef std::function<int32_t(TlsAgent* agent, const SECItem* srvNameArr,
+ PRUint32 srvNameArrSize)>
+ SniCallbackFunction;
+
+class TlsAgent : public PollTarget {
+ public:
+ enum Role { CLIENT, SERVER };
+ enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR };
+
+ static const std::string kClient; // the client key is sign only
+ static const std::string kRsa2048; // bigger sign and encrypt for either
+ static const std::string kServerRsa; // both sign and encrypt
+ static const std::string kServerRsaSign;
+ static const std::string kServerRsaPss;
+ static const std::string kServerRsaDecrypt;
+ static const std::string kServerRsaChain; // A cert that requires a chain.
+ static const std::string kServerEcdsa256;
+ static const std::string kServerEcdsa384;
+ static const std::string kServerEcdsa521;
+ static const std::string kServerEcdhEcdsa;
+ static const std::string kServerEcdhRsa;
+ static const std::string kServerDsa;
+
+ TlsAgent(const std::string& name, Role role, Mode mode);
+ virtual ~TlsAgent();
+
+ bool Init() {
+ pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_);
+ if (!pr_fd_) return false;
+
+ adapter_ = DummyPrSocket::GetAdapter(pr_fd_);
+ if (!adapter_) return false;
+
+ return true;
+ }
+
+ void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); }
+
+ void SetPacketFilter(PacketFilter* filter) {
+ adapter_->SetPacketFilter(filter);
+ }
+
+ void StartConnect(PRFileDesc* model = nullptr);
+ void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
+ size_t kea_size = 0) const;
+ void CheckAuthType(SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) const;
+
+ void DisableAllCiphers();
+ void EnableCiphersByAuthType(SSLAuthType authType);
+ void EnableCiphersByKeyExchange(SSLKEAType kea);
+ void EnableGroupsByKeyExchange(SSLKEAType kea);
+ void EnableGroupsByAuthType(SSLAuthType authType);
+ void EnableSingleCipher(uint16_t cipher);
+
+ void Handshake();
+ // Marks the internal state as CONNECTING in anticipation of renegotiation.
+ void PrepareForRenegotiate();
+ // Prepares for renegotiation, then actually triggers it.
+ void StartRenegotiate();
+ bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
+ const SSLExtraServerCertData* serverCertData = nullptr);
+ bool ConfigServerCertWithChain(const std::string& name);
+ bool EnsureTlsSetup(PRFileDesc* modelSocket = nullptr);
+
+ void SetupClientAuth();
+ void RequestClientAuth(bool requireAuth);
+ bool GetClientAuthCredentials(CERTCertificate** cert,
+ SECKEYPrivateKey** priv) const;
+
+ void ConfigureSessionCache(SessionResumptionMode mode);
+ void SetSessionTicketsEnabled(bool en);
+ void SetSessionCacheEnabled(bool en);
+ void Set0RttEnabled(bool en);
+ void SetShortHeadersEnabled();
+ void SetVersionRange(uint16_t minver, uint16_t maxver);
+ void GetVersionRange(uint16_t* minver, uint16_t* maxver);
+ void CheckPreliminaryInfo();
+ void ResetPreliminaryInfo();
+ void SetExpectedVersion(uint16_t version);
+ void SetServerKeyBits(uint16_t bits);
+ void ExpectReadWriteError();
+ void EnableFalseStart();
+ void ExpectResumption();
+ void ExpectShortHeaders();
+ void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
+ void EnableAlpn(const uint8_t* val, size_t len);
+ void CheckAlpn(SSLNextProtoState expected_state,
+ const std::string& expected = "") const;
+ void EnableSrtp();
+ void CheckSrtp() const;
+ void CheckErrorCode(int32_t expected) const;
+ void WaitForErrorCode(int32_t expected, uint32_t delay) const;
+ // Send data on the socket, encrypting it.
+ void SendData(size_t bytes, size_t blocksize = 1024);
+ void SendBuffer(const DataBuffer& buf);
+ // Send data directly to the underlying socket, skipping the TLS layer.
+ void SendDirect(const DataBuffer& buf);
+ void ReadBytes();
+ void ResetSentBytes(); // Hack to test drops.
+ void EnableExtendedMasterSecret();
+ void CheckExtendedMasterSecret(bool expected);
+ void CheckEarlyDataAccepted(bool expected);
+ void DisableRollbackDetection();
+ void EnableCompression();
+ void SetDowngradeCheckVersion(uint16_t version);
+ void CheckSecretsDestroyed();
+ void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
+ void DisableECDHEServerKeyReuse();
+ bool GetPeerChainLength(size_t* count);
+
+ const std::string& name() const { return name_; }
+
+ Role role() const { return role_; }
+ std::string role_str() const { return role_ == SERVER ? "server" : "client"; }
+
+ State state() const { return state_; }
+
+ const CERTCertificate* peer_cert() const {
+ return SSL_PeerCertificate(ssl_fd_);
+ }
+
+ const char* state_str() const { return state_str(state()); }
+
+ static const char* state_str(State state) { return states[state]; }
+
+ PRFileDesc* ssl_fd() { return ssl_fd_; }
+ DummyPrSocket* adapter() { return adapter_; }
+
+ bool is_compressed() const {
+ return info_.compressionMethod != ssl_compression_null;
+ }
+ uint16_t server_key_bits() const { return server_key_bits_; }
+ uint16_t min_version() const { return vrange_.min; }
+ uint16_t max_version() const { return vrange_.max; }
+ uint16_t version() const {
+ EXPECT_EQ(STATE_CONNECTED, state_);
+ return info_.protocolVersion;
+ }
+
+ bool cipher_suite(uint16_t* cipher_suite) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *cipher_suite = info_.cipherSuite;
+ return true;
+ }
+
+ std::string cipher_suite_name() const {
+ if (state_ != STATE_CONNECTED) return "UNKNOWN";
+
+ return csinfo_.cipherSuiteName;
+ }
+
+ std::vector<uint8_t> session_id() const {
+ return std::vector<uint8_t>(info_.sessionID,
+ info_.sessionID + info_.sessionIDLength);
+ }
+
+ bool auth_type(SSLAuthType* auth_type) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *auth_type = info_.authType;
+ return true;
+ }
+
+ bool kea_type(SSLKEAType* kea_type) const {
+ if (state_ != STATE_CONNECTED) return false;
+
+ *kea_type = info_.keaType;
+ return true;
+ }
+
+ size_t received_bytes() const { return recv_ctr_; }
+ PRErrorCode error_code() const { return error_code_; }
+
+ bool can_falsestart_hook_called() const {
+ return can_falsestart_hook_called_;
+ }
+
+ void SetHandshakeCallback(HandshakeCallbackFunction handshake_callback) {
+ handshake_callback_ = handshake_callback;
+ }
+
+ void SetAuthCertificateCallback(
+ AuthCertificateCallbackFunction auth_certificate_callback) {
+ auth_certificate_callback_ = auth_certificate_callback;
+ }
+
+ void SetSniCallback(SniCallbackFunction sni_callback) {
+ sni_callback_ = sni_callback;
+ }
+
+ private:
+ const static char* states[];
+
+ void SetState(State state);
+
+ // Dummy auth certificate hook.
+ static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd,
+ PRBool checksig, PRBool isServer) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ agent->auth_certificate_hook_called_ = true;
+ if (agent->auth_certificate_callback_) {
+ return agent->auth_certificate_callback_(agent, checksig ? true : false,
+ isServer ? true : false);
+ }
+ return SECSuccess;
+ }
+
+ // Client auth certificate hook.
+ static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd,
+ PRBool checksig, PRBool isServer) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ EXPECT_TRUE(agent->expect_client_auth_);
+ EXPECT_EQ(PR_TRUE, isServer);
+ if (agent->auth_certificate_callback_) {
+ return agent->auth_certificate_callback_(agent, checksig ? true : false,
+ isServer ? true : false);
+ }
+ return SECSuccess;
+ }
+
+ static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd,
+ CERTDistNames* caNames,
+ CERTCertificate** cert,
+ SECKEYPrivateKey** privKey);
+
+ static void ReadableCallback(PollTarget* self, Event event) {
+ TlsAgent* agent = static_cast<TlsAgent*>(self);
+ if (event == TIMER_EVENT) {
+ agent->timer_handle_ = nullptr;
+ }
+ agent->ReadableCallback_int();
+ }
+
+ void ReadableCallback_int() {
+ LOGV("Readable");
+ switch (state_) {
+ case STATE_CONNECTING:
+ Handshake();
+ break;
+ case STATE_CONNECTED:
+ ReadBytes();
+ break;
+ default:
+ break;
+ }
+ }
+
+ static PRInt32 SniHook(PRFileDesc* fd, const SECItem* srvNameArr,
+ PRUint32 srvNameArrSize, void* arg) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ agent->sni_hook_called_ = true;
+ EXPECT_EQ(1UL, srvNameArrSize);
+ if (agent->sni_callback_) {
+ return agent->sni_callback_(agent, srvNameArr, srvNameArrSize);
+ }
+ return 0; // First configuration.
+ }
+
+ static SECStatus CanFalseStartCallback(PRFileDesc* fd, void* arg,
+ PRBool* canFalseStart) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->CheckPreliminaryInfo();
+ EXPECT_TRUE(agent->falsestart_enabled_);
+ EXPECT_FALSE(agent->can_falsestart_hook_called_);
+ agent->can_falsestart_hook_called_ = true;
+ *canFalseStart = true;
+ return SECSuccess;
+ }
+
+ static void HandshakeCallback(PRFileDesc* fd, void* arg) {
+ TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
+ agent->handshake_callback_called_ = true;
+ agent->Connected();
+ if (agent->handshake_callback_) {
+ agent->handshake_callback_(agent);
+ }
+ }
+
+ void DisableLameGroups();
+ void ConfigStrongECGroups(bool en);
+ void ConfigAllDHGroups(bool en);
+ void CheckCallbacks() const;
+ void Connected();
+
+ const std::string name_;
+ Mode mode_;
+ uint16_t server_key_bits_;
+ PRFileDesc* pr_fd_;
+ DummyPrSocket* adapter_;
+ PRFileDesc* ssl_fd_;
+ Role role_;
+ State state_;
+ Poller::Timer* timer_handle_;
+ bool falsestart_enabled_;
+ uint16_t expected_version_;
+ uint16_t expected_cipher_suite_;
+ bool expect_resumption_;
+ bool expect_client_auth_;
+ bool can_falsestart_hook_called_;
+ bool sni_hook_called_;
+ bool auth_certificate_hook_called_;
+ bool handshake_callback_called_;
+ SSLChannelInfo info_;
+ SSLCipherSuiteInfo csinfo_;
+ SSLVersionRange vrange_;
+ PRErrorCode error_code_;
+ size_t send_ctr_;
+ size_t recv_ctr_;
+ bool expect_readwrite_error_;
+ HandshakeCallbackFunction handshake_callback_;
+ AuthCertificateCallbackFunction auth_certificate_callback_;
+ SniCallbackFunction sni_callback_;
+ bool expect_short_headers_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const TlsAgent::State& state) {
+ return stream << TlsAgent::state_str(state);
+}
+
+class TlsAgentTestBase : public ::testing::Test {
+ public:
+ static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;
+
+ TlsAgentTestBase(TlsAgent::Role role, Mode mode)
+ : agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {}
+ ~TlsAgentTestBase() {
+ if (fd_) {
+ PR_Close(fd_);
+ }
+ }
+
+ void SetUp();
+ void TearDown();
+
+ static void MakeRecord(Mode mode, uint8_t type, uint16_t version,
+ const uint8_t* buf, size_t len, DataBuffer* out,
+ uint64_t seq_num = 0);
+ void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
+ size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
+ void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
+ DataBuffer* out, uint64_t seq_num = 0) const;
+ void MakeHandshakeMessageFragment(uint8_t hs_type, const uint8_t* data,
+ size_t hs_len, DataBuffer* out,
+ uint64_t seq_num, uint32_t fragment_offset,
+ uint32_t fragment_length) const;
+ static void MakeTrivialHandshakeRecord(uint8_t hs_type, size_t hs_len,
+ DataBuffer* out);
+ static inline TlsAgent::Role ToRole(const std::string& str) {
+ return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
+ }
+
+ static inline Mode ToMode(const std::string& str) {
+ return str == "TLS" ? STREAM : DGRAM;
+ }
+
+ void Init(const std::string& server_name = TlsAgent::kServerRsa);
+ void Reset(const std::string& server_name = TlsAgent::kServerRsa);
+
+ protected:
+ void EnsureInit();
+ void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
+ int32_t error_code = 0);
+
+ TlsAgent* agent_;
+ PRFileDesc* fd_;
+ TlsAgent::Role role_;
+ Mode mode_;
+};
+
+class TlsAgentTest : public TlsAgentTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<std::string, std::string>> {
+ public:
+ TlsAgentTest()
+ : TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
+ ToMode(std::get<1>(GetParam()))) {}
+};
+
+class TlsAgentTestClient : public TlsAgentTestBase,
+ public ::testing::WithParamInterface<std::string> {
+ public:
+ TlsAgentTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, ToMode(GetParam())) {}
+};
+
+class TlsAgentStreamTestClient : public TlsAgentTestBase {
+ public:
+ TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {}
+};
+
+class TlsAgentStreamTestServer : public TlsAgentTestBase {
+ public:
+ TlsAgentStreamTestServer() : TlsAgentTestBase(TlsAgent::SERVER, STREAM) {}
+};
+
+class TlsAgentDgramTestClient : public TlsAgentTestBase {
+ public:
+ TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {}
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/nss/gtests/ssl_gtest/tls_connect.cc b/nss/gtests/ssl_gtest/tls_connect.cc
new file mode 100644
index 0000000..d025499
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_connect.cc
@@ -0,0 +1,708 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_connect.h"
+extern "C" {
+#include "libssl_internals.h"
+}
+
+#include <iostream>
+
+#include "databuffer.h"
+#include "gtest_utils.h"
+#include "sslproto.h"
+
+extern std::string g_working_dir_path;
+
+namespace nss_test {
+
+static const std::string kTlsModesStreamArr[] = {"TLS"};
+::testing::internal::ParamGenerator<std::string>
+ TlsConnectTestBase::kTlsModesStream =
+ ::testing::ValuesIn(kTlsModesStreamArr);
+static const std::string kTlsModesDatagramArr[] = {"DTLS"};
+::testing::internal::ParamGenerator<std::string>
+ TlsConnectTestBase::kTlsModesDatagram =
+ ::testing::ValuesIn(kTlsModesDatagramArr);
+static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"};
+::testing::internal::ParamGenerator<std::string>
+ TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr);
+
+static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 =
+ ::testing::ValuesIn(kTlsV10Arr);
+static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11 =
+ ::testing::ValuesIn(kTlsV11Arr);
+static const uint16_t kTlsV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12 =
+ ::testing::ValuesIn(kTlsV12Arr);
+static const uint16_t kTlsV10V11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10V11 =
+ ::testing::ValuesIn(kTlsV10V11Arr);
+static const uint16_t kTlsV10ToV12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10ToV12 =
+ ::testing::ValuesIn(kTlsV10ToV12Arr);
+static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11V12 =
+ ::testing::ValuesIn(kTlsV11V12Arr);
+
+static const uint16_t kTlsV11PlusArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV11Plus =
+ ::testing::ValuesIn(kTlsV11PlusArr);
+static const uint16_t kTlsV12PlusArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV12Plus =
+ ::testing::ValuesIn(kTlsV12PlusArr);
+static const uint16_t kTlsV13Arr[] = {SSL_LIBRARY_VERSION_TLS_1_3};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV13 =
+ ::testing::ValuesIn(kTlsV13Arr);
+static const uint16_t kTlsVAllArr[] = {
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_TLS_1_2, SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_0};
+::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsVAll =
+ ::testing::ValuesIn(kTlsVAllArr);
+
+std::string VersionString(uint16_t version) {
+ switch (version) {
+ case 0:
+ return "(no version)";
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ return "1.0";
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ return "1.1";
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ return "1.2";
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ return "1.3";
+ default:
+ std::cerr << "Invalid version: " << version << std::endl;
+ EXPECT_TRUE(false);
+ return "";
+ }
+}
+
+TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version)
+ : mode_(mode),
+ client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)),
+ server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)),
+ client_model_(nullptr),
+ server_model_(nullptr),
+ version_(version),
+ expected_resumption_mode_(RESUME_NONE),
+ session_ids_(),
+ expect_extended_master_secret_(false),
+ expect_early_data_accepted_(false) {
+ std::string v;
+ if (mode_ == DGRAM && version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
+ v = "1.0";
+ } else {
+ v = VersionString(version_);
+ }
+ std::cerr << "Version: " << mode_ << " " << v << std::endl;
+}
+
+TlsConnectTestBase::TlsConnectTestBase(const std::string& mode,
+ uint16_t version)
+ : TlsConnectTestBase(TlsConnectTestBase::ToMode(mode), version) {}
+
+TlsConnectTestBase::~TlsConnectTestBase() {}
+
+// Check the group of each of the supported groups
+void TlsConnectTestBase::CheckGroups(
+ const DataBuffer& groups, std::function<void(SSLNamedGroup)> check_group) {
+ DuplicateGroupChecker group_set;
+ uint32_t tmp = 0;
+ EXPECT_TRUE(groups.Read(0, 2, &tmp));
+ EXPECT_EQ(groups.len() - 2, static_cast<size_t>(tmp));
+ for (size_t i = 2; i < groups.len(); i += 2) {
+ EXPECT_TRUE(groups.Read(i, 2, &tmp));
+ SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
+ group_set.AddAndCheckGroup(group);
+ check_group(group);
+ }
+}
+
+// Check the group of each of the shares
+void TlsConnectTestBase::CheckShares(
+ const DataBuffer& shares, std::function<void(SSLNamedGroup)> check_group) {
+ DuplicateGroupChecker group_set;
+ uint32_t tmp = 0;
+ EXPECT_TRUE(shares.Read(0, 2, &tmp));
+ EXPECT_EQ(shares.len() - 2, static_cast<size_t>(tmp));
+ size_t i;
+ for (i = 2; i < shares.len(); i += 4 + tmp) {
+ ASSERT_TRUE(shares.Read(i, 2, &tmp));
+ SSLNamedGroup group = static_cast<SSLNamedGroup>(tmp);
+ group_set.AddAndCheckGroup(group);
+ check_group(group);
+ ASSERT_TRUE(shares.Read(i + 2, 2, &tmp));
+ }
+ EXPECT_EQ(shares.len(), i);
+}
+
+void TlsConnectTestBase::ClearStats() {
+ // Clear statistics.
+ SSL3Statistics* stats = SSL_GetStatistics();
+ memset(stats, 0, sizeof(*stats));
+}
+
+void TlsConnectTestBase::ClearServerCache() {
+ SSL_ShutdownServerSessionIDCache();
+ SSLInt_ClearSessionTicketKey();
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+}
+
+void TlsConnectTestBase::SetUp() {
+ SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
+ SSLInt_ClearSessionTicketKey();
+ ClearStats();
+ Init();
+}
+
+void TlsConnectTestBase::TearDown() {
+ delete client_;
+ delete server_;
+ if (client_model_) {
+ ASSERT_NE(server_model_, nullptr);
+ delete client_model_;
+ delete server_model_;
+ }
+
+ SSL_ClearSessionCache();
+ SSLInt_ClearSessionTicketKey();
+ SSL_ShutdownServerSessionIDCache();
+}
+
+void TlsConnectTestBase::Init() {
+ EXPECT_TRUE(client_->Init());
+ EXPECT_TRUE(server_->Init());
+
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+
+ if (version_) {
+ ConfigureVersion(version_);
+ }
+}
+
+void TlsConnectTestBase::Reset() {
+ // Take a copy of the names because they are about to disappear.
+ std::string server_name = server_->name();
+ std::string client_name = client_->name();
+ Reset(server_name, client_name);
+}
+
+void TlsConnectTestBase::Reset(const std::string& server_name,
+ const std::string& client_name) {
+ delete client_;
+ delete server_;
+
+ client_ = new TlsAgent(client_name, TlsAgent::CLIENT, mode_);
+ server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_);
+
+ Init();
+}
+
+void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) {
+ expected_resumption_mode_ = expected;
+ if (expected != RESUME_NONE) {
+ client_->ExpectResumption();
+ server_->ExpectResumption();
+ }
+}
+
+void TlsConnectTestBase::EnsureTlsSetup() {
+ EXPECT_TRUE(server_->EnsureTlsSetup(server_model_ ? server_model_->ssl_fd()
+ : nullptr));
+ EXPECT_TRUE(client_->EnsureTlsSetup(client_model_ ? client_model_->ssl_fd()
+ : nullptr));
+}
+
+void TlsConnectTestBase::Handshake() {
+ EnsureTlsSetup();
+ client_->SetServerKeyBits(server_->server_key_bits());
+ client_->Handshake();
+ server_->Handshake();
+
+ ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) &&
+ (server_->state() != TlsAgent::STATE_CONNECTING),
+ 5000);
+}
+
+void TlsConnectTestBase::EnableExtendedMasterSecret() {
+ client_->EnableExtendedMasterSecret();
+ server_->EnableExtendedMasterSecret();
+ ExpectExtendedMasterSecret(true);
+}
+
+void TlsConnectTestBase::Connect() {
+ server_->StartConnect(server_model_ ? server_model_->ssl_fd() : nullptr);
+ client_->StartConnect(client_model_ ? client_model_->ssl_fd() : nullptr);
+ Handshake();
+ CheckConnected();
+}
+
+void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
+ EnsureTlsSetup();
+ client_->EnableSingleCipher(cipher_suite);
+
+ Connect();
+ SendReceive();
+
+ // Check that we used the right cipher suite.
+ uint16_t actual;
+ EXPECT_TRUE(client_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite, actual);
+ EXPECT_TRUE(server_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite, actual);
+}
+
+void TlsConnectTestBase::CheckConnected() {
+ // Check the version is as expected
+ EXPECT_EQ(client_->version(), server_->version());
+ EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
+ client_->version());
+
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+ EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+
+ uint16_t cipher_suite1, cipher_suite2;
+ bool ret = client_->cipher_suite(&cipher_suite1);
+ EXPECT_TRUE(ret);
+ ret = server_->cipher_suite(&cipher_suite2);
+ EXPECT_TRUE(ret);
+ EXPECT_EQ(cipher_suite1, cipher_suite2);
+
+ std::cerr << "Connected with version " << client_->version()
+ << " cipher suite " << client_->cipher_suite_name() << std::endl;
+
+ if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // Check and store session ids.
+ std::vector<uint8_t> sid_c1 = client_->session_id();
+ EXPECT_EQ(32U, sid_c1.size());
+ std::vector<uint8_t> sid_s1 = server_->session_id();
+ EXPECT_EQ(32U, sid_s1.size());
+ EXPECT_EQ(sid_c1, sid_s1);
+ session_ids_.push_back(sid_c1);
+ }
+
+ CheckExtendedMasterSecret();
+ CheckEarlyDataAccepted();
+ CheckResumption(expected_resumption_mode_);
+ client_->CheckSecretsDestroyed();
+ server_->CheckSecretsDestroyed();
+}
+
+void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLAuthType auth_type,
+ SSLSignatureScheme sig_scheme) const {
+ client_->CheckKEA(kea_type, kea_group);
+ server_->CheckKEA(kea_type, kea_group);
+ client_->CheckAuthType(auth_type, sig_scheme);
+ server_->CheckAuthType(auth_type, sig_scheme);
+}
+
+void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
+ SSLAuthType auth_type) const {
+ SSLNamedGroup group;
+ switch (kea_type) {
+ case ssl_kea_ecdh:
+ group = ssl_grp_ec_curve25519;
+ break;
+ case ssl_kea_dh:
+ group = ssl_grp_ffdhe_2048;
+ break;
+ case ssl_kea_rsa:
+ group = ssl_grp_none;
+ break;
+ default:
+ EXPECT_TRUE(false) << "unexpected KEA";
+ group = ssl_grp_none;
+ break;
+ }
+
+ SSLSignatureScheme scheme;
+ switch (auth_type) {
+ case ssl_auth_rsa_decrypt:
+ scheme = ssl_sig_none;
+ break;
+ case ssl_auth_rsa_sign:
+ scheme = ssl_sig_rsa_pss_sha256;
+ break;
+ case ssl_auth_ecdsa:
+ scheme = ssl_sig_ecdsa_secp256r1_sha256;
+ break;
+ case ssl_auth_dsa:
+ scheme = ssl_sig_dsa_sha1;
+ break;
+ default:
+ EXPECT_TRUE(false) << "unexpected auth type";
+ scheme = static_cast<SSLSignatureScheme>(0x0100);
+ break;
+ }
+ CheckKeys(kea_type, group, auth_type, scheme);
+}
+
+void TlsConnectTestBase::CheckKeys() const {
+ CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
+}
+
+void TlsConnectTestBase::ConnectExpectFail() {
+ server_->StartConnect();
+ client_->StartConnect();
+ Handshake();
+ ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state());
+ ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
+}
+
+void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
+ client_->SetVersionRange(version, version);
+ server_->SetVersionRange(version, version);
+}
+
+void TlsConnectTestBase::SetExpectedVersion(uint16_t version) {
+ client_->SetExpectedVersion(version);
+ server_->SetExpectedVersion(version);
+}
+
+void TlsConnectTestBase::DisableAllCiphers() {
+ EnsureTlsSetup();
+ client_->DisableAllCiphers();
+ server_->DisableAllCiphers();
+}
+
+void TlsConnectTestBase::EnableOnlyStaticRsaCiphers() {
+ DisableAllCiphers();
+
+ client_->EnableCiphersByKeyExchange(ssl_kea_rsa);
+ server_->EnableCiphersByKeyExchange(ssl_kea_rsa);
+}
+
+void TlsConnectTestBase::EnableOnlyDheCiphers() {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ DisableAllCiphers();
+ client_->EnableCiphersByKeyExchange(ssl_kea_dh);
+ server_->EnableCiphersByKeyExchange(ssl_kea_dh);
+ } else {
+ client_->ConfigNamedGroups(kFFDHEGroups);
+ server_->ConfigNamedGroups(kFFDHEGroups);
+ }
+}
+
+void TlsConnectTestBase::EnableSomeEcdhCiphers() {
+ if (version_ < SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
+ client_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
+ server_->EnableCiphersByAuthType(ssl_auth_ecdh_rsa);
+ server_->EnableCiphersByAuthType(ssl_auth_ecdh_ecdsa);
+ } else {
+ client_->ConfigNamedGroups(kECDHEGroups);
+ server_->ConfigNamedGroups(kECDHEGroups);
+ }
+}
+
+void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
+ SessionResumptionMode server) {
+ client_->ConfigureSessionCache(client);
+ server_->ConfigureSessionCache(server);
+ if ((server & RESUME_TICKET) != 0) {
+ // This is an abomination. NSS encrypts session tickets with the server's
+ // RSA public key. That means we need the server to have an RSA certificate
+ // even if it won't be used for the connection.
+ server_->ConfigServerCert(TlsAgent::kServerRsaDecrypt);
+ }
+}
+
+void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) {
+ EXPECT_NE(RESUME_BOTH, expected);
+
+ int resume_count = expected ? 1 : 0;
+ int stateless_count = (expected & RESUME_TICKET) ? 1 : 0;
+
+ // Note: hch == server counter; hsh == client counter.
+ SSL3Statistics* stats = SSL_GetStatistics();
+ EXPECT_EQ(resume_count, stats->hch_sid_cache_hits);
+ EXPECT_EQ(resume_count, stats->hsh_sid_cache_hits);
+
+ EXPECT_EQ(stateless_count, stats->hch_sid_stateless_resumes);
+ EXPECT_EQ(stateless_count, stats->hsh_sid_stateless_resumes);
+
+ if (expected != RESUME_NONE) {
+ if (client_->version() < SSL_LIBRARY_VERSION_TLS_1_3) {
+ // Check that the last two session ids match.
+ ASSERT_EQ(2U, session_ids_.size());
+ EXPECT_EQ(session_ids_[session_ids_.size() - 1],
+ session_ids_[session_ids_.size() - 2]);
+ } else {
+ // TLS 1.3 only uses tickets.
+ EXPECT_TRUE(expected & RESUME_TICKET);
+ }
+ }
+}
+
+void TlsConnectTestBase::EnableAlpn() {
+ client_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+ server_->EnableAlpn(alpn_dummy_val_, sizeof(alpn_dummy_val_));
+}
+
+void TlsConnectTestBase::EnableAlpn(const uint8_t* val, size_t len) {
+ client_->EnableAlpn(val, len);
+ server_->EnableAlpn(val, len);
+}
+
+void TlsConnectTestBase::EnsureModelSockets() {
+ // Make sure models agents are available.
+ if (!client_model_) {
+ ASSERT_EQ(server_model_, nullptr);
+ client_model_ = new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_);
+ server_model_ = new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_);
+ }
+
+ // Initialise agents.
+ ASSERT_TRUE(client_model_->Init());
+ ASSERT_TRUE(server_model_->Init());
+}
+
+void TlsConnectTestBase::CheckAlpn(const std::string& val) {
+ client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, val);
+ server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, val);
+}
+
+void TlsConnectTestBase::EnableSrtp() {
+ client_->EnableSrtp();
+ server_->EnableSrtp();
+}
+
+void TlsConnectTestBase::CheckSrtp() const {
+ client_->CheckSrtp();
+ server_->CheckSrtp();
+}
+
+void TlsConnectTestBase::SendReceive() {
+ client_->SendData(50);
+ server_->SendData(50);
+ Receive(50);
+}
+
+// Do a first connection so we can do 0-RTT on the second one.
+void TlsConnectTestBase::SetupForZeroRtt() {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->Set0RttEnabled(true); // So we signal that we allow 0-RTT.
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->StartConnect();
+ client_->StartConnect();
+}
+
+// Do a first connection so we can do resumption
+void TlsConnectTestBase::SetupForResume() {
+ EnsureTlsSetup();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive(); // Need to read so that we absorb the session ticket.
+ CheckKeys();
+
+ Reset();
+}
+
+void TlsConnectTestBase::ZeroRttSendReceive(
+ bool expect_writable, bool expect_readable,
+ std::function<bool()> post_clienthello_check) {
+ const char* k0RttData = "ABCDEF";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+
+ client_->Handshake(); // Send ClientHello.
+ if (post_clienthello_check) {
+ if (!post_clienthello_check()) return;
+ }
+ PRInt32 rv =
+ PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen); // 0-RTT write.
+ if (expect_writable) {
+ EXPECT_EQ(k0RttDataLen, rv);
+ } else {
+ EXPECT_EQ(SECFailure, rv);
+ }
+ server_->Handshake(); // Consume ClientHello, EE, Finished.
+
+ std::vector<uint8_t> buf(k0RttDataLen);
+ rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen); // 0-RTT read
+ if (expect_readable) {
+ std::cerr << "0-RTT read " << rv << " bytes\n";
+ EXPECT_EQ(k0RttDataLen, rv);
+ } else {
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+ }
+
+ // Do a second read. this should fail.
+ rv = PR_Read(server_->ssl_fd(), buf.data(), k0RttDataLen);
+ EXPECT_EQ(SECFailure, rv);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+}
+
+void TlsConnectTestBase::Receive(size_t amount) {
+ WAIT_(client_->received_bytes() == amount &&
+ server_->received_bytes() == amount,
+ 2000);
+ ASSERT_EQ(amount, client_->received_bytes());
+ ASSERT_EQ(amount, server_->received_bytes());
+}
+
+void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) {
+ expect_extended_master_secret_ = expected;
+}
+
+void TlsConnectTestBase::CheckExtendedMasterSecret() {
+ client_->CheckExtendedMasterSecret(expect_extended_master_secret_);
+ server_->CheckExtendedMasterSecret(expect_extended_master_secret_);
+}
+
+void TlsConnectTestBase::ExpectEarlyDataAccepted(bool expected) {
+ expect_early_data_accepted_ = expected;
+}
+
+void TlsConnectTestBase::CheckEarlyDataAccepted() {
+ client_->CheckEarlyDataAccepted(expect_early_data_accepted_);
+ server_->CheckEarlyDataAccepted(expect_early_data_accepted_);
+}
+
+void TlsConnectTestBase::DisableECDHEServerKeyReuse() {
+ server_->DisableECDHEServerKeyReuse();
+}
+
+TlsConnectGeneric::TlsConnectGeneric()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectPre12::TlsConnectPre12()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectTls12::TlsConnectTls12()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_2) {}
+
+TlsConnectTls12Plus::TlsConnectTls12Plus()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+TlsConnectTls13::TlsConnectTls13()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+void TlsKeyExchangeTest::EnsureKeyShareSetup() {
+ EnsureTlsSetup();
+ groups_capture_ = new TlsExtensionCapture(ssl_supported_groups_xtn);
+ shares_capture_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
+ shares_capture2_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn, true);
+ std::vector<PacketFilter*> captures;
+ captures.push_back(groups_capture_);
+ captures.push_back(shares_capture_);
+ captures.push_back(shares_capture2_);
+ client_->SetPacketFilter(new ChainedPacketFilter(captures));
+ capture_hrr_ =
+ new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest);
+ server_->SetPacketFilter(capture_hrr_);
+}
+
+void TlsKeyExchangeTest::ConfigNamedGroups(
+ const std::vector<SSLNamedGroup>& groups) {
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+}
+
+std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetGroupDetails(
+ const DataBuffer& ext) {
+ uint32_t tmp = 0;
+ EXPECT_TRUE(ext.Read(0, 2, &tmp));
+ EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+ EXPECT_TRUE(ext.len() % 2 == 0);
+ std::vector<SSLNamedGroup> groups;
+ for (size_t i = 1; i < ext.len() / 2; i += 1) {
+ EXPECT_TRUE(ext.Read(2 * i, 2, &tmp));
+ groups.push_back(static_cast<SSLNamedGroup>(tmp));
+ }
+ return groups;
+}
+
+std::vector<SSLNamedGroup> TlsKeyExchangeTest::GetShareDetails(
+ const DataBuffer& ext) {
+ uint32_t tmp = 0;
+ EXPECT_TRUE(ext.Read(0, 2, &tmp));
+ EXPECT_EQ(ext.len() - 2, static_cast<size_t>(tmp));
+ std::vector<SSLNamedGroup> shares;
+ size_t i = 2;
+ while (i < ext.len()) {
+ EXPECT_TRUE(ext.Read(i, 2, &tmp));
+ shares.push_back(static_cast<SSLNamedGroup>(tmp));
+ EXPECT_TRUE(ext.Read(i + 2, 2, &tmp));
+ i += 4 + tmp;
+ }
+ EXPECT_EQ(ext.len(), i);
+ return shares;
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares, bool expect_hrr) {
+ std::vector<SSLNamedGroup> groups =
+ GetGroupDetails(groups_capture_->extension());
+ EXPECT_EQ(expected_groups, groups);
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ASSERT_LT(0U, expected_shares.size());
+ std::vector<SSLNamedGroup> shares =
+ GetShareDetails(shares_capture_->extension());
+ EXPECT_EQ(expected_shares, shares);
+ } else {
+ EXPECT_EQ(0U, shares_capture_->extension().len());
+ }
+
+ EXPECT_EQ(expect_hrr, capture_hrr_->buffer().len() != 0);
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares) {
+ CheckKEXDetails(expected_groups, expected_shares, false);
+}
+
+void TlsKeyExchangeTest::CheckKEXDetails(
+ const std::vector<SSLNamedGroup>& expected_groups,
+ const std::vector<SSLNamedGroup>& expected_shares,
+ SSLNamedGroup expected_share2) {
+ CheckKEXDetails(expected_groups, expected_shares, true);
+
+ for (auto it : expected_shares) {
+ EXPECT_NE(expected_share2, it);
+ }
+ std::vector<SSLNamedGroup> expected_shares2 = {expected_share2};
+ std::vector<SSLNamedGroup> shares =
+ GetShareDetails(shares_capture2_->extension());
+ EXPECT_EQ(expected_shares2, shares);
+}
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/tls_connect.h b/nss/gtests/ssl_gtest/tls_connect.h
new file mode 100644
index 0000000..aa4a32d
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_connect.h
@@ -0,0 +1,274 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_connect_h_
+#define tls_connect_h_
+
+#include <tuple>
+
+#include "sslproto.h"
+#include "sslt.h"
+
+#include "tls_agent.h"
+#include "tls_filter.h"
+
+#define GTEST_HAS_RTTI 0
+#include "gtest/gtest.h"
+
+namespace nss_test {
+
+extern std::string VersionString(uint16_t version);
+
+// A generic TLS connection test base.
+class TlsConnectTestBase : public ::testing::Test {
+ public:
+ static ::testing::internal::ParamGenerator<std::string> kTlsModesStream;
+ static ::testing::internal::ParamGenerator<std::string> kTlsModesDatagram;
+ static ::testing::internal::ParamGenerator<std::string> kTlsModesAll;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV13;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
+ static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll;
+
+ TlsConnectTestBase(Mode mode, uint16_t version);
+ TlsConnectTestBase(const std::string& mode, uint16_t version);
+ virtual ~TlsConnectTestBase();
+
+ void SetUp();
+ void TearDown();
+
+ // Initialize client and server.
+ void Init();
+ // Clear the statistics.
+ void ClearStats();
+ // Clear the server session cache.
+ void ClearServerCache();
+ // Make sure TLS is configured for a connection.
+ void EnsureTlsSetup();
+ // Reset and keep the same certificate names
+ void Reset();
+ // Reset, and update the certificate names on both peers
+ void Reset(const std::string& server_name,
+ const std::string& client_name = "client");
+
+ // Run the handshake.
+ void Handshake();
+ // Connect and check that it works.
+ void Connect();
+ // Check that the connection was successfully established.
+ void CheckConnected();
+ // Connect and expect it to fail.
+ void ConnectExpectFail();
+ void ConnectWithCipherSuite(uint16_t cipher_suite);
+ // Check that the keys used in the handshake match expectations.
+ void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
+ SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const;
+ // This version guesses some of the values.
+ void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const;
+ // This version assumes defaults.
+ void CheckKeys() const;
+ void CheckGroups(const DataBuffer& groups,
+ std::function<void(SSLNamedGroup)> check_group);
+ void CheckShares(const DataBuffer& shares,
+ std::function<void(SSLNamedGroup)> check_group);
+
+ void ConfigureVersion(uint16_t version);
+ void SetExpectedVersion(uint16_t version);
+ // Expect resumption of a particular type.
+ void ExpectResumption(SessionResumptionMode expected);
+ void DisableAllCiphers();
+ void EnableOnlyStaticRsaCiphers();
+ void EnableOnlyDheCiphers();
+ void EnableSomeEcdhCiphers();
+ void EnableExtendedMasterSecret();
+ void ConfigureSessionCache(SessionResumptionMode client,
+ SessionResumptionMode server);
+ void EnableAlpn();
+ void EnableAlpn(const uint8_t* val, size_t len);
+ void EnsureModelSockets();
+ void CheckAlpn(const std::string& val);
+ void EnableSrtp();
+ void CheckSrtp() const;
+ void SendReceive();
+ void SetupForZeroRtt();
+ void SetupForResume();
+ void ZeroRttSendReceive(
+ bool expect_writable, bool expect_readable,
+ std::function<bool()> post_clienthello_check = nullptr);
+ void Receive(size_t amount);
+ void ExpectExtendedMasterSecret(bool expected);
+ void ExpectEarlyDataAccepted(bool expected);
+ void DisableECDHEServerKeyReuse();
+
+ protected:
+ Mode mode_;
+ TlsAgent* client_;
+ TlsAgent* server_;
+ TlsAgent* client_model_;
+ TlsAgent* server_model_;
+ uint16_t version_;
+ SessionResumptionMode expected_resumption_mode_;
+ std::vector<std::vector<uint8_t>> session_ids_;
+
+ // A simple value of "a", "b". Note that the preferred value of "a" is placed
+ // at the end, because the NSS API follows the now defunct NPN specification,
+ // which places the preferred (and default) entry at the end of the list.
+ // NSS will move this final entry to the front when used with ALPN.
+ const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61};
+
+ private:
+ static inline Mode ToMode(const std::string& str) {
+ return str == "TLS" ? STREAM : DGRAM;
+ }
+
+ void CheckResumption(SessionResumptionMode expected);
+ void CheckExtendedMasterSecret();
+ void CheckEarlyDataAccepted();
+
+ bool expect_extended_master_secret_;
+ bool expect_early_data_accepted_;
+
+ // Track groups and make sure that there are no duplicates.
+ class DuplicateGroupChecker {
+ public:
+ void AddAndCheckGroup(SSLNamedGroup group) {
+ EXPECT_EQ(groups_.end(), groups_.find(group))
+ << "Group " << group << " should not be duplicated";
+ groups_.insert(group);
+ }
+
+ private:
+ std::set<SSLNamedGroup> groups_;
+ };
+};
+
+// A non-parametrized TLS test base.
+class TlsConnectTest : public TlsConnectTestBase {
+ public:
+ TlsConnectTest() : TlsConnectTestBase(STREAM, 0) {}
+};
+
+// A non-parametrized DTLS-only test base.
+class DtlsConnectTest : public TlsConnectTestBase {
+ public:
+ DtlsConnectTest() : TlsConnectTestBase(DGRAM, 0) {}
+};
+
+// A TLS-only test base.
+class TlsConnectStream : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {}
+};
+
+// A TLS-only test base for tests before 1.3
+class TlsConnectStreamPre13 : public TlsConnectStream {};
+
+// A DTLS-only test base.
+class TlsConnectDatagram : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {}
+};
+
+// A generic test class that can be either STREAM or DGRAM and a single version
+// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this
+// should use TEST_P().
+class TlsConnectGeneric
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsConnectGeneric();
+};
+
+// A Pre TLS 1.2 generic test.
+class TlsConnectPre12
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsConnectPre12();
+};
+
+// A TLS 1.2 only generic test.
+class TlsConnectTls12 : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<std::string> {
+ public:
+ TlsConnectTls12();
+};
+
+// A TLS 1.2 only stream test.
+class TlsConnectStreamTls12 : public TlsConnectTestBase {
+ public:
+ TlsConnectStreamTls12()
+ : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_2) {}
+};
+
+// A TLS 1.2+ generic test.
+class TlsConnectTls12Plus
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsConnectTls12Plus();
+};
+
+// A TLS 1.3 only generic test.
+class TlsConnectTls13 : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<std::string> {
+ public:
+ TlsConnectTls13();
+};
+
+// A TLS 1.3 only stream test.
+class TlsConnectStreamTls13 : public TlsConnectTestBase {
+ public:
+ TlsConnectStreamTls13()
+ : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+class TlsConnectDatagram13 : public TlsConnectTestBase {
+ public:
+ TlsConnectDatagram13()
+ : TlsConnectTestBase(DGRAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+// A variant that is used only with Pre13.
+class TlsConnectGenericPre13 : public TlsConnectGeneric {};
+
+class TlsKeyExchangeTest : public TlsConnectGeneric {
+ protected:
+ TlsExtensionCapture* groups_capture_;
+ TlsExtensionCapture* shares_capture_;
+ TlsExtensionCapture* shares_capture2_;
+ TlsInspectorRecordHandshakeMessage* capture_hrr_;
+
+ void EnsureKeyShareSetup();
+ void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
+ std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext);
+ std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext);
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares);
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares,
+ SSLNamedGroup expectedShare2);
+
+ private:
+ void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
+ const std::vector<SSLNamedGroup>& expectedShares,
+ bool expect_hrr);
+};
+
+class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {};
+class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {};
+
+} // namespace nss_test
+
+#endif
diff --git a/nss/gtests/ssl_gtest/tls_filter.cc b/nss/gtests/ssl_gtest/tls_filter.cc
new file mode 100644
index 0000000..4f7d195
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_filter.cc
@@ -0,0 +1,503 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_filter.h"
+#include "sslproto.h"
+
+extern "C" {
+// This is not something that should make you happy.
+#include "libssl_internals.h"
+}
+
+#include <iostream>
+#include "gtest_utils.h"
+#include "tls_agent.h"
+
+namespace nss_test {
+
+PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ bool changed = false;
+ size_t offset = 0U;
+ output->Allocate(input.len());
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ RecordHeader header;
+ DataBuffer record;
+ if (!header.Parse(&parser, &record)) {
+ return KEEP;
+ }
+
+ if (FilterRecord(header, record, &offset, output) != KEEP) {
+ changed = true;
+ } else {
+ offset = header.Write(output, offset, record);
+ }
+ }
+ output->Truncate(offset);
+
+ // Record how many packets we actually touched.
+ if (changed) {
+ ++count_;
+ return (offset == 0) ? DROP : CHANGE;
+ }
+
+ return KEEP;
+}
+
+PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header,
+ const DataBuffer& record,
+ size_t* offset,
+ DataBuffer* output) {
+ DataBuffer filtered;
+ PacketFilter::Action action = FilterRecord(header, record, &filtered);
+ if (action == KEEP) {
+ return KEEP;
+ }
+
+ if (action == DROP) {
+ std::cerr << "record drop: " << record << std::endl;
+ return DROP;
+ }
+
+ const DataBuffer* source = &record;
+ if (action == CHANGE) {
+ EXPECT_GT(0x10000U, filtered.len());
+ std::cerr << "record old: " << record << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ *offset = header.Write(output, *offset, *source);
+ return CHANGE;
+}
+
+bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
+ if (!parser->Read(&content_type_)) {
+ return false;
+ }
+
+ uint32_t version;
+ if (!parser->Read(&version, 2)) {
+ return false;
+ }
+ version_ = version;
+
+ sequence_number_ = 0;
+ if (IsDtls(version)) {
+ uint32_t tmp;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ = static_cast<uint64_t>(tmp) << 32;
+ if (!parser->Read(&tmp, 4)) {
+ return false;
+ }
+ sequence_number_ |= static_cast<uint64_t>(tmp);
+ }
+ return parser->ReadVariable(body, 2);
+}
+
+size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const {
+ offset = buffer->Write(offset, content_type_, 1);
+ offset = buffer->Write(offset, version_, 2);
+ if (is_dtls()) {
+ // write epoch (2 octet), and seqnum (6 octet)
+ offset = buffer->Write(offset, sequence_number_ >> 32, 4);
+ offset = buffer->Write(offset, sequence_number_ & 0xffffffff, 4);
+ }
+ offset = buffer->Write(offset, body.len(), 2);
+ offset = buffer->Write(offset, body);
+ return offset;
+}
+
+PacketFilter::Action TlsHandshakeFilter::FilterRecord(
+ const RecordHeader& record_header, const DataBuffer& input,
+ DataBuffer* output) {
+ // Check that the first byte is as requested.
+ if (record_header.content_type() != kTlsHandshakeType) {
+ return KEEP;
+ }
+
+ bool changed = false;
+ size_t offset = 0U;
+ output->Allocate(input.len()); // Preallocate a little.
+
+ TlsParser parser(input);
+ while (parser.remaining()) {
+ HandshakeHeader header;
+ DataBuffer handshake;
+ if (!header.Parse(&parser, record_header, &handshake)) {
+ return KEEP;
+ }
+
+ DataBuffer filtered;
+ PacketFilter::Action action = FilterHandshake(header, handshake, &filtered);
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "handshake drop: " << handshake << std::endl;
+ continue;
+ }
+
+ const DataBuffer* source = &handshake;
+ if (action == CHANGE) {
+ EXPECT_GT(0x1000000U, filtered.len());
+ changed = true;
+ std::cerr << "handshake old: " << handshake << std::endl;
+ std::cerr << "handshake new: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ offset = header.Write(output, offset, *source);
+ }
+ output->Truncate(offset);
+ return changed ? (offset ? CHANGE : DROP) : KEEP;
+}
+
+bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser,
+ const RecordHeader& header,
+ uint32_t* length) {
+ if (!parser->Read(length, 3)) {
+ return false; // malformed
+ }
+
+ if (!header.is_dtls()) {
+ return true; // nothing left to do
+ }
+
+ // Read and check DTLS parameters
+ uint32_t message_seq_tmp;
+ if (!parser->Read(&message_seq_tmp, 2)) { // sequence number
+ return false;
+ }
+ message_seq_ = message_seq_tmp;
+
+ uint32_t fragment_offset;
+ if (!parser->Read(&fragment_offset, 3)) {
+ return false;
+ }
+
+ uint32_t fragment_length;
+ if (!parser->Read(&fragment_length, 3)) {
+ return false;
+ }
+
+ // All current tests where we are using this code don't fragment.
+ return (fragment_offset == 0 && fragment_length == *length);
+}
+
+bool TlsHandshakeFilter::HandshakeHeader::Parse(
+ TlsParser* parser, const RecordHeader& record_header, DataBuffer* body) {
+ version_ = record_header.version();
+ if (!parser->Read(&handshake_type_)) {
+ return false; // malformed
+ }
+ uint32_t length;
+ if (!ReadLength(parser, record_header, &length)) {
+ return false;
+ }
+
+ return parser->Read(body, length);
+}
+
+size_t TlsHandshakeFilter::HandshakeHeader::Write(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+ offset = buffer->Write(offset, handshake_type(), 1);
+ offset = buffer->Write(offset, body.len(), 3);
+ if (is_dtls()) {
+ offset = buffer->Write(offset, message_seq_, 2);
+ offset = buffer->Write(offset, 0U, 3); // fragment_offset
+ offset = buffer->Write(offset, body.len(), 3);
+ }
+ offset = buffer->Write(offset, body);
+ return offset;
+}
+
+PacketFilter::Action TlsInspectorRecordHandshakeMessage::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ // Only do this once.
+ if (buffer_.len()) {
+ return KEEP;
+ }
+
+ if (header.handshake_type() == handshake_type_) {
+ buffer_ = input;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (header.handshake_type() == handshake_type_) {
+ *output = buffer_;
+ return CHANGE;
+ }
+
+ return KEEP;
+}
+
+PacketFilter::Action TlsConversationRecorder::FilterRecord(
+ const RecordHeader& header, const DataBuffer& input, DataBuffer* output) {
+ buffer_.Append(input);
+ return KEEP;
+}
+
+PacketFilter::Action TlsAlertRecorder::FilterRecord(const RecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (level_ == kTlsAlertFatal) { // already fatal
+ return KEEP;
+ }
+ if (header.content_type() != kTlsAlertType) {
+ return KEEP;
+ }
+
+ std::cerr << "Alert: " << input << std::endl;
+
+ TlsParser parser(input);
+ uint8_t lvl;
+ if (!parser.Read(&lvl)) {
+ return KEEP;
+ }
+ if (lvl == kTlsAlertWarning) { // not strong enough
+ return KEEP;
+ }
+ level_ = lvl;
+ (void)parser.Read(&description_);
+ return KEEP;
+}
+
+ChainedPacketFilter::~ChainedPacketFilter() {
+ for (auto it = filters_.begin(); it != filters_.end(); ++it) {
+ delete *it;
+ }
+}
+
+PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ DataBuffer in(input);
+ bool changed = false;
+ for (auto it = filters_.begin(); it != filters_.end(); ++it) {
+ PacketFilter::Action action = (*it)->Filter(in, output);
+ if (action == DROP) {
+ return DROP;
+ }
+ if (action == CHANGE) {
+ in = *output;
+ changed = true;
+ }
+ }
+ return changed ? CHANGE : KEEP;
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
+ TlsParser parser(input);
+ if (!FindClientHelloExtensions(&parser, header)) {
+ return KEEP;
+ }
+ return FilterExtensions(&parser, input, output);
+ }
+ if (header.handshake_type() == kTlsHandshakeServerHello) {
+ TlsParser parser(input);
+ if (!FindServerHelloExtensions(&parser)) {
+ return KEEP;
+ }
+ return FilterExtensions(&parser, input, output);
+ }
+ return KEEP;
+}
+
+bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser,
+ const Versioned& header) {
+ if (!parser->Skip(2 + 32)) { // version + random
+ return false;
+ }
+ if (!parser->SkipVariable(1)) { // session ID
+ return false;
+ }
+ if (header.is_dtls() && !parser->SkipVariable(1)) { // DTLS cookie
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // cipher suites
+ return false;
+ }
+ if (!parser->SkipVariable(1)) { // compression methods
+ return false;
+ }
+ return true;
+}
+
+bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) {
+ uint32_t vtmp;
+ if (!parser->Read(&vtmp, 2)) {
+ return false;
+ }
+ uint16_t version = static_cast<uint16_t>(vtmp);
+ if (!parser->Skip(32)) { // random
+ return false;
+ }
+ if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
+ if (!parser->SkipVariable(1)) { // session ID
+ return false;
+ }
+ }
+ if (!parser->Skip(2)) { // cipher suite
+ return false;
+ }
+ if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) {
+ if (!parser->Skip(1)) { // compression method
+ return false;
+ }
+ }
+ return true;
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterExtensions(
+ TlsParser* parser, const DataBuffer& input, DataBuffer* output) {
+ size_t length_offset = parser->consumed();
+ uint32_t all_extensions;
+ if (!parser->Read(&all_extensions, 2)) {
+ return KEEP; // no extensions, odd but OK
+ }
+ if (all_extensions != parser->remaining()) {
+ return KEEP; // malformed
+ }
+
+ bool changed = false;
+
+ // Write out the start of the message.
+ output->Allocate(input.len());
+ size_t offset = output->Write(0, input.data(), parser->consumed());
+
+ while (parser->remaining()) {
+ uint32_t extension_type;
+ if (!parser->Read(&extension_type, 2)) {
+ return KEEP; // malformed
+ }
+
+ DataBuffer extension;
+ if (!parser->ReadVariable(&extension, 2)) {
+ return KEEP; // malformed
+ }
+
+ DataBuffer filtered;
+ PacketFilter::Action action =
+ FilterExtension(extension_type, extension, &filtered);
+ if (action == DROP) {
+ changed = true;
+ std::cerr << "extension drop: " << extension << std::endl;
+ continue;
+ }
+
+ const DataBuffer* source = &extension;
+ if (action == CHANGE) {
+ EXPECT_GT(0x10000U, filtered.len());
+ changed = true;
+ std::cerr << "extension old: " << extension << std::endl;
+ std::cerr << "extension new: " << filtered << std::endl;
+ source = &filtered;
+ }
+
+ // Write out extension.
+ offset = output->Write(offset, extension_type, 2);
+ offset = output->Write(offset, source->len(), 2);
+ if (source->len() > 0) {
+ offset = output->Write(offset, *source);
+ }
+ }
+ output->Truncate(offset);
+
+ if (changed) {
+ size_t newlen = output->len() - length_offset - 2;
+ EXPECT_GT(0x10000U, newlen);
+ if (newlen >= 0x10000) {
+ return KEEP; // bad: size increased too much
+ }
+ output->Write(length_offset, newlen, 2);
+ return CHANGE;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionCapture::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type == extension_ && (last_ || !captured_)) {
+ data_.Assign(input);
+ captured_ = true;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action TlsExtensionReplacer::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ *output = data_;
+ return CHANGE;
+}
+
+PacketFilter::Action TlsExtensionDropper::FilterExtension(
+ uint16_t extension_type, const DataBuffer& input, DataBuffer* output) {
+ if (extension_type == extension_) {
+ return DROP;
+ }
+ return KEEP;
+}
+
+PacketFilter::Action AfterRecordN::FilterRecord(const RecordHeader& header,
+ const DataBuffer& body,
+ DataBuffer* out) {
+ if (counter_++ == record_) {
+ DataBuffer buf;
+ header.Write(&buf, 0, body);
+ src_->SendDirect(buf);
+ dest_->Handshake();
+ func_();
+ return DROP;
+ }
+
+ return KEEP;
+}
+
+PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (header.handshake_type() == kTlsHandshakeClientKeyExchange) {
+ EXPECT_EQ(SECSuccess,
+ SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd()));
+ }
+ return KEEP;
+}
+
+PacketFilter::Action SelectiveDropFilter::Filter(const DataBuffer& input,
+ DataBuffer* output) {
+ if (counter_ >= 32) {
+ return KEEP;
+ }
+ return ((1 << counter_++) & pattern_) ? DROP : KEEP;
+}
+
+PacketFilter::Action TlsInspectorClientHelloVersionSetter::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
+ *output = input;
+ output->Write(0, version_, 2);
+ return CHANGE;
+ }
+ return KEEP;
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/tls_filter.h b/nss/gtests/ssl_gtest/tls_filter.h
new file mode 100644
index 0000000..fa2e387
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_filter.h
@@ -0,0 +1,343 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_filter_h_
+#define tls_filter_h_
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "test_io.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// Abstract filter that operates on entire (D)TLS records.
+class TlsRecordFilter : public PacketFilter {
+ public:
+ TlsRecordFilter() : count_(0) {}
+
+ // External interface. Overrides PacketFilter.
+ PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
+
+ // Report how many packets were altered by the filter.
+ size_t filtered_packets() const { return count_; }
+
+ class Versioned {
+ public:
+ Versioned() : version_(0) {}
+ explicit Versioned(uint16_t version) : version_(version) {}
+
+ bool is_dtls() const { return IsDtls(version_); }
+ uint16_t version() const { return version_; }
+
+ protected:
+ uint16_t version_;
+ };
+
+ class RecordHeader : public Versioned {
+ public:
+ RecordHeader() : Versioned(), content_type_(0), sequence_number_(0) {}
+ RecordHeader(uint16_t version, uint8_t content_type,
+ uint64_t sequence_number)
+ : Versioned(version),
+ content_type_(content_type),
+ sequence_number_(sequence_number) {}
+
+ uint8_t content_type() const { return content_type_; }
+ uint64_t sequence_number() const { return sequence_number_; }
+ size_t header_length() const { return is_dtls() ? 11 : 3; }
+
+ // Parse the header; return true if successful; body in an outparam if OK.
+ bool Parse(TlsParser* parser, DataBuffer* body);
+ // Write the header and body to a buffer at the given offset.
+ // Return the offset of the end of the write.
+ size_t Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const;
+
+ private:
+ uint8_t content_type_;
+ uint64_t sequence_number_;
+ };
+
+ protected:
+ // There are two filter functions which can be overriden. Both are
+ // called with the header and the record but the outer one is called
+ // with a raw pointer to let you write into the buffer and lets you
+ // do anything with this section of the stream. The inner one
+ // just lets you change the record contents. By default, the
+ // outer one calls the inner one, so if you override the outer
+ // one, the inner one is never called unless you call it yourself.
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& record,
+ size_t* offset, DataBuffer* output);
+
+ // The record filter receives the record contentType, version and DTLS
+ // sequence number (which is zero for TLS), plus the existing record payload.
+ // It returns an action (KEEP, CHANGE, DROP). It writes to the `changed`
+ // outparam with the new record contents if it chooses to CHANGE the record.
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& data,
+ DataBuffer* changed) {
+ return KEEP;
+ }
+
+ private:
+ size_t count_;
+};
+
+// Abstract filter that operates on handshake messages rather than records.
+// This assumes that the handshake messages are written in a block as entire
+// records and that they don't span records or anything crazy like that.
+class TlsHandshakeFilter : public TlsRecordFilter {
+ public:
+ TlsHandshakeFilter() {}
+
+ class HandshakeHeader : public Versioned {
+ public:
+ HandshakeHeader() : Versioned(), handshake_type_(0), message_seq_(0) {}
+
+ uint8_t handshake_type() const { return handshake_type_; }
+ bool Parse(TlsParser* parser, const RecordHeader& record_header,
+ DataBuffer* body);
+ size_t Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const;
+
+ private:
+ // Reads the length from the record header.
+ // This also reads the DTLS fragment information and checks it.
+ bool ReadLength(TlsParser* parser, const RecordHeader& header,
+ uint32_t* length);
+
+ uint8_t handshake_type_;
+ uint16_t message_seq_;
+ // fragment_offset is always zero in these tests.
+ };
+
+ protected:
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) = 0;
+
+ private:
+};
+
+// Make a copy of the first instance of a handshake message.
+class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter {
+ public:
+ TlsInspectorRecordHandshakeMessage(uint8_t handshake_type)
+ : handshake_type_(handshake_type), buffer_() {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ const DataBuffer& buffer() const { return buffer_; }
+
+ private:
+ uint8_t handshake_type_;
+ DataBuffer buffer_;
+};
+
+// Replace all instances of a handshake message.
+class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter {
+ public:
+ TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type,
+ const DataBuffer& replacement)
+ : handshake_type_(handshake_type), buffer_(replacement) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ uint8_t handshake_type_;
+ DataBuffer buffer_;
+};
+
+// Make a copy of the complete conversation.
+class TlsConversationRecorder : public TlsRecordFilter {
+ public:
+ TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {}
+
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ DataBuffer& buffer_;
+};
+
+// Records an alert. If an alert has already been recorded, it won't save the
+// new alert unless the old alert is a warning and the new one is fatal.
+class TlsAlertRecorder : public TlsRecordFilter {
+ public:
+ TlsAlertRecorder() : level_(255), description_(255) {}
+
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ uint8_t level() const { return level_; }
+ uint8_t description() const { return description_; }
+
+ private:
+ uint8_t level_;
+ uint8_t description_;
+};
+
+// Runs multiple packet filters in series.
+class ChainedPacketFilter : public PacketFilter {
+ public:
+ ChainedPacketFilter() {}
+ ChainedPacketFilter(const std::vector<PacketFilter*> filters)
+ : filters_(filters.begin(), filters.end()) {}
+ virtual ~ChainedPacketFilter();
+
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output);
+
+ // Takes ownership of the filter.
+ void Add(PacketFilter* filter) { filters_.push_back(filter); }
+
+ private:
+ std::vector<PacketFilter*> filters_;
+};
+
+class TlsExtensionFilter : public TlsHandshakeFilter {
+ protected:
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) = 0;
+
+ public:
+ static bool FindClientHelloExtensions(TlsParser* parser,
+ const Versioned& header);
+ static bool FindServerHelloExtensions(TlsParser* parser);
+
+ private:
+ PacketFilter::Action FilterExtensions(TlsParser* parser,
+ const DataBuffer& input,
+ DataBuffer* output);
+};
+
+class TlsExtensionCapture : public TlsExtensionFilter {
+ public:
+ TlsExtensionCapture(uint16_t ext, bool last = false)
+ : extension_(ext), captured_(false), last_(last), data_() {}
+
+ const DataBuffer& extension() const { return data_; }
+ bool captured() const { return captured_; }
+
+ protected:
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ bool captured_;
+ bool last_;
+ DataBuffer data_;
+};
+
+class TlsExtensionReplacer : public TlsExtensionFilter {
+ public:
+ TlsExtensionReplacer(uint16_t extension, const DataBuffer& data)
+ : extension_(extension), data_(data) {}
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionDropper : public TlsExtensionFilter {
+ public:
+ TlsExtensionDropper(uint16_t extension) : extension_(extension) {}
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer&, DataBuffer*) override;
+
+ private:
+ uint16_t extension_;
+};
+
+class TlsAgent;
+typedef std::function<void(void)> VoidFunction;
+
+class AfterRecordN : public TlsRecordFilter {
+ public:
+ AfterRecordN(TlsAgent* src, TlsAgent* dest, unsigned int record,
+ VoidFunction func)
+ : src_(src), dest_(dest), record_(record), func_(func), counter_(0) {}
+
+ virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ const DataBuffer& body,
+ DataBuffer* out) override;
+
+ private:
+ TlsAgent* src_;
+ TlsAgent* dest_;
+ unsigned int record_;
+ VoidFunction func_;
+ unsigned int counter_;
+};
+
+// When we see the ClientKeyExchange from |client|, increment the
+// ClientHelloVersion on |server|.
+class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
+ public:
+ TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ TlsAgent* server_;
+};
+
+// This class selectively drops complete writes. This relies on the fact that
+// writes in libssl are on record boundaries.
+class SelectiveDropFilter : public PacketFilter {
+ public:
+ SelectiveDropFilter(uint32_t pattern) : pattern_(pattern), counter_(0) {}
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override;
+
+ private:
+ const uint32_t pattern_;
+ uint8_t counter_;
+};
+
+// Set the version number in the ClientHello.
+class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
+ public:
+ TlsInspectorClientHelloVersionSetter(uint16_t version) : version_(version) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output);
+
+ private:
+ uint16_t version_;
+};
+
+} // namespace nss_test
+
+#endif
diff --git a/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc b/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
new file mode 100644
index 0000000..51ff938
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_hkdf_unittest.cc
@@ -0,0 +1,262 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include <memory>
+#include "nss.h"
+#include "pk11pub.h"
+#include "tls13hkdf.h"
+
+#include "databuffer.h"
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+
+namespace nss_test {
+
+const uint8_t kKey1Data[] = {
+ 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
+ 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23,
+ 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f};
+const DataBuffer kKey1(kKey1Data, sizeof(kKey1Data));
+
+// The same as key1 but with the first byte
+// 0x01.
+const uint8_t kKey2Data[] = {
+ 0x01, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
+ 0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17,
+ 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23,
+ 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f};
+const DataBuffer kKey2(kKey2Data, sizeof(kKey2Data));
+
+const char kLabelMasterSecret[] = "master secret";
+
+const uint8_t kSessionHash[] = {
+ 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb,
+ 0xfc, 0xfd, 0xfe, 0xff, 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7,
+ 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef, 0xd0, 0xd1, 0xd2, 0xd3,
+ 0xd4, 0xd5, 0xd6, 0xd7, 0xd8, 0xd9, 0xda, 0xdb, 0xdc, 0xdd, 0xde, 0xdf,
+ 0xe0, 0xe1, 0xe2, 0xe3, 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb,
+ 0xec, 0xed, 0xee, 0xef, 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7,
+ 0xf8, 0xf9, 0xfa, 0xfb, 0xfc, 0xfd, 0xfe, 0xff, 0xe0, 0xe1, 0xe2, 0xe3,
+ 0xe4, 0xe5, 0xe6, 0xe7, 0xe8, 0xe9, 0xea, 0xeb, 0xec, 0xed, 0xee, 0xef,
+ 0xf0, 0xf1, 0xf2, 0xf3, 0xf4, 0xf5, 0xf6, 0xf7, 0xf8, 0xf9, 0xfa, 0xfb,
+ 0xfc, 0xfd, 0xfe, 0xff,
+};
+
+const size_t kHashLength[] = {
+ 0, /* ssl_hash_none */
+ 16, /* ssl_hash_md5 */
+ 20, /* ssl_hash_sha1 */
+ 28, /* ssl_hash_sha224 */
+ 32, /* ssl_hash_sha256 */
+ 48, /* ssl_hash_sha384 */
+ 64, /* ssl_hash_sha512 */
+};
+
+const std::string kHashName[] = {"None", "MD5", "SHA-1", "SHA-224",
+ "SHA-256", "SHA-384", "SHA-512"};
+
+static void ImportKey(ScopedPK11SymKey* to, const DataBuffer& key,
+ PK11SlotInfo* slot) {
+ SECItem key_item = {siBuffer, const_cast<uint8_t*>(key.data()),
+ static_cast<unsigned int>(key.len())};
+
+ PK11SymKey* inner =
+ PK11_ImportSymKey(slot, CKM_SSL3_MASTER_KEY_DERIVE, PK11_OriginUnwrap,
+ CKA_DERIVE, &key_item, NULL);
+ ASSERT_NE(nullptr, inner);
+ to->reset(inner);
+}
+
+static void DumpData(const std::string& label, const uint8_t* buf, size_t len) {
+ DataBuffer d(buf, len);
+
+ std::cerr << label << ": " << d << std::endl;
+}
+
+void DumpKey(const std::string& label, ScopedPK11SymKey& key) {
+ SECStatus rv = PK11_ExtractKeyValue(key.get());
+ ASSERT_EQ(SECSuccess, rv);
+
+ SECItem* key_data = PK11_GetKeyData(key.get());
+ ASSERT_NE(nullptr, key_data);
+
+ DumpData(label, key_data->data, key_data->len);
+}
+
+extern "C" {
+extern char ssl_trace;
+extern FILE* ssl_trace_iob;
+}
+
+class TlsHkdfTest : public ::testing::Test,
+ public ::testing::WithParamInterface<SSLHashType> {
+ public:
+ TlsHkdfTest()
+ : k1_(), k2_(), hash_type_(GetParam()), slot_(PK11_GetInternalSlot()) {
+ EXPECT_NE(nullptr, slot_);
+ char* ev = getenv("SSLTRACE");
+ if (ev && ev[0]) {
+ ssl_trace = atoi(ev);
+ ssl_trace_iob = stderr;
+ }
+ }
+
+ void SetUp() {
+ ImportKey(&k1_, kKey1, slot_.get());
+ ImportKey(&k2_, kKey2, slot_.get());
+ }
+
+ void VerifyKey(const ScopedPK11SymKey& key, const DataBuffer& expected) {
+ SECStatus rv = PK11_ExtractKeyValue(key.get());
+ ASSERT_EQ(SECSuccess, rv);
+
+ SECItem* key_data = PK11_GetKeyData(key.get());
+ ASSERT_NE(nullptr, key_data);
+
+ EXPECT_EQ(expected.len(), key_data->len);
+ EXPECT_EQ(0, memcmp(expected.data(), key_data->data, expected.len()));
+ }
+
+ void HkdfExtract(const ScopedPK11SymKey& ikmk1, const ScopedPK11SymKey& ikmk2,
+ SSLHashType base_hash, const DataBuffer& expected) {
+ std::cerr << "Hash = " << kHashName[base_hash] << std::endl;
+
+ PK11SymKey* prk = nullptr;
+ SECStatus rv = tls13_HkdfExtract(ikmk1.get(), ikmk2.get(), base_hash, &prk);
+ ASSERT_EQ(SECSuccess, rv);
+ ScopedPK11SymKey prkk(prk);
+
+ DumpKey("Output", prkk);
+ VerifyKey(prkk, expected);
+ }
+
+ void HkdfExpandLabel(ScopedPK11SymKey* prk, SSLHashType base_hash,
+ const uint8_t* session_hash, size_t session_hash_len,
+ const char* label, size_t label_len,
+ const DataBuffer& expected) {
+ std::cerr << "Hash = " << kHashName[base_hash] << std::endl;
+
+ std::vector<uint8_t> output(expected.len());
+
+ SECStatus rv = tls13_HkdfExpandLabelRaw(prk->get(), base_hash, session_hash,
+ session_hash_len, label, label_len,
+ &output[0], output.size());
+ ASSERT_EQ(SECSuccess, rv);
+ DumpData("Output", &output[0], output.size());
+ EXPECT_EQ(0, memcmp(expected.data(), &output[0], expected.len()));
+ }
+
+ protected:
+ ScopedPK11SymKey k1_;
+ ScopedPK11SymKey k2_;
+ SSLHashType hash_type_;
+
+ private:
+ ScopedPK11SlotInfo slot_;
+};
+
+TEST_P(TlsHkdfTest, HkdfNullNull) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0x33, 0xad, 0x0a, 0x1c, 0x60, 0x7e, 0xc0, 0x3b, 0x09, 0xe6, 0xcd,
+ 0x98, 0x93, 0x68, 0x0c, 0xe2, 0x10, 0xad, 0xf3, 0x00, 0xaa, 0x1f,
+ 0x26, 0x60, 0xe1, 0xb2, 0x2e, 0x10, 0xf1, 0x70, 0xf9, 0x2a},
+ {0x7e, 0xe8, 0x20, 0x6f, 0x55, 0x70, 0x02, 0x3e, 0x6d, 0xc7, 0x51, 0x9e,
+ 0xb1, 0x07, 0x3b, 0xc4, 0xe7, 0x91, 0xad, 0x37, 0xb5, 0xc3, 0x82, 0xaa,
+ 0x10, 0xba, 0x18, 0xe2, 0x35, 0x7e, 0x71, 0x69, 0x71, 0xf9, 0x36, 0x2f,
+ 0x2c, 0x2f, 0xe2, 0xa7, 0x6b, 0xfd, 0x78, 0xdf, 0xec, 0x4e, 0xa9, 0xb5}};
+
+ const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]);
+ HkdfExtract(nullptr, nullptr, hash_type_, expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfKey1Only) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0x11, 0x87, 0x38, 0x28, 0xa9, 0x19, 0x78, 0x11, 0x33, 0x91, 0x24,
+ 0xb5, 0x8a, 0x1b, 0xb0, 0x9f, 0x7f, 0x0d, 0x8d, 0xbb, 0x10, 0xf4,
+ 0x9c, 0x54, 0xbd, 0x1f, 0xd8, 0x85, 0xcd, 0x15, 0x30, 0x33},
+ {0x51, 0xb1, 0xd5, 0xb4, 0x59, 0x79, 0x79, 0x08, 0x4a, 0x15, 0xb2, 0xdb,
+ 0x84, 0xd3, 0xd6, 0xbc, 0xfc, 0x93, 0x45, 0xd9, 0xdc, 0x74, 0xda, 0x1a,
+ 0x57, 0xc2, 0x76, 0x9f, 0x3f, 0x83, 0x45, 0x2f, 0xf6, 0xf3, 0x56, 0x1f,
+ 0x58, 0x63, 0xdb, 0x88, 0xda, 0x40, 0xce, 0x63, 0x7d, 0x24, 0x37, 0xf3}};
+
+ const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]);
+ HkdfExtract(k1_, nullptr, hash_type_, expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfKey2Only) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {
+ 0x2f, 0x5f, 0x78, 0xd0, 0xa4, 0xc4, 0x36, 0xee, 0x6c, 0x8a, 0x4e,
+ 0xf9, 0xd0, 0x43, 0x81, 0x02, 0x13, 0xfd, 0x47, 0x83, 0x63, 0x3a,
+ 0xd2, 0xe1, 0x40, 0x6d, 0x2d, 0x98, 0x00, 0xfd, 0xc1, 0x87,
+ },
+ {0x7b, 0x40, 0xf9, 0xef, 0x91, 0xff, 0xc9, 0xd1, 0x29, 0x24, 0x5c, 0xbf,
+ 0xf8, 0x82, 0x76, 0x68, 0xae, 0x4b, 0x63, 0xe8, 0x03, 0xdd, 0x39, 0xa8,
+ 0xd4, 0x6a, 0xf6, 0xe5, 0xec, 0xea, 0xf8, 0x7d, 0x91, 0x71, 0x81, 0xf1,
+ 0xdb, 0x3b, 0xaf, 0xbf, 0xde, 0x71, 0x61, 0x15, 0xeb, 0xb5, 0x5f, 0x68}};
+
+ const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]);
+ HkdfExtract(nullptr, k2_, hash_type_, expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfKey1Key2) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {
+ 0x79, 0x53, 0xb8, 0xdd, 0x6b, 0x98, 0xce, 0x00, 0xb7, 0xdc, 0xe8,
+ 0x03, 0x70, 0x8c, 0xe3, 0xac, 0x06, 0x8b, 0x22, 0xfd, 0x0e, 0x34,
+ 0x48, 0xe6, 0xe5, 0xe0, 0x8a, 0xd6, 0x16, 0x18, 0xe5, 0x48,
+ },
+ {0x01, 0x93, 0xc0, 0x07, 0x3f, 0x6a, 0x83, 0x0e, 0x2e, 0x4f, 0xb2, 0x58,
+ 0xe4, 0x00, 0x08, 0x5c, 0x68, 0x9c, 0x37, 0x32, 0x00, 0x37, 0xff, 0xc3,
+ 0x1c, 0x5b, 0x98, 0x0b, 0x02, 0x92, 0x3f, 0xfd, 0x73, 0x5a, 0x6f, 0x2a,
+ 0x95, 0xa3, 0xee, 0xf6, 0xd6, 0x8e, 0x6f, 0x86, 0xea, 0x63, 0xf8, 0x33}};
+
+ const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]);
+ HkdfExtract(k1_, k2_, hash_type_, expected_data);
+}
+
+TEST_P(TlsHkdfTest, HkdfExpandLabel) {
+ const uint8_t tv[][48] = {
+ {/* ssl_hash_none */},
+ {/* ssl_hash_md5 */},
+ {/* ssl_hash_sha1 */},
+ {/* ssl_hash_sha224 */},
+ {0x34, 0x7c, 0x67, 0x80, 0xff, 0x0b, 0xba, 0xd7, 0x1c, 0x28, 0x3b,
+ 0x16, 0xeb, 0x2f, 0x9c, 0xf6, 0x2d, 0x24, 0xe6, 0xcd, 0xb6, 0x13,
+ 0xd5, 0x17, 0x76, 0x54, 0x8c, 0xb0, 0x7d, 0xcd, 0xe7, 0x4c},
+ {0x4b, 0x1e, 0x5e, 0xc1, 0x49, 0x30, 0x78, 0xea, 0x35, 0xbd, 0x3f, 0x01,
+ 0x04, 0xe6, 0x1a, 0xea, 0x14, 0xcc, 0x18, 0x2a, 0xd1, 0xc4, 0x76, 0x21,
+ 0xc4, 0x64, 0xc0, 0x4e, 0x4b, 0x36, 0x16, 0x05, 0x6f, 0x04, 0xab, 0xe9,
+ 0x43, 0xb1, 0x2d, 0xa8, 0xa7, 0x17, 0x9a, 0x5f, 0x09, 0x91, 0x7d, 0x1f}};
+
+ const DataBuffer expected_data(tv[hash_type_], kHashLength[hash_type_]);
+ HkdfExpandLabel(&k1_, hash_type_, kSessionHash, kHashLength[hash_type_],
+ kLabelMasterSecret, strlen(kLabelMasterSecret),
+ expected_data);
+}
+
+static const SSLHashType kHashTypes[] = {ssl_hash_sha256, ssl_hash_sha384};
+INSTANTIATE_TEST_CASE_P(AllHashFuncs, TlsHkdfTest,
+ ::testing::ValuesIn(kHashTypes));
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/tls_parser.cc b/nss/gtests/ssl_gtest/tls_parser.cc
new file mode 100644
index 0000000..e4c06aa
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_parser.cc
@@ -0,0 +1,73 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_parser.h"
+
+namespace nss_test {
+
+bool TlsParser::Read(uint8_t* val) {
+ if (remaining() < 1) {
+ return false;
+ }
+ *val = *ptr();
+ consume(1);
+ return true;
+}
+
+bool TlsParser::Read(uint32_t* val, size_t size) {
+ if (size > sizeof(uint32_t)) {
+ return false;
+ }
+
+ uint32_t v = 0;
+ for (size_t i = 0; i < size; ++i) {
+ uint8_t tmp;
+ if (!Read(&tmp)) {
+ return false;
+ }
+
+ v = (v << 8) | tmp;
+ }
+
+ *val = v;
+ return true;
+}
+
+bool TlsParser::Read(DataBuffer* val, size_t len) {
+ if (remaining() < len) {
+ return false;
+ }
+
+ val->Assign(ptr(), len);
+ consume(len);
+ return true;
+}
+
+bool TlsParser::ReadVariable(DataBuffer* val, size_t len_size) {
+ uint32_t len;
+ if (!Read(&len, len_size)) {
+ return false;
+ }
+ return Read(val, len);
+}
+
+bool TlsParser::Skip(size_t len) {
+ if (len > remaining()) {
+ return false;
+ }
+ consume(len);
+ return true;
+}
+
+bool TlsParser::SkipVariable(size_t len_size) {
+ uint32_t len;
+ if (!Read(&len, len_size)) {
+ return false;
+ }
+ return Skip(len);
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/tls_parser.h b/nss/gtests/ssl_gtest/tls_parser.h
new file mode 100644
index 0000000..c79d45a
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_parser.h
@@ -0,0 +1,131 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_parser_h_
+#define tls_parser_h_
+
+#include <cstdint>
+#include <cstring>
+#include <memory>
+#if defined(WIN32) || defined(WIN64)
+#include <winsock2.h>
+#else
+#include <arpa/inet.h>
+#endif
+#include "databuffer.h"
+
+namespace nss_test {
+
+const uint8_t kTlsChangeCipherSpecType = 20;
+const uint8_t kTlsAlertType = 21;
+const uint8_t kTlsHandshakeType = 22;
+const uint8_t kTlsApplicationDataType = 23;
+
+const uint8_t kTlsHandshakeClientHello = 1;
+const uint8_t kTlsHandshakeServerHello = 2;
+const uint8_t kTlsHandshakeHelloRetryRequest = 6;
+const uint8_t kTlsHandshakeEncryptedExtensions = 8;
+const uint8_t kTlsHandshakeCertificate = 11;
+const uint8_t kTlsHandshakeServerKeyExchange = 12;
+const uint8_t kTlsHandshakeCertificateVerify = 15;
+const uint8_t kTlsHandshakeClientKeyExchange = 16;
+const uint8_t kTlsHandshakeFinished = 20;
+
+const uint8_t kTlsAlertWarning = 1;
+const uint8_t kTlsAlertFatal = 2;
+
+const uint8_t kTlsAlertUnexpectedMessage = 10;
+const uint8_t kTlsAlertBadRecordMac = 20;
+const uint8_t kTlsAlertHandshakeFailure = 40;
+const uint8_t kTlsAlertIllegalParameter = 47;
+const uint8_t kTlsAlertDecodeError = 50;
+const uint8_t kTlsAlertDecryptError = 51;
+const uint8_t kTlsAlertMissingExtension = 109;
+const uint8_t kTlsAlertUnsupportedExtension = 110;
+const uint8_t kTlsAlertUnrecognizedName = 112;
+const uint8_t kTlsAlertNoApplicationProtocol = 120;
+
+const uint8_t kTlsFakeChangeCipherSpec[] = {
+ kTlsChangeCipherSpecType, // Type
+ 0xfe,
+ 0xff, // Version
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x00,
+ 0x10, // Fictitious sequence #
+ 0x00,
+ 0x01, // Length
+ 0x01 // Value
+};
+
+static const uint8_t kTls13PskKe = 0;
+static const uint8_t kTls13PskDhKe = 1;
+static const uint8_t kTls13PskAuth = 0;
+static const uint8_t kTls13PskSignAuth = 1;
+
+inline bool IsDtls(uint16_t version) { return (version & 0x8000) == 0x8000; }
+
+inline uint16_t NormalizeTlsVersion(uint16_t version) {
+ if (version == 0xfeff) {
+ return 0x0302; // special: DTLS 1.0 == TLS 1.1
+ }
+ if (IsDtls(version)) {
+ return (version ^ 0xffff) + 0x0201;
+ }
+ return version;
+}
+
+inline uint16_t TlsVersionToDtlsVersion(uint16_t version) {
+ if (version == 0x0302) {
+ return 0xfeff;
+ }
+ if (version == 0x0304) {
+ return version;
+ }
+ return 0xffff - version + 0x0201;
+}
+
+inline size_t WriteVariable(DataBuffer* target, size_t index,
+ const DataBuffer& buf, size_t len_size) {
+ index = target->Write(index, static_cast<uint32_t>(buf.len()), len_size);
+ return target->Write(index, buf.data(), buf.len());
+}
+
+class TlsParser {
+ public:
+ TlsParser(const uint8_t* data, size_t len) : buffer_(data, len), offset_(0) {}
+ explicit TlsParser(const DataBuffer& buf) : buffer_(buf), offset_(0) {}
+
+ bool Read(uint8_t* val);
+ // Read an integral type of specified width.
+ bool Read(uint32_t* val, size_t size);
+ // Reads len bytes into dest buffer, overwriting it.
+ bool Read(DataBuffer* dest, size_t len);
+ // Reads bytes into dest buffer, overwriting it. The number of bytes is
+ // determined by reading from len_size bytes from the stream first.
+ bool ReadVariable(DataBuffer* dest, size_t len_size);
+
+ bool Skip(size_t len);
+ bool SkipVariable(size_t len_size);
+
+ size_t consumed() const { return offset_; }
+ size_t remaining() const { return buffer_.len() - offset_; }
+
+ private:
+ void consume(size_t len) { offset_ += len; }
+ const uint8_t* ptr() const { return buffer_.data() + offset_; }
+
+ DataBuffer buffer_;
+ size_t offset_;
+};
+
+} // namespace nss_test
+
+#endif