summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest
diff options
context:
space:
mode:
Diffstat (limited to 'nss/gtests/ssl_gtest')
-rw-r--r--nss/gtests/ssl_gtest/Makefile5
-rw-r--r--nss/gtests/ssl_gtest/databuffer.h191
-rw-r--r--nss/gtests/ssl_gtest/gtest_utils.h2
-rw-r--r--nss/gtests/ssl_gtest/libssl_internals.c99
-rw-r--r--nss/gtests/ssl_gtest/libssl_internals.h13
-rw-r--r--nss/gtests/ssl_gtest/manifest.mn14
-rw-r--r--nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc199
-rw-r--r--nss/gtests/ssl_gtest/ssl_agent_unittest.cc28
-rw-r--r--nss/gtests/ssl_gtest/ssl_auth_unittest.cc195
-rw-r--r--nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc83
-rw-r--r--nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc45
-rw-r--r--nss/gtests/ssl_gtest/ssl_damage_unittest.cc64
-rw-r--r--nss/gtests/ssl_gtest/ssl_dhe_unittest.cc100
-rw-r--r--nss/gtests/ssl_gtest/ssl_drop_unittest.cc16
-rw-r--r--nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc74
-rw-r--r--nss/gtests/ssl_gtest/ssl_ems_unittest.cc6
-rw-r--r--nss/gtests/ssl_gtest/ssl_exporter_unittest.cc34
-rw-r--r--nss/gtests/ssl_gtest/ssl_extension_unittest.cc568
-rw-r--r--nss/gtests/ssl_gtest/ssl_fragment_unittest.cc157
-rw-r--r--nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc192
-rw-r--r--nss/gtests/ssl_gtest/ssl_gather_unittest.cc143
-rw-r--r--nss/gtests/ssl_gtest/ssl_gtest.cc14
-rw-r--r--nss/gtests/ssl_gtest/ssl_gtest.gyp31
-rw-r--r--nss/gtests/ssl_gtest/ssl_hrr_unittest.cc117
-rw-r--r--nss/gtests/ssl_gtest/ssl_loopback_unittest.cc163
-rw-r--r--nss/gtests/ssl_gtest/ssl_resumption_unittest.cc161
-rw-r--r--nss/gtests/ssl_gtest/ssl_skip_unittest.cc143
-rw-r--r--nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc50
-rw-r--r--nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc58
-rw-r--r--nss/gtests/ssl_gtest/ssl_version_unittest.cc59
-rw-r--r--nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc394
-rw-r--r--nss/gtests/ssl_gtest/test_io.cc386
-rw-r--r--nss/gtests/ssl_gtest/test_io.h97
-rw-r--r--nss/gtests/ssl_gtest/tls_agent.cc371
-rw-r--r--nss/gtests/ssl_gtest/tls_agent.h135
-rw-r--r--nss/gtests/ssl_gtest/tls_connect.cc175
-rw-r--r--nss/gtests/ssl_gtest/tls_connect.h86
-rw-r--r--nss/gtests/ssl_gtest/tls_filter.cc326
-rw-r--r--nss/gtests/ssl_gtest/tls_filter.h231
-rw-r--r--nss/gtests/ssl_gtest/tls_parser.cc73
-rw-r--r--nss/gtests/ssl_gtest/tls_parser.h131
-rw-r--r--nss/gtests/ssl_gtest/tls_protect.cc145
-rw-r--r--nss/gtests/ssl_gtest/tls_protect.h76
43 files changed, 3815 insertions, 1835 deletions
diff --git a/nss/gtests/ssl_gtest/Makefile b/nss/gtests/ssl_gtest/Makefile
index dfb8df9..a9a9290 100644
--- a/nss/gtests/ssl_gtest/Makefile
+++ b/nss/gtests/ssl_gtest/Makefile
@@ -33,11 +33,8 @@ ifdef NSS_SSL_ENABLE_ZLIB
include $(CORE_DEPTH)/coreconf/zlib.mk
endif
-ifndef NSS_ENABLE_TLS_1_3
-NSS_DISABLE_TLS_1_3=1
-endif
-
ifdef NSS_DISABLE_TLS_1_3
+NSS_DISABLE_TLS_1_3=1
# Run parameterized tests only, for which we can easily exclude TLS 1.3
CPPSRCS := $(filter-out $(shell grep -l '^TEST_F' $(CPPSRCS)), $(CPPSRCS))
CFLAGS += -DNSS_DISABLE_TLS_1_3
diff --git a/nss/gtests/ssl_gtest/databuffer.h b/nss/gtests/ssl_gtest/databuffer.h
deleted file mode 100644
index e7236d4..0000000
--- a/nss/gtests/ssl_gtest/databuffer.h
+++ /dev/null
@@ -1,191 +0,0 @@
-/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
-/* vim: set ts=2 et sw=2 tw=80: */
-/* This Source Code Form is subject to the terms of the Mozilla Public
- * License, v. 2.0. If a copy of the MPL was not distributed with this file,
- * You can obtain one at http://mozilla.org/MPL/2.0/. */
-
-#ifndef databuffer_h__
-#define databuffer_h__
-
-#include <algorithm>
-#include <cassert>
-#include <cstring>
-#include <iomanip>
-#include <iostream>
-#if defined(WIN32) || defined(WIN64)
-#include <winsock2.h>
-#else
-#include <arpa/inet.h>
-#endif
-
-extern bool g_ssl_gtest_verbose;
-
-namespace nss_test {
-
-class DataBuffer {
- public:
- DataBuffer() : data_(nullptr), len_(0) {}
- DataBuffer(const uint8_t* data, size_t len) : data_(nullptr), len_(0) {
- Assign(data, len);
- }
- DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) {
- Assign(other);
- }
- ~DataBuffer() { delete[] data_; }
-
- DataBuffer& operator=(const DataBuffer& other) {
- if (&other != this) {
- Assign(other);
- }
- return *this;
- }
-
- void Allocate(size_t len) {
- delete[] data_;
- data_ = new uint8_t[len ? len : 1]; // Don't depend on new [0].
- len_ = len;
- }
-
- void Truncate(size_t len) { len_ = std::min(len_, len); }
-
- void Assign(const DataBuffer& other) { Assign(other.data(), other.len()); }
-
- void Assign(const uint8_t* data, size_t len) {
- if (data) {
- Allocate(len);
- memcpy(static_cast<void*>(data_), static_cast<const void*>(data), len);
- } else {
- assert(len == 0);
- data_ = nullptr;
- len_ = 0;
- }
- }
-
- // Write will do a new allocation and expand the size of the buffer if needed.
- // Returns the offset of the end of the write.
- size_t Write(size_t index, const uint8_t* val, size_t count) {
- assert(val);
- if (index + count > len_) {
- size_t newlen = index + count;
- uint8_t* tmp = new uint8_t[newlen]; // Always > 0.
- if (data_) {
- memcpy(static_cast<void*>(tmp), static_cast<const void*>(data_), len_);
- }
- if (index > len_) {
- memset(static_cast<void*>(tmp + len_), 0, index - len_);
- }
- delete[] data_;
- data_ = tmp;
- len_ = newlen;
- }
- if (data_) {
- memcpy(static_cast<void*>(data_ + index), static_cast<const void*>(val),
- count);
- }
- return index + count;
- }
-
- size_t Write(size_t index, const DataBuffer& buf) {
- return Write(index, buf.data(), buf.len());
- }
-
- // Write an integer, also performing host-to-network order conversion.
- // Returns the offset of the end of the write.
- size_t Write(size_t index, uint32_t val, size_t count) {
- assert(count <= sizeof(uint32_t));
- uint32_t nvalue = htonl(val);
- auto* addr = reinterpret_cast<const uint8_t*>(&nvalue);
- return Write(index, addr + sizeof(uint32_t) - count, count);
- }
-
- // This can't use the same trick as Write(), since we might be reading from a
- // smaller data source.
- bool Read(size_t index, size_t count, uint32_t* val) const {
- assert(count < sizeof(uint32_t));
- assert(val);
- if ((index > len()) || (count > (len() - index))) {
- return false;
- }
- *val = 0;
- for (size_t i = 0; i < count; ++i) {
- *val = (*val << 8) | data()[index + i];
- }
- return true;
- }
-
- // Starting at |index|, remove |remove| bytes and replace them with the
- // contents of |buf|.
- void Splice(const DataBuffer& buf, size_t index, size_t remove = 0) {
- Splice(buf.data(), buf.len(), index, remove);
- }
-
- void Splice(const uint8_t* ins, size_t ins_len, size_t index,
- size_t remove = 0) {
- assert(ins);
- uint8_t* old_value = data_;
- size_t old_len = len_;
-
- // The amount of stuff remaining from the tail of the old.
- size_t tail_len = old_len - std::min(old_len, index + remove);
- // The new length: the head of the old, the new, and the tail of the old.
- len_ = index + ins_len + tail_len;
- data_ = new uint8_t[len_ ? len_ : 1];
-
- // The head of the old.
- if (old_value) {
- Write(0, old_value, std::min(old_len, index));
- }
- // Maybe a gap.
- if (old_value && index > old_len) {
- memset(old_value + index, 0, index - old_len);
- }
- // The new.
- Write(index, ins, ins_len);
- // The tail of the old.
- if (tail_len > 0) {
- Write(index + ins_len, old_value + index + remove, tail_len);
- }
-
- delete[] old_value;
- }
-
- void Append(const DataBuffer& buf) { Splice(buf, len_); }
-
- const uint8_t* data() const { return data_; }
- uint8_t* data() { return data_; }
- size_t len() const { return len_; }
- bool empty() const { return len_ == 0; }
-
- private:
- uint8_t* data_;
- size_t len_;
-};
-
-static const size_t kMaxBufferPrint = 32;
-
-inline std::ostream& operator<<(std::ostream& stream, const DataBuffer& buf) {
- stream << "[" << buf.len() << "] ";
- for (size_t i = 0; i < buf.len(); ++i) {
- if (!g_ssl_gtest_verbose && i >= kMaxBufferPrint) {
- stream << "...";
- break;
- }
- stream << std::hex << std::setfill('0') << std::setw(2)
- << static_cast<unsigned>(buf.data()[i]);
- }
- stream << std::dec;
- return stream;
-}
-
-inline bool operator==(const DataBuffer& a, const DataBuffer& b) {
- return (a.empty() && b.empty()) ||
- (a.len() == b.len() && 0 == memcmp(a.data(), b.data(), a.len()));
-}
-
-inline bool operator!=(const DataBuffer& a, const DataBuffer& b) {
- return !(a == b);
-}
-
-} // namespace nss_test
-
-#endif
diff --git a/nss/gtests/ssl_gtest/gtest_utils.h b/nss/gtests/ssl_gtest/gtest_utils.h
index 3ecd96c..2344c3c 100644
--- a/nss/gtests/ssl_gtest/gtest_utils.h
+++ b/nss/gtests/ssl_gtest/gtest_utils.h
@@ -34,7 +34,7 @@ class Timeout : public PollTarget {
bool timed_out() const { return !handle_; }
private:
- Poller::Timer* handle_;
+ std::shared_ptr<Poller::Timer> handle_;
};
} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/libssl_internals.c b/nss/gtests/ssl_gtest/libssl_internals.c
index 5136ee8..32ffcb6 100644
--- a/nss/gtests/ssl_gtest/libssl_internals.c
+++ b/nss/gtests/ssl_gtest/libssl_internals.c
@@ -10,8 +10,6 @@
#include "nss.h"
#include "pk11pub.h"
#include "seccomon.h"
-#include "ssl.h"
-#include "sslimpl.h"
SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) {
sslSocket *ss = ssl_FindSocket(fd);
@@ -35,15 +33,8 @@ SECStatus SSLInt_UpdateSSLv2ClientRandom(PRFileDesc *fd, uint8_t *rnd,
return SECFailure;
}
- SECStatus rv = ssl3_InitState(ss);
- if (rv != SECSuccess) {
- return rv;
- }
-
- rv = ssl3_RestartHandshakeHashes(ss);
- if (rv != SECSuccess) {
- return rv;
- }
+ ssl3_InitState(ss);
+ ssl3_RestartHandshakeHashes(ss);
// Ensure we don't overrun hs.client_random.
rnd_len = PR_MIN(SSL3_RANDOM_LENGTH, rnd_len);
@@ -64,18 +55,15 @@ PRBool SSLInt_ExtensionNegotiated(PRFileDesc *fd, PRUint16 ext) {
return (PRBool)(ss && ssl3_ExtensionNegotiated(ss, ext));
}
-void SSLInt_ClearSessionTicketKey() {
- ssl3_SessionTicketShutdown(NULL, NULL);
- NSS_UnregisterShutdown(ssl3_SessionTicketShutdown, NULL);
-}
+void SSLInt_ClearSessionTicketKey() { ssl_ResetSessionTicketKeys(); }
SECStatus SSLInt_SetMTU(PRFileDesc *fd, PRUint16 mtu) {
sslSocket *ss = ssl_FindSocket(fd);
- if (ss) {
- ss->ssl3.mtu = mtu;
- return SECSuccess;
+ if (!ss) {
+ return SECFailure;
}
- return SECFailure;
+ ss->ssl3.mtu = mtu;
+ return SECSuccess;
}
PRInt32 SSLInt_CountTls13CipherSpecs(PRFileDesc *fd) {
@@ -199,7 +187,9 @@ SECStatus SSLInt_Set0RttAlpn(PRFileDesc *fd, PRUint8 *data, unsigned int len) {
if (ss->xtnData.nextProto.data) {
SECITEM_FreeItem(&ss->xtnData.nextProto, PR_FALSE);
}
- if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) return SECFailure;
+ if (!SECITEM_AllocItem(NULL, &ss->xtnData.nextProto, len)) {
+ return SECFailure;
+ }
PORT_Memcpy(ss->xtnData.nextProto.data, data, len);
return SECSuccess;
@@ -211,7 +201,7 @@ PRBool SSLInt_HasCertWithAuthType(PRFileDesc *fd, SSLAuthType authType) {
return PR_FALSE;
}
- return (PRBool)(!!ssl_FindServerCertByAuthType(ss, authType));
+ return (PRBool)(!!ssl_FindServerCert(ss, authType, NULL));
}
PRBool SSLInt_SendAlert(PRFileDesc *fd, uint8_t level, uint8_t type) {
@@ -256,6 +246,7 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
return SECFailure;
}
if (to >= (1ULL << 48)) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
ssl_GetSpecWriteLock(ss);
@@ -267,6 +258,7 @@ SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to) {
* scrub the entire structure on the assumption that the new sequence number
* is far enough past the last received sequence number. */
if (to <= spec->recvdRecords.right + DTLS_RECVD_RECORDS_WINDOW) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
dtls_RecordSetRecvd(&spec->recvdRecords, to);
@@ -284,6 +276,7 @@ SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to) {
return SECFailure;
}
if (to >= (1ULL << 48)) {
+ PORT_SetError(SEC_ERROR_INVALID_ARGS);
return SECFailure;
}
ssl_GetSpecWriteLock(ss);
@@ -314,6 +307,40 @@ SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group) {
return groupDef->keaType;
}
+SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
+ sslCipherSpecChangedFunc func,
+ void *arg) {
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ ss->ssl3.changedCipherSpecFunc = func;
+ ss->ssl3.changedCipherSpecArg = arg;
+
+ return SECSuccess;
+}
+
+static ssl3KeyMaterial *GetKeyingMaterial(PRBool isServer,
+ ssl3CipherSpec *spec) {
+ return isServer ? &spec->server : &spec->client;
+}
+
+PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec) {
+ return GetKeyingMaterial(isServer, spec)->write_key;
+}
+
+SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
+ ssl3CipherSpec *spec) {
+ return spec->cipher_def->calg;
+}
+
+unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec) {
+ return GetKeyingMaterial(isServer, spec)->write_iv;
+}
+
SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd) {
sslSocket *ss;
@@ -335,6 +362,36 @@ SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result) {
}
*result = ss->ssl3.hs.shortHeaders;
+ return SECSuccess;
+}
+
+void SSLInt_SetTicketLifetime(uint32_t lifetime) {
+ ssl_ticket_lifetime = lifetime;
+}
+
+void SSLInt_SetMaxEarlyDataSize(uint32_t size) {
+ ssl_max_early_data_size = size;
+}
+
+SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size) {
+ sslSocket *ss;
+
+ ss = ssl_FindSocket(fd);
+ if (!ss) {
+ return SECFailure;
+ }
+
+ /* This only works when resuming. */
+ if (!ss->statelessResume) {
+ PORT_SetError(SEC_INTERNAL_ONLY);
+ return SECFailure;
+ }
+
+ /* Modifying both specs allows this to be used on either peer. */
+ ssl_GetSpecWriteLock(ss);
+ ss->ssl3.crSpec->earlyDataRemaining = size;
+ ss->ssl3.cwSpec->earlyDataRemaining = size;
+ ssl_ReleaseSpecWriteLock(ss);
return SECSuccess;
}
diff --git a/nss/gtests/ssl_gtest/libssl_internals.h b/nss/gtests/ssl_gtest/libssl_internals.h
index 6ea66db..531c31f 100644
--- a/nss/gtests/ssl_gtest/libssl_internals.h
+++ b/nss/gtests/ssl_gtest/libssl_internals.h
@@ -11,6 +11,8 @@
#include "prio.h"
#include "seccomon.h"
+#include "ssl.h"
+#include "sslimpl.h"
#include "sslt.h"
SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd);
@@ -37,7 +39,18 @@ SECStatus SSLInt_AdvanceWriteSeqNum(PRFileDesc *fd, PRUint64 to);
SECStatus SSLInt_AdvanceReadSeqNum(PRFileDesc *fd, PRUint64 to);
SECStatus SSLInt_AdvanceWriteSeqByAWindow(PRFileDesc *fd, PRInt32 extra);
SSLKEAType SSLInt_GetKEAType(SSLNamedGroup group);
+
+SECStatus SSLInt_SetCipherSpecChangeFunc(PRFileDesc *fd,
+ sslCipherSpecChangedFunc func,
+ void *arg);
+PK11SymKey *SSLInt_CipherSpecToKey(PRBool isServer, ssl3CipherSpec *spec);
+SSLCipherAlgorithm SSLInt_CipherSpecToAlgorithm(PRBool isServer,
+ ssl3CipherSpec *spec);
+unsigned char *SSLInt_CipherSpecToIv(PRBool isServer, ssl3CipherSpec *spec);
SECStatus SSLInt_EnableShortHeaders(PRFileDesc *fd);
SECStatus SSLInt_UsingShortHeaders(PRFileDesc *fd, PRBool *result);
+void SSLInt_SetTicketLifetime(uint32_t lifetime);
+void SSLInt_SetMaxEarlyDataSize(uint32_t size);
+SECStatus SSLInt_SetSocketMaxEarlyDataSize(PRFileDesc *fd, uint32_t size);
#endif // ndef libssl_internals_h_
diff --git a/nss/gtests/ssl_gtest/manifest.mn b/nss/gtests/ssl_gtest/manifest.mn
index 391db81..e3775cc 100644
--- a/nss/gtests/ssl_gtest/manifest.mn
+++ b/nss/gtests/ssl_gtest/manifest.mn
@@ -12,6 +12,9 @@ CSRCS = \
$(NULL)
CPPSRCS = \
+ $(CORE_DEPTH)/cpputil/dummy_io.cc \
+ $(CORE_DEPTH)/cpputil/dummy_io_fwd.cc \
+ $(CORE_DEPTH)/cpputil/tls_parser.cc \
ssl_0rtt_unittest.cc \
ssl_agent_unittest.cc \
ssl_auth_unittest.cc \
@@ -24,7 +27,9 @@ CPPSRCS = \
ssl_ems_unittest.cc \
ssl_exporter_unittest.cc \
ssl_extension_unittest.cc \
+ ssl_fragment_unittest.cc \
ssl_fuzz_unittest.cc \
+ ssl_gather_unittest.cc \
ssl_gtest.cc \
ssl_hrr_unittest.cc \
ssl_loopback_unittest.cc \
@@ -34,21 +39,22 @@ CPPSRCS = \
ssl_staticrsa_unittest.cc \
ssl_v2_client_hello_unittest.cc \
ssl_version_unittest.cc \
+ ssl_versionpolicy_unittest.cc \
test_io.cc \
tls_agent.cc \
tls_connect.cc \
tls_hkdf_unittest.cc \
tls_filter.cc \
- tls_parser.cc \
+ tls_protect.cc \
$(NULL)
INCLUDES += -I$(CORE_DEPTH)/gtests/google_test/gtest/include \
- -I$(CORE_DEPTH)/gtests/common
+ -I$(CORE_DEPTH)/gtests/common \
+ -I$(CORE_DEPTH)/cpputil
REQUIRES = nspr nss libdbm gtest
PROGRAM = ssl_gtest
-EXTRA_LIBS = $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX) \
- $(DIST)/lib/$(LIB_PREFIX)softokn.$(LIB_SUFFIX)
+EXTRA_LIBS = $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX)
USE_STATIC_LIBS = 1
diff --git a/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc b/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
index cf5a27f..85b7011 100644
--- a/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_0rtt_unittest.cc
@@ -155,6 +155,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnServer) {
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "a");
EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, sizeof(b)));
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
return true;
});
Handshake();
@@ -174,6 +175,7 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttNoAlpnClient) {
PRUint8 b[] = {'b'};
EXPECT_EQ(SECSuccess, SSLInt_Set0RttAlpn(client_->ssl_fd(), b, 1));
client_->CheckAlpn(SSL_NEXT_PROTO_EARLY_VALUE, "b");
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
return true;
});
Handshake();
@@ -200,4 +202,201 @@ TEST_P(TlsConnectTls13, TestTls13ZeroRttAlpnChangeBoth) {
CheckAlpn("b");
}
+// The client should abort the connection when sending a 0-rtt handshake but
+// the servers responds with a TLS 1.2 ServerHello. (no app data sent)
+TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngrade) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->Set0RttEnabled(true); // set ticket_allow_early_data
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->StartConnect();
+ server_->StartConnect();
+
+ // We will send the early data xtn without sending actual early data. Thus
+ // a 1.2 server shouldn't fail until the client sends an alert because the
+ // client sends end_of_early_data only after reading the server's flight.
+ client_->Set0RttEnabled(true);
+
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ }
+ client_->Handshake();
+ server_->Handshake();
+ ASSERT_TRUE_WAIT(
+ (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000);
+
+ // DTLS will timeout as we bump the epoch when installing the early app data
+ // cipher suite. Thus the encrypted alert will be ignored.
+ if (variant_ == ssl_variant_stream) {
+ // The client sends an encrypted alert message.
+ ASSERT_TRUE_WAIT(
+ (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA),
+ 2000);
+ }
+}
+
+// The client should abort the connection when sending a 0-rtt handshake but
+// the servers responds with a TLS 1.2 ServerHello. (with app data)
+TEST_P(TlsConnectTls13, TestTls13ZeroRttDowngradeEarlyData) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ server_->Set0RttEnabled(true); // set ticket_allow_early_data
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->StartConnect();
+ server_->StartConnect();
+
+ // Send the early data xtn in the CH, followed by early app data. The server
+ // will fail right after sending its flight, when receiving the early data.
+ client_->Set0RttEnabled(true);
+ ZeroRttSendReceive(true, false, [this]() {
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ }
+ return true;
+ });
+
+ client_->Handshake();
+ server_->Handshake();
+ ASSERT_TRUE_WAIT(
+ (client_->error_code() == SSL_ERROR_DOWNGRADE_WITH_EARLY_DATA), 2000);
+
+ // DTLS will timeout as we bump the epoch when installing the early app data
+ // cipher suite. Thus the encrypted alert will be ignored.
+ if (variant_ == ssl_variant_stream) {
+ // The server sends an alert when receiving the early app data record.
+ ASSERT_TRUE_WAIT(
+ (server_->error_code() == SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA),
+ 2000);
+ }
+}
+
+static void CheckEarlyDataLimit(const std::shared_ptr<TlsAgent>& agent,
+ size_t expected_size) {
+ SSLPreliminaryChannelInfo preinfo;
+ SECStatus rv =
+ SSL_GetPreliminaryChannelInfo(agent->ssl_fd(), &preinfo, sizeof(preinfo));
+ EXPECT_EQ(SECSuccess, rv);
+ EXPECT_EQ(expected_size, static_cast<size_t>(preinfo.maxEarlyDataSize));
+}
+
+TEST_P(TlsConnectTls13, SendTooMuchEarlyData) {
+ const char* big_message = "0123456789abcdef";
+ const size_t short_size = strlen(big_message) - 1;
+ const PRInt32 short_length = static_cast<PRInt32>(short_size);
+ SSLInt_SetMaxEarlyDataSize(static_cast<PRUint32>(short_size));
+ SetupForZeroRtt();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ ExpectAlert(client_, kTlsAlertEndOfEarlyData);
+ client_->Handshake();
+ CheckEarlyDataLimit(client_, short_size);
+
+ PRInt32 sent;
+ // Writing more than the limit will succeed in TLS, but fail in DTLS.
+ if (variant_ == ssl_variant_stream) {
+ sent = PR_Write(client_->ssl_fd(), big_message,
+ static_cast<PRInt32>(strlen(big_message)));
+ } else {
+ sent = PR_Write(client_->ssl_fd(), big_message,
+ static_cast<PRInt32>(strlen(big_message)));
+ EXPECT_GE(0, sent);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Try an exact-sized write now.
+ sent = PR_Write(client_->ssl_fd(), big_message, short_length);
+ }
+ EXPECT_EQ(short_length, sent);
+
+ // Even a single octet write should now fail.
+ sent = PR_Write(client_->ssl_fd(), big_message, 1);
+ EXPECT_GE(0, sent);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ // Process the ClientHello and read 0-RTT.
+ server_->Handshake();
+ CheckEarlyDataLimit(server_, short_size);
+
+ std::vector<uint8_t> buf(short_size + 1);
+ PRInt32 read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+ EXPECT_EQ(short_length, read);
+ EXPECT_EQ(0, memcmp(big_message, buf.data(), short_size));
+
+ // Second read fails.
+ read = PR_Read(server_->ssl_fd(), buf.data(), buf.capacity());
+ EXPECT_EQ(SECFailure, read);
+ EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
+
+ Handshake();
+ ExpectEarlyDataAccepted(true);
+ CheckConnected();
+ SendReceive();
+}
+
+TEST_P(TlsConnectTls13, ReceiveTooMuchEarlyData) {
+ const size_t limit = 5;
+ SSLInt_SetMaxEarlyDataSize(limit);
+ SetupForZeroRtt();
+
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+ ExpectResumption(RESUME_TICKET);
+
+ client_->ExpectSendAlert(kTlsAlertEndOfEarlyData);
+ client_->Handshake(); // Send ClientHello
+ CheckEarlyDataLimit(client_, limit);
+
+ // Lift the limit on the client.
+ EXPECT_EQ(SECSuccess,
+ SSLInt_SetSocketMaxEarlyDataSize(client_->ssl_fd(), 1000));
+
+ // Send message
+ const char* message = "0123456789abcdef";
+ const PRInt32 message_len = static_cast<PRInt32>(strlen(message));
+ EXPECT_EQ(message_len, PR_Write(client_->ssl_fd(), message, message_len));
+
+ if (variant_ == ssl_variant_stream) {
+ // This error isn't fatal for DTLS.
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ }
+ server_->Handshake(); // Process ClientHello, send server flight.
+ server_->Handshake(); // Just to make sure that we don't read ahead.
+ CheckEarlyDataLimit(server_, limit);
+
+ // Attempt to read early data.
+ std::vector<uint8_t> buf(strlen(message) + 1);
+ EXPECT_GT(0, PR_Read(server_->ssl_fd(), buf.data(), buf.capacity()));
+ if (variant_ == ssl_variant_stream) {
+ server_->CheckErrorCode(SSL_ERROR_TOO_MUCH_EARLY_DATA);
+ }
+
+ client_->Handshake(); // Process the handshake.
+ client_->Handshake(); // Process the alert.
+ if (variant_ == ssl_variant_stream) {
+ client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
+ }
+}
+
} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_agent_unittest.cc b/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
index 0e6ddae..5035a33 100644
--- a/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_agent_unittest.cc
@@ -56,6 +56,7 @@ static const char *k0RttData = "ABCDEF";
TEST_P(TlsAgentTest, EarlyFinished) {
DataBuffer buffer;
MakeTrivialHandshakeRecord(kTlsHandshakeFinished, 0, &buffer);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
ProcessMessage(buffer, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_FINISHED);
}
@@ -63,15 +64,14 @@ TEST_P(TlsAgentTest, EarlyFinished) {
TEST_P(TlsAgentTest, EarlyCertificateVerify) {
DataBuffer buffer;
MakeTrivialHandshakeRecord(kTlsHandshakeCertificateVerify, 0, &buffer);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
ProcessMessage(buffer, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
}
-TEST_P(TlsAgentTestClient, CannedHello) {
+TEST_P(TlsAgentTestClient13, CannedHello) {
DataBuffer buffer;
EnsureInit();
- agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
- SSL_LIBRARY_VERSION_TLS_1_3);
DataBuffer server_hello;
MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello,
sizeof(kCannedTls13ServerHello), &server_hello);
@@ -80,7 +80,7 @@ TEST_P(TlsAgentTestClient, CannedHello) {
ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
}
-TEST_P(TlsAgentTestClient, EncryptedExtensionsInClear) {
+TEST_P(TlsAgentTestClient13, EncryptedExtensionsInClear) {
DataBuffer server_hello;
MakeHandshakeMessage(kTlsHandshakeServerHello, kCannedTls13ServerHello,
sizeof(kCannedTls13ServerHello), &server_hello);
@@ -92,8 +92,7 @@ TEST_P(TlsAgentTestClient, EncryptedExtensionsInClear) {
MakeRecord(kTlsHandshakeType, SSL_LIBRARY_VERSION_TLS_1_3,
server_hello.data(), server_hello.len(), &buffer);
EnsureInit();
- agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
- SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
ProcessMessage(buffer, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
}
@@ -118,6 +117,7 @@ TEST_F(TlsAgentStreamTestClient, EncryptedExtensionsInClearTwoPieces) {
agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
SSL_LIBRARY_VERSION_TLS_1_3);
ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
ProcessMessage(buffer2, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
}
@@ -148,6 +148,7 @@ TEST_F(TlsAgentDgramTestClient, EncryptedExtensionsInClearTwoPieces) {
agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
SSL_LIBRARY_VERSION_TLS_1_3);
ProcessMessage(buffer, TlsAgent::STATE_CONNECTING);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
ProcessMessage(buffer2, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_HANDSHAKE);
}
@@ -158,8 +159,8 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenWrite) {
SSL_LIBRARY_VERSION_TLS_1_3);
agent_->StartConnect();
agent_->Set0RttEnabled(true);
- auto filter =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeClientHello);
+ auto filter = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeClientHello);
agent_->SetPacketFilter(filter);
PRInt32 rv = PR_Write(agent_->ssl_fd(), k0RttData, strlen(k0RttData));
EXPECT_EQ(-1, rv);
@@ -178,6 +179,7 @@ TEST_F(TlsAgentStreamTestClient, Set0RttOptionThenRead) {
MakeRecord(kTlsApplicationDataType, SSL_LIBRARY_VERSION_TLS_1_3,
reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData),
&buffer);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
ProcessMessage(buffer, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_APPLICATION_DATA);
}
@@ -198,13 +200,19 @@ TEST_F(TlsAgentStreamTestServer, Set0RttOptionClientHelloThenRead) {
MakeRecord(kTlsApplicationDataType, SSL_LIBRARY_VERSION_TLS_1_3,
reinterpret_cast<const uint8_t *>(k0RttData), strlen(k0RttData),
&buffer);
+ ExpectAlert(kTlsAlertBadRecordMac);
ProcessMessage(buffer, TlsAgent::STATE_ERROR, SSL_ERROR_BAD_MAC_READ);
}
INSTANTIATE_TEST_CASE_P(
AgentTests, TlsAgentTest,
::testing::Combine(TlsAgentTestBase::kTlsRolesAll,
- TlsConnectTestBase::kTlsModesStream));
+ TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
INSTANTIATE_TEST_CASE_P(ClientTests, TlsAgentTestClient,
- TlsConnectTestBase::kTlsModesAll);
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_CASE_P(ClientTests13, TlsAgentTestClient13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_auth_unittest.cc b/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
index e407d55..dbcbc9a 100644
--- a/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_auth_unittest.cc
@@ -77,9 +77,10 @@ TEST_P(TlsConnectGeneric, ClientAuthBigRsa) {
}
// Offset is the position in the captured buffer where the signature sits.
-static void CheckSigScheme(TlsInspectorRecordHandshakeMessage* capture,
- size_t offset, TlsAgent* peer,
- uint16_t expected_scheme, size_t expected_size) {
+static void CheckSigScheme(
+ std::shared_ptr<TlsInspectorRecordHandshakeMessage>& capture, size_t offset,
+ std::shared_ptr<TlsAgent>& peer, uint16_t expected_scheme,
+ size_t expected_size) {
EXPECT_LT(offset + 2U, capture->buffer().len());
uint32_t scheme = 0;
@@ -95,8 +96,8 @@ static void CheckSigScheme(TlsInspectorRecordHandshakeMessage* capture,
// in the default certificate.
TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
EnsureTlsSetup();
- auto capture_ske =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ auto capture_ske = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(capture_ske);
Connect();
CheckKeys();
@@ -114,7 +115,8 @@ TEST_P(TlsConnectTls12, ServerAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
EnsureTlsSetup();
auto capture_cert_verify =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify);
+ std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeCertificateVerify);
client_->SetPacketFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
@@ -127,7 +129,8 @@ TEST_P(TlsConnectTls12, ClientAuthCheckSigAlg) {
TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
Reset(TlsAgent::kServerRsa, TlsAgent::kRsa2048);
auto capture_cert_verify =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeCertificateVerify);
+ std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeCertificateVerify);
client_->SetPacketFilter(capture_cert_verify);
client_->SetupClientAuth();
server_->RequestClientAuth(true);
@@ -136,6 +139,76 @@ TEST_P(TlsConnectTls12, ClientAuthBigRsaCheckSigAlg) {
CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pss_sha256, 2048);
}
+class TlsZeroCertificateRequestSigAlgsFilter : public TlsHandshakeFilter {
+ public:
+ virtual PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) {
+ if (header.handshake_type() != kTlsHandshakeCertificateRequest) {
+ return KEEP;
+ }
+
+ TlsParser parser(input);
+ std::cerr << "Zeroing CertReq.supported_signature_algorithms" << std::endl;
+
+ DataBuffer cert_types;
+ if (!parser.ReadVariable(&cert_types, 1)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ if (!parser.SkipVariable(2)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ DataBuffer cas;
+ if (!parser.ReadVariable(&cas, 2)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ size_t idx = 0;
+
+ // Write certificate types.
+ idx = output->Write(idx, cert_types.len(), 1);
+ idx = output->Write(idx, cert_types);
+
+ // Write zero signature algorithms.
+ idx = output->Write(idx, 0U, 2);
+
+ // Write certificate authorities.
+ idx = output->Write(idx, cas.len(), 2);
+ idx = output->Write(idx, cas);
+
+ return CHANGE;
+ }
+};
+
+// Check that we fall back to SHA-1 when the server doesn't provide any
+// supported_signature_algorithms in the CertificateRequest message.
+TEST_P(TlsConnectTls12, ClientAuthNoSigAlgsFallback) {
+ EnsureTlsSetup();
+ auto filter = std::make_shared<TlsZeroCertificateRequestSigAlgsFilter>();
+ server_->SetPacketFilter(filter);
+ auto capture_cert_verify =
+ std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeCertificateVerify);
+ client_->SetPacketFilter(capture_cert_verify);
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+
+ // We're expecting a bad signature here because we tampered with a handshake
+ // message (CertReq). Previously, without the SHA-1 fallback, we would've
+ // seen a malformed record alert.
+ server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+
+ CheckSigScheme(capture_cert_verify, 0, server_, ssl_sig_rsa_pkcs1_sha1, 1024);
+}
+
static const SSLSignatureScheme SignatureSchemeEcdsaSha384[] = {
ssl_sig_ecdsa_secp384r1_sha384};
static const SSLSignatureScheme SignatureSchemeEcdsaSha256[] = {
@@ -198,20 +271,38 @@ TEST_P(TlsConnectGeneric, SignatureAlgorithmServerOnly) {
ssl_sig_ecdsa_secp384r1_sha384);
}
-TEST_P(TlsConnectTls12Plus, SignatureSchemeCurveMismatch) {
+// In TLS 1.2, curve and hash aren't bound together.
+TEST_P(TlsConnectTls12, SignatureSchemeCurveMismatch) {
Reset(TlsAgent::kServerEcdsa256);
client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
- ConnectExpectFail();
+ Connect();
+}
+
+// In TLS 1.3, curve and hash are coupled.
+TEST_P(TlsConnectTls13, SignatureSchemeCurveMismatch) {
+ Reset(TlsAgent::kServerEcdsa256);
+ client_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
-TEST_P(TlsConnectTls12Plus, SignatureSchemeBadConfig) {
+// Configuring a P-256 cert with only SHA-384 signatures is OK in TLS 1.2.
+TEST_P(TlsConnectTls12, SignatureSchemeBadConfig) {
Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used.
server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
- ConnectExpectFail();
+ Connect();
+}
+
+// A P-256 certificate in TLS 1.3 needs a SHA-256 signature scheme.
+TEST_P(TlsConnectTls13, SignatureSchemeBadConfig) {
+ Reset(TlsAgent::kServerEcdsa256); // P-256 cert can't be used.
+ server_->SetSignatureSchemes(SignatureSchemeEcdsaSha384,
+ PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
@@ -234,7 +325,7 @@ TEST_P(TlsConnectTls12Plus, SignatureAlgorithmNoOverlapEcdsa) {
PR_ARRAY_SIZE(SignatureSchemeEcdsaSha384));
server_->SetSignatureSchemes(SignatureSchemeEcdsaSha256,
PR_ARRAY_SIZE(SignatureSchemeEcdsaSha256));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_SIGNATURE_ALGORITHM);
}
@@ -252,8 +343,8 @@ TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) {
// The signature_algorithms extension is mandatory in TLS 1.3.
TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
client_->SetPacketFilter(
- new TlsExtensionDropper(ssl_signature_algorithms_xtn));
- ConnectExpectFail();
+ std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
}
@@ -262,8 +353,8 @@ TEST_P(TlsConnectTls13, SignatureAlgorithmDrop) {
// only fails when the Finished is checked.
TEST_P(TlsConnectTls12, SignatureAlgorithmDrop) {
client_->SetPacketFilter(
- new TlsExtensionDropper(ssl_signature_algorithms_xtn));
- ConnectExpectFail();
+ std::make_shared<TlsExtensionDropper>(ssl_signature_algorithms_xtn));
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
}
@@ -280,7 +371,8 @@ class BeforeFinished : public TlsRecordFilter {
enum HandshakeState { BEFORE_CCS, AFTER_CCS, DONE };
public:
- BeforeFinished(TlsAgent* client, TlsAgent* server, VoidFunction before_ccs,
+ BeforeFinished(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server, VoidFunction before_ccs,
VoidFunction before_finished)
: client_(client),
server_(server),
@@ -289,7 +381,7 @@ class BeforeFinished : public TlsRecordFilter {
state_(BEFORE_CCS) {}
protected:
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) {
switch (state_) {
@@ -303,8 +395,8 @@ class BeforeFinished : public TlsRecordFilter {
// but that means that they both get processed together.
DataBuffer ccs;
header.Write(&ccs, 0, body);
- server_->SendDirect(ccs);
- client_->Handshake();
+ server_.lock()->SendDirect(ccs);
+ client_.lock()->Handshake();
state_ = AFTER_CCS;
// Request that the original record be dropped by the filter.
return DROP;
@@ -327,8 +419,8 @@ class BeforeFinished : public TlsRecordFilter {
}
private:
- TlsAgent* client_;
- TlsAgent* server_;
+ std::weak_ptr<TlsAgent> client_;
+ std::weak_ptr<TlsAgent> server_;
VoidFunction before_ccs_;
VoidFunction before_finished_;
HandshakeState state_;
@@ -353,7 +445,8 @@ class BeforeFinished13 : public PacketFilter {
};
public:
- BeforeFinished13(TlsAgent* client, TlsAgent* server,
+ BeforeFinished13(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server,
VoidFunction before_finished)
: client_(client),
server_(server),
@@ -367,7 +460,7 @@ class BeforeFinished13 : public PacketFilter {
case 1:
// Packet 1 is the server's entire first flight. Drop it.
EXPECT_EQ(SECSuccess,
- SSLInt_SetMTU(server_->ssl_fd(), input.len() - 1));
+ SSLInt_SetMTU(server_.lock()->ssl_fd(), input.len() - 1));
return DROP;
// Packet 2 is the first part of the server's retransmitted first
@@ -377,7 +470,7 @@ class BeforeFinished13 : public PacketFilter {
// Packet 3 is the second part of the server's retransmitted first
// flight. Before passing that on, make sure that the client processes
// packet 2, then call the before_finished_() callback.
- client_->Handshake();
+ client_.lock()->Handshake();
before_finished_();
break;
@@ -388,8 +481,8 @@ class BeforeFinished13 : public PacketFilter {
}
private:
- TlsAgent* client_;
- TlsAgent* server_;
+ std::weak_ptr<TlsAgent> client_;
+ std::weak_ptr<TlsAgent> server_;
VoidFunction before_finished_;
size_t records_;
};
@@ -403,9 +496,11 @@ static SECStatus AuthCompleteBlock(TlsAgent*, PRBool, PRBool) {
// processed by the client, SSL_AuthCertificateComplete() is called.
TEST_F(TlsConnectDatagram13, AuthCompleteBeforeFinished) {
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->SetPacketFilter(new BeforeFinished13(client_, server_, [this]() {
- EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
- }));
+ server_->SetPacketFilter(
+ std::make_shared<BeforeFinished13>(client_, server_, [this]() {
+ EXPECT_EQ(SECSuccess,
+ SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
+ }));
Connect();
}
@@ -422,9 +517,9 @@ static void TriggerAuthComplete(PollTarget* target, Event event) {
TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) {
client_->SetAuthCertificateCallback(
[this](TlsAgent*, PRBool, PRBool) -> SECStatus {
- Poller::Timer* timer_handle;
+ std::shared_ptr<Poller::Timer> timer_handle;
// This is really just to unroll the stack.
- Poller::Instance()->SetTimer(1U, client_, TriggerAuthComplete,
+ Poller::Instance()->SetTimer(1U, client_.get(), TriggerAuthComplete,
&timer_handle);
return SECWouldBlock;
});
@@ -433,7 +528,7 @@ TEST_F(TlsConnectDatagram13, AuthCompleteAfterFinished) {
TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
client_->EnableFalseStart();
- server_->SetPacketFilter(new BeforeFinished(
+ server_->SetPacketFilter(std::make_shared<BeforeFinished>(
client_, server_,
[this]() { EXPECT_TRUE(client_->can_falsestart_hook_called()); },
[this]() {
@@ -449,7 +544,7 @@ TEST_P(TlsConnectGenericPre13, ClientWriteBetweenCCSAndFinishedWithFalseStart) {
TEST_P(TlsConnectGenericPre13, AuthCompleteBeforeFinishedWithFalseStart) {
client_->EnableFalseStart();
client_->SetAuthCertificateCallback(AuthCompleteBlock);
- server_->SetPacketFilter(new BeforeFinished(
+ server_->SetPacketFilter(std::make_shared<BeforeFinished>(
client_, server_,
[]() {
// Do nothing before CCS
@@ -496,7 +591,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
// The client should send nothing from here on.
- client_->SetPacketFilter(new EnforceNoActivity());
+ client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
@@ -507,7 +602,7 @@ TEST_P(TlsConnectGenericPre13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
// Remove this before closing or the close_notify alert will trigger it.
- client_->SetPacketFilter(nullptr);
+ client_->DeletePacketFilter();
}
// TLS 1.3 handles a delayed AuthComplete callback differently since the
@@ -523,12 +618,12 @@ TEST_P(TlsConnectTls13, AuthCompleteDelayed) {
EXPECT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
// The client will send nothing until AuthCertificateComplete is called.
- client_->SetPacketFilter(new EnforceNoActivity());
+ client_->SetPacketFilter(std::make_shared<EnforceNoActivity>());
client_->Handshake();
EXPECT_EQ(TlsAgent::STATE_CONNECTING, client_->state());
// This should allow the handshake to complete now.
- client_->SetPacketFilter(nullptr);
+ client_->DeletePacketFilter();
EXPECT_EQ(SECSuccess, SSL_AuthCertificateComplete(client_->ssl_fd(), 0));
client_->Handshake(); // Send Finished
server_->Handshake(); // Transition to connected and send NewSessionTicket
@@ -621,8 +716,8 @@ TEST_F(TlsAgentStreamTestServer, ConfigureCertRsaPss) {
&ServerCertDataRsaPss));
}
-// mode, version, certificate, auth type, signature scheme
-typedef std::tuple<std::string, uint16_t, std::string, SSLAuthType,
+// variant, version, certificate, auth type, signature scheme
+typedef std::tuple<SSLProtocolVariant, uint16_t, std::string, SSLAuthType,
SSLSignatureScheme>
SignatureSchemeProfile;
@@ -637,7 +732,7 @@ class TlsSignatureSchemeConfiguration
signature_scheme_(std::get<4>(GetParam())) {}
protected:
- void TestSignatureSchemeConfig(TlsAgent* configPeer) {
+ void TestSignatureSchemeConfig(std::shared_ptr<TlsAgent>& configPeer) {
EnsureTlsSetup();
configPeer->SetSignatureSchemes(&signature_scheme_, 1);
Connect();
@@ -657,8 +752,8 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigServer) {
TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigClient) {
Reset(certificate_);
- TlsExtensionCapture* capture =
- new TlsExtensionCapture(ssl_signature_algorithms_xtn);
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
client_->SetPacketFilter(capture);
TestSignatureSchemeConfig(client_);
@@ -683,7 +778,7 @@ TEST_P(TlsSignatureSchemeConfiguration, SignatureSchemeConfigBoth) {
INSTANTIATE_TEST_CASE_P(
SignatureSchemeRsa, TlsSignatureSchemeConfiguration,
::testing::Combine(
- TlsConnectTestBase::kTlsModesAll, TlsConnectTestBase::kTlsV12Plus,
+ TlsConnectTestBase::kTlsVariantsAll, TlsConnectTestBase::kTlsV12Plus,
::testing::Values(TlsAgent::kServerRsaSign),
::testing::Values(ssl_auth_rsa_sign),
::testing::Values(ssl_sig_rsa_pkcs1_sha256, ssl_sig_rsa_pkcs1_sha384,
@@ -692,42 +787,42 @@ INSTANTIATE_TEST_CASE_P(
// PSS with SHA-512 needs a bigger key to work.
INSTANTIATE_TEST_CASE_P(
SignatureSchemeBigRsa, TlsSignatureSchemeConfiguration,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12Plus,
::testing::Values(TlsAgent::kRsa2048),
::testing::Values(ssl_auth_rsa_sign),
::testing::Values(ssl_sig_rsa_pss_sha512)));
INSTANTIATE_TEST_CASE_P(
SignatureSchemeRsaSha1, TlsSignatureSchemeConfiguration,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12,
::testing::Values(TlsAgent::kServerRsa),
::testing::Values(ssl_auth_rsa_sign),
::testing::Values(ssl_sig_rsa_pkcs1_sha1)));
INSTANTIATE_TEST_CASE_P(
SignatureSchemeEcdsaP256, TlsSignatureSchemeConfiguration,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12Plus,
::testing::Values(TlsAgent::kServerEcdsa256),
::testing::Values(ssl_auth_ecdsa),
::testing::Values(ssl_sig_ecdsa_secp256r1_sha256)));
INSTANTIATE_TEST_CASE_P(
SignatureSchemeEcdsaP384, TlsSignatureSchemeConfiguration,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12Plus,
::testing::Values(TlsAgent::kServerEcdsa384),
::testing::Values(ssl_auth_ecdsa),
::testing::Values(ssl_sig_ecdsa_secp384r1_sha384)));
INSTANTIATE_TEST_CASE_P(
SignatureSchemeEcdsaP521, TlsSignatureSchemeConfiguration,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12Plus,
::testing::Values(TlsAgent::kServerEcdsa521),
::testing::Values(ssl_auth_ecdsa),
::testing::Values(ssl_sig_ecdsa_secp521r1_sha512)));
INSTANTIATE_TEST_CASE_P(
SignatureSchemeEcdsaSha1, TlsSignatureSchemeConfiguration,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12,
::testing::Values(TlsAgent::kServerEcdsa256,
TlsAgent::kServerEcdsa384),
diff --git a/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc b/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
index 876c368..3463782 100644
--- a/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_cert_ext_unittest.cc
@@ -23,9 +23,10 @@ namespace nss_test {
// by the relevant callbacks on the client.
class SignedCertificateTimestampsExtractor {
public:
- SignedCertificateTimestampsExtractor(TlsAgent* client) : client_(client) {
- client_->SetAuthCertificateCallback(
- [&](TlsAgent* agent, bool checksig, bool isServer) -> SECStatus {
+ SignedCertificateTimestampsExtractor(std::shared_ptr<TlsAgent>& client)
+ : client_(client) {
+ client->SetAuthCertificateCallback(
+ [this](TlsAgent* agent, bool checksig, bool isServer) -> SECStatus {
const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd());
EXPECT_TRUE(scts);
if (!scts) {
@@ -34,7 +35,7 @@ class SignedCertificateTimestampsExtractor {
auth_timestamps_.reset(new DataBuffer(scts->data, scts->len));
return SECSuccess;
});
- client_->SetHandshakeCallback([&](TlsAgent* agent) {
+ client->SetHandshakeCallback([this](TlsAgent* agent) {
const SECItem* scts = SSL_PeerSignedCertTimestamps(agent->ssl_fd());
ASSERT_TRUE(scts);
handshake_timestamps_.reset(new DataBuffer(scts->data, scts->len));
@@ -48,12 +49,13 @@ class SignedCertificateTimestampsExtractor {
EXPECT_TRUE(handshake_timestamps_);
EXPECT_EQ(timestamps, *handshake_timestamps_);
- const SECItem* current = SSL_PeerSignedCertTimestamps(client_->ssl_fd());
+ const SECItem* current =
+ SSL_PeerSignedCertTimestamps(client_.lock()->ssl_fd());
EXPECT_EQ(timestamps, DataBuffer(current->data, current->len));
}
private:
- TlsAgent* client_;
+ std::weak_ptr<TlsAgent> client_;
std::unique_ptr<DataBuffer> auth_timestamps_;
std::unique_ptr<DataBuffer> handshake_timestamps_;
};
@@ -62,10 +64,22 @@ static const uint8_t kSctValue[] = {0x01, 0x23, 0x45, 0x67, 0x89};
static const SECItem kSctItem = {siBuffer, const_cast<uint8_t*>(kSctValue),
sizeof(kSctValue)};
static const DataBuffer kSctBuffer(kSctValue, sizeof(kSctValue));
+static const SSLExtraServerCertData kExtraSctData = {ssl_auth_null, nullptr,
+ nullptr, &kSctItem};
// Test timestamps extraction during a successful handshake.
-TEST_P(TlsConnectGeneric, SignedCertificateTimestampsHandshake) {
+TEST_P(TlsConnectGenericPre13, SignedCertificateTimestampsLegacy) {
EnsureTlsSetup();
+
+ // We have to use the legacy API consistently here for configuring certs.
+ // Also, this doesn't work in TLS 1.3 because this only configures the SCT for
+ // RSA decrypt and PKCS#1 signing, not PSS.
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsa, &cert, &priv));
+ EXPECT_EQ(SECSuccess, SSL_ConfigSecureServerWithCertChain(
+ server_->ssl_fd(), cert.get(), nullptr, priv.get(),
+ ssl_kea_rsa));
EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(),
&kSctItem, ssl_kea_rsa));
EXPECT_EQ(SECSuccess,
@@ -78,13 +92,10 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsHandshake) {
timestamps_extractor.assertTimestamps(kSctBuffer);
}
-TEST_P(TlsConnectGeneric, SignedCertificateTimestampsConfig) {
- static const SSLExtraServerCertData kExtraData = {ssl_auth_rsa_sign, nullptr,
- nullptr, &kSctItem};
-
+TEST_P(TlsConnectGeneric, SignedCertificateTimestampsSuccess) {
EnsureTlsSetup();
EXPECT_TRUE(
- server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraData));
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData));
EXPECT_EQ(SECSuccess,
SSL_OptionSet(client_->ssl_fd(), SSL_ENABLE_SIGNED_CERT_TIMESTAMPS,
PR_TRUE));
@@ -99,8 +110,8 @@ TEST_P(TlsConnectGeneric, SignedCertificateTimestampsConfig) {
// when the client / the server / both have not enabled the feature.
TEST_P(TlsConnectGeneric, SignedCertificateTimestampsInactiveClient) {
EnsureTlsSetup();
- EXPECT_EQ(SECSuccess, SSL_SetSignedCertTimestamps(server_->ssl_fd(),
- &kSctItem, ssl_kea_rsa));
+ EXPECT_TRUE(
+ server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kExtraSctData));
SignedCertificateTimestampsExtractor timestamps_extractor(client_);
Connect();
@@ -141,8 +152,8 @@ static const SECItem kOcspItems[] = {
{siBuffer, const_cast<uint8_t*>(kOcspValue2), sizeof(kOcspValue2)}};
static const SECItemArray kOcspResponses = {const_cast<SECItem*>(kOcspItems),
PR_ARRAY_SIZE(kOcspItems)};
-const static SSLExtraServerCertData kOcspExtraData = {
- ssl_auth_rsa_sign, nullptr, &kOcspResponses, nullptr};
+const static SSLExtraServerCertData kOcspExtraData = {ssl_auth_null, nullptr,
+ &kOcspResponses, nullptr};
TEST_P(TlsConnectGeneric, NoOcsp) {
EnsureTlsSetup();
@@ -176,10 +187,10 @@ TEST_P(TlsConnectGenericPre13, OcspMangled) {
server_->ConfigServerCert(TlsAgent::kServerRsa, true, &kOcspExtraData));
static const uint8_t val[] = {1};
- auto replacer = new TlsExtensionReplacer(ssl_cert_status_xtn,
- DataBuffer(val, sizeof(val)));
+ auto replacer = std::make_shared<TlsExtensionReplacer>(
+ ssl_cert_status_xtn, DataBuffer(val, sizeof(val)));
server_->SetPacketFilter(replacer);
- ConnectExpectFail();
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
}
@@ -188,7 +199,8 @@ TEST_P(TlsConnectGeneric, OcspSuccess) {
EnsureTlsSetup();
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
- auto capture_ocsp = new TlsExtensionCapture(ssl_cert_status_xtn);
+ auto capture_ocsp =
+ std::make_shared<TlsExtensionCapture>(ssl_cert_status_xtn);
server_->SetPacketFilter(capture_ocsp);
// The value should be available during the AuthCertificateCallback
@@ -211,4 +223,35 @@ TEST_P(TlsConnectGeneric, OcspSuccess) {
EXPECT_EQ(0U, capture_ocsp->extension().len());
}
+TEST_P(TlsConnectGeneric, OcspHugeSuccess) {
+ EnsureTlsSetup();
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
+ SSL_ENABLE_OCSP_STAPLING, PR_TRUE));
+
+ uint8_t hugeOcspValue[16385];
+ memset(hugeOcspValue, 0xa1, sizeof(hugeOcspValue));
+ const SECItem hugeOcspItems[] = {
+ {siBuffer, const_cast<uint8_t*>(hugeOcspValue), sizeof(hugeOcspValue)}};
+ const SECItemArray hugeOcspResponses = {const_cast<SECItem*>(hugeOcspItems),
+ PR_ARRAY_SIZE(hugeOcspItems)};
+ const SSLExtraServerCertData hugeOcspExtraData = {
+ ssl_auth_null, nullptr, &hugeOcspResponses, nullptr};
+
+ // The value should be available during the AuthCertificateCallback
+ client_->SetAuthCertificateCallback([&](TlsAgent* agent, bool checksig,
+ bool isServer) -> SECStatus {
+ const SECItemArray* ocsp = SSL_PeerStapledOCSPResponses(agent->ssl_fd());
+ if (!ocsp) {
+ return SECFailure;
+ }
+ EXPECT_EQ(1U, ocsp->len) << "We only provide the first item";
+ EXPECT_EQ(0, SECITEM_CompareItem(&hugeOcspItems[0], &ocsp->items[0]));
+ return SECSuccess;
+ });
+ EXPECT_TRUE(server_->ConfigServerCert(TlsAgent::kServerRsa, true,
+ &hugeOcspExtraData));
+
+ Connect();
+}
+
} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc b/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
index ab10a84..85c30b2 100644
--- a/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_ciphersuite_unittest.cc
@@ -22,17 +22,17 @@ extern "C" {
namespace nss_test {
-// mode, version, cipher suite
-typedef std::tuple<std::string, uint16_t, uint16_t, SSLNamedGroup,
+// variant, version, cipher suite
+typedef std::tuple<SSLProtocolVariant, uint16_t, uint16_t, SSLNamedGroup,
SSLSignatureScheme>
CipherSuiteProfile;
class TlsCipherSuiteTestBase : public TlsConnectTestBase {
public:
- TlsCipherSuiteTestBase(const std::string &mode, uint16_t version,
+ TlsCipherSuiteTestBase(SSLProtocolVariant variant, uint16_t version,
uint16_t cipher_suite, SSLNamedGroup group,
SSLSignatureScheme signature_scheme)
- : TlsConnectTestBase(mode, version),
+ : TlsConnectTestBase(variant, version),
cipher_suite_(cipher_suite),
group_(group),
signature_scheme_(signature_scheme),
@@ -128,16 +128,22 @@ class TlsCipherSuiteTestBase : public TlsConnectTestBase {
Connect();
SendReceive();
- // Check that we used the right cipher suite.
+ // Check that we used the right cipher suite, auth type and kea type.
uint16_t actual;
- EXPECT_TRUE(client_->cipher_suite(&actual) && actual == cipher_suite_);
- EXPECT_TRUE(server_->cipher_suite(&actual) && actual == cipher_suite_);
+ EXPECT_TRUE(client_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite_, actual);
+ EXPECT_TRUE(server_->cipher_suite(&actual));
+ EXPECT_EQ(cipher_suite_, actual);
SSLAuthType auth;
- EXPECT_TRUE(client_->auth_type(&auth) && auth == auth_type_);
- EXPECT_TRUE(server_->auth_type(&auth) && auth == auth_type_);
+ EXPECT_TRUE(client_->auth_type(&auth));
+ EXPECT_EQ(auth_type_, auth);
+ EXPECT_TRUE(server_->auth_type(&auth));
+ EXPECT_EQ(auth_type_, auth);
SSLKEAType kea;
- EXPECT_TRUE(client_->kea_type(&kea) && kea == kea_type_);
- EXPECT_TRUE(server_->kea_type(&kea) && kea == kea_type_);
+ EXPECT_TRUE(client_->kea_type(&kea));
+ EXPECT_EQ(kea_type_, kea);
+ EXPECT_TRUE(server_->kea_type(&kea));
+ EXPECT_EQ(kea_type_, kea);
}
// Get the expected limit on the number of records that can be sent for the
@@ -252,14 +258,17 @@ TEST_P(TlsCipherSuiteTest, ReadLimit) {
// authentication tag.
static const uint8_t payload[18] = {6};
DataBuffer record;
- uint64_t epoch = 0;
- if (mode_ == DGRAM) {
- epoch++;
+ uint64_t epoch;
+ if (variant_ == ssl_variant_datagram) {
if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
- epoch++;
+ epoch = 3; // Application traffic keys.
+ } else {
+ epoch = 1;
}
+ } else {
+ epoch = 0;
}
- TlsAgentTestBase::MakeRecord(mode_, kTlsApplicationDataType, version_,
+ TlsAgentTestBase::MakeRecord(variant_, kTlsApplicationDataType, version_,
payload, sizeof(payload), &record,
(epoch << 48) | record_limit());
server_->adapter()->PacketReceived(record);
@@ -287,7 +296,7 @@ TEST_P(TlsCipherSuiteTest, WriteLimit) {
k##name##Ciphers = ::testing::ValuesIn(k##name##CiphersArr); \
INSTANTIATE_TEST_CASE_P( \
CipherSuite##name, TlsCipherSuiteTest, \
- ::testing::Combine(TlsConnectTestBase::kTlsModes##modes, \
+ ::testing::Combine(TlsConnectTestBase::kTlsVariants##modes, \
TlsConnectTestBase::kTls##versions, k##name##Ciphers, \
groups, sigalgs));
@@ -396,7 +405,7 @@ class SecurityStatusTest
public ::testing::WithParamInterface<SecStatusParams> {
public:
SecurityStatusTest()
- : TlsCipherSuiteTestBase("TLS", GetParam().version,
+ : TlsCipherSuiteTestBase(ssl_variant_stream, GetParam().version,
GetParam().cipher_suite, ssl_grp_none,
ssl_sig_none) {}
};
diff --git a/nss/gtests/ssl_gtest/ssl_damage_unittest.cc b/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
index 9dadcbd..69fd003 100644
--- a/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_damage_unittest.cc
@@ -33,12 +33,14 @@ TEST_F(TlsConnectTest, DamageSecretHandleClientFinished) {
client_->StartConnect();
client_->Handshake();
server_->Handshake();
- std::cerr << "Damaging HS secret\n";
+ std::cerr << "Damaging HS secret" << std::endl;
SSLInt_DamageClientHsTrafficSecret(server_->ssl_fd());
client_->Handshake();
- server_->Handshake();
// The client thinks it has connected.
EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+ ExpectAlert(server_, kTlsAlertDecryptError);
+ server_->Handshake();
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
client_->Handshake();
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
@@ -49,7 +51,10 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
SSL_LIBRARY_VERSION_TLS_1_3);
- server_->SetPacketFilter(new AfterRecordN(
+ client_->ExpectSendAlert(kTlsAlertDecryptError);
+ // The server can't read the client's alert, so it also sends an alert.
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->SetPacketFilter(std::make_shared<AfterRecordN>(
server_, client_,
0, // ServerHello.
[this]() { SSLInt_DamageServerHsTrafficSecret(client_->ssl_fd()); }));
@@ -58,4 +63,57 @@ TEST_F(TlsConnectTest, DamageSecretHandleServerFinished) {
server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
}
+TEST_P(TlsConnectGenericPre13, DamageServerSignature) {
+ EnsureTlsSetup();
+ auto filter =
+ std::make_shared<TlsLastByteDamager>(kTlsHandshakeServerKeyExchange);
+ server_->SetTlsRecordFilter(filter);
+ ExpectAlert(client_, kTlsAlertDecryptError);
+ ConnectExpectFail();
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+ server_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+}
+
+TEST_P(TlsConnectTls13, DamageServerSignature) {
+ EnsureTlsSetup();
+ auto filter =
+ std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
+ server_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ client_->ExpectSendAlert(kTlsAlertDecryptError);
+ // The server can't read the client's alert, so it also sends an alert.
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ client_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+}
+
+TEST_P(TlsConnectGeneric, DamageClientSignature) {
+ EnsureTlsSetup();
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ auto filter =
+ std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify);
+ client_->SetTlsRecordFilter(filter);
+ server_->ExpectSendAlert(kTlsAlertDecryptError);
+ filter->EnableDecryption();
+ // Do these handshakes by hand to avoid race condition on
+ // the client processing the server's alert.
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ server_->Handshake();
+ client_->Handshake();
+ server_->Handshake();
+ EXPECT_EQ(version_ >= SSL_LIBRARY_VERSION_TLS_1_3
+ ? TlsAgent::STATE_CONNECTED
+ : TlsAgent::STATE_CONNECTING,
+ client_->state());
+ server_->CheckErrorCode(SEC_ERROR_BAD_SIGNATURE);
+}
+
} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc b/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
index 82d5558..9794330 100644
--- a/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_dhe_unittest.cc
@@ -31,12 +31,13 @@ TEST_P(TlsConnectTls13, SharesForBothEcdheAndDhe) {
EnsureTlsSetup();
client_->ConfigNamedGroups(kAllDHEGroups);
- auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn);
- auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
- std::vector<PacketFilter*> captures;
- captures.push_back(groups_capture);
- captures.push_back(shares_capture);
- client_->SetPacketFilter(new ChainedPacketFilter(captures));
+ auto groups_capture =
+ std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ auto shares_capture =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
+ shares_capture};
+ client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
Connect();
@@ -60,12 +61,13 @@ TEST_P(TlsConnectGeneric, ConnectFfdheClient) {
EnableOnlyDheCiphers();
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
- auto groups_capture = new TlsExtensionCapture(ssl_supported_groups_xtn);
- auto shares_capture = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
- std::vector<PacketFilter*> captures;
- captures.push_back(groups_capture);
- captures.push_back(shares_capture);
- client_->SetPacketFilter(new ChainedPacketFilter(captures));
+ auto groups_capture =
+ std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ auto shares_capture =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ std::vector<std::shared_ptr<PacketFilter>> captures = {groups_capture,
+ shares_capture};
+ client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
Connect();
@@ -95,7 +97,7 @@ TEST_P(TlsConnectGenericPre13, ConnectFfdheServer) {
Connect();
CheckKeys(ssl_kea_dh, ssl_auth_rsa_sign);
} else {
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
@@ -126,9 +128,9 @@ TEST_P(TlsConnectGenericPre13, DamageServerKeyShare) {
EnableOnlyDheCiphers();
EXPECT_EQ(SECSuccess, SSL_OptionSet(client_->ssl_fd(),
SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
- server_->SetPacketFilter(new TlsDheServerKeyExchangeDamager());
+ server_->SetPacketFilter(std::make_shared<TlsDheServerKeyExchangeDamager>());
- ConnectExpectFail();
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_WEAK_SERVER_EPHEMERAL_DH_KEY);
server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
@@ -249,8 +251,9 @@ class TlsDheSkeChangeYServer : public TlsDheSkeChangeY {
class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
public:
- TlsDheSkeChangeYClient(ChangeYTo change,
- const TlsDheSkeChangeYServer* server_filter)
+ TlsDheSkeChangeYClient(
+ ChangeYTo change,
+ std::shared_ptr<const TlsDheSkeChangeYServer> server_filter)
: TlsDheSkeChangeY(change), server_filter_(server_filter) {}
protected:
@@ -266,13 +269,14 @@ class TlsDheSkeChangeYClient : public TlsDheSkeChangeY {
}
private:
- const TlsDheSkeChangeYServer* server_filter_;
+ std::shared_ptr<const TlsDheSkeChangeYServer> server_filter_;
};
-/* This matrix includes: mode (stream/datagram), TLS version, what change to
+/* This matrix includes: variant (stream/datagram), TLS version, what change to
* make to dh_Ys, whether the client will be configured to require DH named
* groups. Test all combinations. */
-typedef std::tuple<std::string, uint16_t, TlsDheSkeChangeY::ChangeYTo, bool>
+typedef std::tuple<SSLProtocolVariant, uint16_t, TlsDheSkeChangeY::ChangeYTo,
+ bool>
DamageDHYProfile;
class TlsDamageDHYTest
: public TlsConnectTestBase,
@@ -289,8 +293,14 @@ TEST_P(TlsDamageDHYTest, DamageServerY) {
SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
}
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
- server_->SetPacketFilter(new TlsDheSkeChangeYServer(change, true));
+ server_->SetPacketFilter(
+ std::make_shared<TlsDheSkeChangeYServer>(change, true));
+ if (change == TlsDheSkeChangeY::kYZeroPad) {
+ ExpectAlert(client_, kTlsAlertDecryptError);
+ } else {
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ }
ConnectExpectFail();
if (change == TlsDheSkeChangeY::kYZeroPad) {
// Zero padding Y only manifests in a signature failure.
@@ -314,14 +324,20 @@ TEST_P(TlsDamageDHYTest, DamageClientY) {
SSL_REQUIRE_DH_NAMED_GROUPS, PR_TRUE));
}
// The filter on the server is required to capture the prime.
- TlsDheSkeChangeYServer* server_filter =
- new TlsDheSkeChangeYServer(TlsDheSkeChangeY::kYZero, false);
+ auto server_filter =
+ std::make_shared<TlsDheSkeChangeYServer>(TlsDheSkeChangeY::kYZero, false);
server_->SetPacketFilter(server_filter);
// The client filter does the damage.
TlsDheSkeChangeY::ChangeYTo change = std::get<2>(GetParam());
- client_->SetPacketFilter(new TlsDheSkeChangeYClient(change, server_filter));
+ client_->SetPacketFilter(
+ std::make_shared<TlsDheSkeChangeYClient>(change, server_filter));
+ if (change == TlsDheSkeChangeY::kYZeroPad) {
+ ExpectAlert(server_, kTlsAlertDecryptError);
+ } else {
+ ExpectAlert(server_, kTlsAlertHandshakeFailure);
+ }
ConnectExpectFail();
if (change == TlsDheSkeChangeY::kYZeroPad) {
// Zero padding Y only manifests in a finished error.
@@ -343,13 +359,13 @@ static const bool kTrueFalseArr[] = {true, false};
static ::testing::internal::ParamGenerator<bool> kTrueFalse =
::testing::ValuesIn(kTrueFalseArr);
-INSTANTIATE_TEST_CASE_P(DamageYStream, TlsDamageDHYTest,
- ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
- TlsConnectTestBase::kTlsV10ToV12,
- kAllY, kTrueFalse));
+INSTANTIATE_TEST_CASE_P(
+ DamageYStream, TlsDamageDHYTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12, kAllY, kTrueFalse));
INSTANTIATE_TEST_CASE_P(
DamageYDatagram, TlsDamageDHYTest,
- ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
TlsConnectTestBase::kTlsV11V12, kAllY, kTrueFalse));
class TlsDheSkeMakePEven : public TlsHandshakeFilter {
@@ -378,9 +394,9 @@ class TlsDheSkeMakePEven : public TlsHandshakeFilter {
// Even without requiring named groups, an even value for p is bad news.
TEST_P(TlsConnectGenericPre13, MakeDhePEven) {
EnableOnlyDheCiphers();
- server_->SetPacketFilter(new TlsDheSkeMakePEven());
+ server_->SetPacketFilter(std::make_shared<TlsDheSkeMakePEven>());
- ConnectExpectFail();
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_DHE_KEY_SHARE);
server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
@@ -409,9 +425,9 @@ class TlsDheSkeZeroPadP : public TlsHandshakeFilter {
// Zero padding only causes signature failure.
TEST_P(TlsConnectGenericPre13, PadDheP) {
EnableOnlyDheCiphers();
- server_->SetPacketFilter(new TlsDheSkeZeroPadP());
+ server_->SetPacketFilter(std::make_shared<TlsDheSkeZeroPadP>());
- ConnectExpectFail();
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
// In TLS 1.0 and 1.1, the client reports a device error.
if (version_ < SSL_LIBRARY_VERSION_TLS_1_2) {
@@ -470,7 +486,7 @@ TEST_P(TlsConnectTls13, NamedGroupMismatch13) {
server_->ConfigNamedGroups(server_groups);
client_->ConfigNamedGroups(client_groups);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
@@ -488,7 +504,7 @@ TEST_P(TlsConnectGenericPre13, RequireNamedGroupsMismatchPre13) {
server_->ConfigNamedGroups(server_groups);
client_->ConfigNamedGroups(client_groups);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
@@ -518,7 +534,7 @@ TEST_P(TlsConnectGenericPre13, MismatchDHE) {
EXPECT_EQ(SECSuccess, SSL_DHEGroupPrefSet(client_->ssl_fd(), clientGroups,
PR_ARRAY_SIZE(clientGroups)));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
@@ -533,11 +549,11 @@ TEST_P(TlsConnectTls13, ResumeFfdhe) {
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
EnableOnlyDheCiphers();
- TlsExtensionCapture* clientCapture =
- new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
+ auto clientCapture =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
client_->SetPacketFilter(clientCapture);
- TlsExtensionCapture* serverCapture =
- new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
+ auto serverCapture =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
server_->SetPacketFilter(serverCapture);
ExpectResumption(RESUME_TICKET);
Connect();
@@ -599,10 +615,10 @@ TEST_P(TlsConnectGenericPre13, InvalidDERSignatureFfdhe) {
const std::vector<SSLNamedGroup> client_groups = {ssl_grp_ffdhe_2048};
client_->ConfigNamedGroups(client_groups);
- server_->SetPacketFilter(new TlsDheSkeChangeSignature(
+ server_->SetPacketFilter(std::make_shared<TlsDheSkeChangeSignature>(
version_, kBogusDheSignature, sizeof(kBogusDheSignature)));
- ConnectExpectFail();
+ ConnectExpectAlert(client_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
}
diff --git a/nss/gtests/ssl_gtest/ssl_drop_unittest.cc b/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
index 89ca28e..3cc3b0e 100644
--- a/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_drop_unittest.cc
@@ -21,13 +21,13 @@ extern "C" {
namespace nss_test {
TEST_P(TlsConnectDatagram, DropClientFirstFlightOnce) {
- client_->SetPacketFilter(new SelectiveDropFilter(0x1));
+ client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
- server_->SetPacketFilter(new SelectiveDropFilter(0x1));
+ server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x1));
Connect();
SendReceive();
}
@@ -36,32 +36,32 @@ TEST_P(TlsConnectDatagram, DropServerFirstFlightOnce) {
// flights that they send. Note: In DTLS 1.3, the shorter handshake means that
// this will also drop some application data, so we can't call SendReceive().
TEST_P(TlsConnectDatagram, DropAllFirstTransmissions) {
- client_->SetPacketFilter(new SelectiveDropFilter(0x15));
- server_->SetPacketFilter(new SelectiveDropFilter(0x5));
+ client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x15));
+ server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x5));
Connect();
}
// This drops the server's first flight three times.
TEST_P(TlsConnectDatagram, DropServerFirstFlightThrice) {
- server_->SetPacketFilter(new SelectiveDropFilter(0x7));
+ server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x7));
Connect();
}
// This drops the client's second flight once
TEST_P(TlsConnectDatagram, DropClientSecondFlightOnce) {
- client_->SetPacketFilter(new SelectiveDropFilter(0x2));
+ client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
// This drops the client's second flight three times.
TEST_P(TlsConnectDatagram, DropClientSecondFlightThrice) {
- client_->SetPacketFilter(new SelectiveDropFilter(0xe));
+ client_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
// This drops the server's second flight three times.
TEST_P(TlsConnectDatagram, DropServerSecondFlightThrice) {
- server_->SetPacketFilter(new SelectiveDropFilter(0xe));
+ server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0xe));
Connect();
}
diff --git a/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc b/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
index 43dfcba..1e406b6 100644
--- a/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_ecdh_unittest.cc
@@ -58,7 +58,7 @@ TEST_P(TlsConnectTls12, ConnectEcdheP384) {
Reset(TlsAgent::kServerEcdsa384);
ConnectWithCipherSuite(TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256);
CheckKeys(ssl_kea_ecdh, ssl_grp_ec_secp384r1, ssl_auth_ecdsa,
- ssl_sig_ecdsa_secp384r1_sha384);
+ ssl_sig_ecdsa_secp256r1_sha256);
}
TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) {
@@ -75,8 +75,8 @@ TEST_P(TlsConnectGeneric, ConnectEcdheP384Client) {
// This causes a HelloRetryRequest in TLS 1.3. Earlier versions don't care.
TEST_P(TlsConnectGeneric, ConnectEcdheP384Server) {
EnsureTlsSetup();
- auto hrr_capture =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest);
+ auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeHelloRetryRequest);
server_->SetPacketFilter(hrr_capture);
const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
server_->ConfigNamedGroups(groups);
@@ -191,6 +191,60 @@ TEST_P(TlsConnectGenericPre13, P384PriorityFromModelSocket) {
ssl_sig_rsa_pss_sha256);
}
+class TlsKeyExchangeGroupCapture : public TlsHandshakeFilter {
+ public:
+ TlsKeyExchangeGroupCapture() : group_(ssl_grp_none) {}
+
+ SSLNamedGroup group() const { return group_; }
+
+ protected:
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader &header,
+ const DataBuffer &input,
+ DataBuffer *output) {
+ if (header.handshake_type() != kTlsHandshakeServerKeyExchange) {
+ return KEEP;
+ }
+
+ uint32_t value = 0;
+ EXPECT_TRUE(input.Read(0, 1, &value));
+ EXPECT_EQ(3U, value) << "curve type has to be 3";
+
+ EXPECT_TRUE(input.Read(1, 2, &value));
+ group_ = static_cast<SSLNamedGroup>(value);
+
+ return KEEP;
+ }
+
+ private:
+ SSLNamedGroup group_;
+};
+
+// If we strip the client's supported groups extension, the server should assume
+// P-256 is supported by the client (<= 1.2 only).
+TEST_P(TlsConnectGenericPre13, DropSupportedGroupExtensionP256) {
+ EnsureTlsSetup();
+ client_->SetPacketFilter(
+ std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn));
+ auto group_capture = std::make_shared<TlsKeyExchangeGroupCapture>();
+ server_->SetPacketFilter(group_capture);
+
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+
+ EXPECT_EQ(ssl_grp_ec_secp256r1, group_capture->group());
+}
+
+// Supported groups is mandatory in TLS 1.3.
+TEST_P(TlsConnectTls13, DropSupportedGroupExtension) {
+ EnsureTlsSetup();
+ client_->SetPacketFilter(
+ std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn));
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
+}
+
// If we only have a lame group, we fall back to static RSA.
TEST_P(TlsConnectGenericPre13, UseLameGroup) {
const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp192r1};
@@ -431,7 +485,7 @@ TEST_P(TlsConnectGeneric, P256ClientAndCurve25519Server) {
client_->ConfigNamedGroups(client_groups);
server_->ConfigNamedGroups(server_groups);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
@@ -507,25 +561,25 @@ class ECCServerKEXFilter : public TlsHandshakeFilter {
TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyServerPoint) {
// add packet filter
- server_->SetPacketFilter(new ECCServerKEXFilter());
- ConnectExpectFail();
+ server_->SetPacketFilter(std::make_shared<ECCServerKEXFilter>());
+ ConnectExpectAlert(client_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_KEY_EXCH);
}
TEST_P(TlsConnectGenericPre13, ConnectECDHEmptyClientPoint) {
// add packet filter
- client_->SetPacketFilter(new ECCClientKEXFilter());
- ConnectExpectFail();
+ client_->SetPacketFilter(std::make_shared<ECCClientKEXFilter>());
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_KEY_EXCH);
}
INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV11Plus));
#ifndef NSS_DISABLE_TLS_1_3
INSTANTIATE_TEST_CASE_P(KeyExchangeTest, TlsKeyExchangeTest13,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV13));
#endif
diff --git a/nss/gtests/ssl_gtest/ssl_ems_unittest.cc b/nss/gtests/ssl_gtest/ssl_ems_unittest.cc
index b9c725b..dad6ca0 100644
--- a/nss/gtests/ssl_gtest/ssl_ems_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_ems_unittest.cc
@@ -79,11 +79,7 @@ TEST_P(TlsConnectGenericPre13, ConnectExtendedMasterSecretResumeWithout) {
Reset();
server_->EnableExtendedMasterSecret();
- auto alert_recorder = new TlsAlertRecorder();
- server_->SetPacketFilter(alert_recorder);
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertHandshakeFailure, alert_recorder->description());
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
}
TEST_P(TlsConnectGenericPre13, ConnectNormalResumeWithExtendedMasterSecret) {
diff --git a/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc b/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
index 0a0d9f2..be407b4 100644
--- a/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_exporter_unittest.cc
@@ -14,7 +14,8 @@ namespace nss_test {
static const char* kExporterLabel = "EXPORTER-duck";
static const uint8_t kExporterContext[] = {0x12, 0x34, 0x56};
-static void ExportAndCompare(TlsAgent* client, TlsAgent* server, bool context) {
+static void ExportAndCompare(std::shared_ptr<TlsAgent>& client,
+ std::shared_ptr<TlsAgent>& server, bool context) {
static const size_t exporter_len = 10;
uint8_t client_value[exporter_len] = {0};
EXPECT_EQ(SECSuccess,
@@ -76,6 +77,33 @@ TEST_P(TlsConnectTls13, ExporterContextEmptyIsSameAsNone) {
ExportAndCompare(client_, server_, false);
}
+TEST_P(TlsConnectGenericPre13, ExporterContextLengthTooLong) {
+ static const uint8_t kExporterContextTooLong[PR_UINT16_MAX] = {
+ 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC, 0xDE, 0xFF};
+
+ EnsureTlsSetup();
+ Connect();
+ CheckKeys();
+
+ static const size_t exporter_len = 10;
+ uint8_t client_value[exporter_len] = {0};
+ EXPECT_EQ(SECFailure,
+ SSL_ExportKeyingMaterial(client_->ssl_fd(), kExporterLabel,
+ strlen(kExporterLabel), PR_TRUE,
+ kExporterContextTooLong,
+ sizeof(kExporterContextTooLong),
+ client_value, sizeof(client_value)));
+ EXPECT_EQ(PORT_GetError(), SEC_ERROR_INVALID_ARGS);
+ uint8_t server_value[exporter_len] = {0xff};
+ EXPECT_EQ(SECFailure,
+ SSL_ExportKeyingMaterial(server_->ssl_fd(), kExporterLabel,
+ strlen(kExporterLabel), PR_TRUE,
+ kExporterContextTooLong,
+ sizeof(kExporterContextTooLong),
+ server_value, sizeof(server_value)));
+ EXPECT_EQ(PORT_GetError(), SEC_ERROR_INVALID_ARGS);
+}
+
// This has a weird signature so that it can be passed to the SNI callback.
int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr,
PRUint32 srvNameArrSize) {
@@ -90,13 +118,15 @@ int32_t RegularExporterShouldFail(TlsAgent* agent, const SECItem* srvNameArr,
TEST_P(TlsConnectTls13, EarlyExporter) {
SetupForZeroRtt();
+ ExpectAlert(client_, kTlsAlertEndOfEarlyData);
client_->Set0RttEnabled(true);
server_->Set0RttEnabled(true);
ExpectResumption(RESUME_TICKET);
client_->Handshake(); // Send ClientHello.
uint8_t client_value[10] = {0};
- RegularExporterShouldFail(client_, nullptr, 0);
+ RegularExporterShouldFail(client_.get(), nullptr, 0);
+
EXPECT_EQ(SECSuccess,
SSL_ExportEarlyKeyingMaterial(
client_->ssl_fd(), kExporterLabel, strlen(kExporterLabel),
diff --git a/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
index 9200e72..d151394 100644
--- a/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -69,22 +69,11 @@ class TlsExtensionInjector : public TlsHandshakeFilter {
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) {
- size_t offset;
- if (header.handshake_type() == kTlsHandshakeClientHello) {
- TlsParser parser(input);
- if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
- return KEEP;
- }
- offset = parser.consumed();
- } else if (header.handshake_type() == kTlsHandshakeServerHello) {
- TlsParser parser(input);
- if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) {
- return KEEP;
- }
- offset = parser.consumed();
- } else {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
return KEEP;
}
+ size_t offset = parser.consumed();
*output = input;
@@ -116,38 +105,41 @@ class TlsExtensionInjector : public TlsHandshakeFilter {
class TlsExtensionAppender : public TlsHandshakeFilter {
public:
- TlsExtensionAppender(uint16_t ext, DataBuffer& data)
- : extension_(ext), data_(data) {}
+ TlsExtensionAppender(uint8_t handshake_type, uint16_t ext, DataBuffer& data)
+ : handshake_type_(handshake_type), extension_(ext), data_(data) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output) {
- size_t offset;
+ if (header.handshake_type() != handshake_type_) {
+ return KEEP;
+ }
+
TlsParser parser(input);
- if (header.handshake_type() == kTlsHandshakeClientHello) {
- if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
- return KEEP;
- }
- } else if (header.handshake_type() == kTlsHandshakeServerHello) {
- if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) {
- return KEEP;
- }
- } else {
+ if (!TlsExtensionFilter::FindExtensions(&parser, header)) {
return KEEP;
}
- offset = parser.consumed();
*output = input;
- uint32_t ext_len;
- if (!parser.Read(&ext_len, 2)) {
- ADD_FAILURE();
+ // Increase the length of the extensions block.
+ if (!UpdateLength(output, parser.consumed(), 2)) {
return KEEP;
}
- ext_len += 4 + data_.len();
- output->Write(offset, ext_len, 2);
+ // Extensions in Certificate are nested twice. Increase the size of the
+ // certificate list.
+ if (header.handshake_type() == kTlsHandshakeCertificate) {
+ TlsParser p2(input);
+ if (!p2.SkipVariable(1)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+ if (!UpdateLength(output, p2.consumed(), 3)) {
+ return KEEP;
+ }
+ }
- offset = output->len();
+ size_t offset = output->len();
offset = output->Write(offset, extension_, 2);
WriteVariable(output, offset, data_, 2);
@@ -155,39 +147,38 @@ class TlsExtensionAppender : public TlsHandshakeFilter {
}
private:
+ bool UpdateLength(DataBuffer* output, size_t offset, size_t size) {
+ uint32_t len;
+ if (!output->Read(offset, size, &len)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ len += 4 + data_.len();
+ output->Write(offset, len, size);
+ return true;
+ }
+
+ const uint8_t handshake_type_;
const uint16_t extension_;
const DataBuffer data_;
};
class TlsExtensionTestBase : public TlsConnectTestBase {
protected:
- TlsExtensionTestBase(Mode mode, uint16_t version)
- : TlsConnectTestBase(mode, version) {}
- TlsExtensionTestBase(const std::string& mode, uint16_t version)
- : TlsConnectTestBase(mode, version) {}
-
- void ClientHelloErrorTest(PacketFilter* filter,
- uint8_t alert = kTlsAlertDecodeError) {
- auto alert_recorder = new TlsAlertRecorder();
- server_->SetPacketFilter(alert_recorder);
- if (filter) {
- client_->SetPacketFilter(filter);
- }
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(alert, alert_recorder->description());
+ TlsExtensionTestBase(SSLProtocolVariant variant, uint16_t version)
+ : TlsConnectTestBase(variant, version) {}
+
+ void ClientHelloErrorTest(std::shared_ptr<PacketFilter> filter,
+ uint8_t desc = kTlsAlertDecodeError) {
+ client_->SetPacketFilter(filter);
+ ConnectExpectAlert(server_, desc);
}
- void ServerHelloErrorTest(PacketFilter* filter,
- uint8_t alert = kTlsAlertDecodeError) {
- auto alert_recorder = new TlsAlertRecorder();
- client_->SetPacketFilter(alert_recorder);
- if (filter) {
- server_->SetPacketFilter(filter);
- }
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(alert, alert_recorder->description());
+ void ServerHelloErrorTest(std::shared_ptr<PacketFilter> filter,
+ uint8_t desc = kTlsAlertDecodeError) {
+ server_->SetPacketFilter(filter);
+ ConnectExpectAlert(client_, desc);
}
static void InitSimpleSni(DataBuffer* extension) {
@@ -213,7 +204,7 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
server_->StartConnect();
client_->Handshake(); // Send ClientHello
server_->Handshake(); // Send HRR.
- client_->SetPacketFilter(new TlsExtensionDropper(type));
+ client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(type));
Handshake();
client_->CheckErrorCode(client_error);
server_->CheckErrorCode(server_error);
@@ -223,38 +214,40 @@ class TlsExtensionTestBase : public TlsConnectTestBase {
class TlsExtensionTestDtls : public TlsExtensionTestBase,
public ::testing::WithParamInterface<uint16_t> {
public:
- TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {}
+ TlsExtensionTestDtls()
+ : TlsExtensionTestBase(ssl_variant_datagram, GetParam()) {}
};
-class TlsExtensionTest12Plus
- : public TlsExtensionTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsExtensionTest12Plus : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsExtensionTest12Plus()
: TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
}
};
-class TlsExtensionTest12
- : public TlsExtensionTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsExtensionTest12 : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsExtensionTest12()
: TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
}
};
-class TlsExtensionTest13 : public TlsExtensionTestBase,
- public ::testing::WithParamInterface<std::string> {
+class TlsExtensionTest13
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
public:
TlsExtensionTest13()
: TlsExtensionTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) {
DataBuffer versions_buf(buf, len);
- client_->SetPacketFilter(new TlsExtensionReplacer(
+ client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
ssl_tls13_supported_versions_xtn, versions_buf));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
}
@@ -264,7 +257,7 @@ class TlsExtensionTest13 : public TlsExtensionTestBase,
size_t index = versions_buf.Write(0, 2, 1);
versions_buf.Write(index, version, 2);
- client_->SetPacketFilter(new TlsExtensionReplacer(
+ client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
ssl_tls13_supported_versions_xtn, versions_buf));
ConnectExpectFail();
}
@@ -273,21 +266,21 @@ class TlsExtensionTest13 : public TlsExtensionTestBase,
class TlsExtensionTest13Stream : public TlsExtensionTestBase {
public:
TlsExtensionTest13Stream()
- : TlsExtensionTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+ : TlsExtensionTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
-class TlsExtensionTestGeneric
- : public TlsExtensionTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsExtensionTestGeneric : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsExtensionTestGeneric()
: TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
}
};
-class TlsExtensionTestPre13
- : public TlsExtensionTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsExtensionTestPre13 : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsExtensionTestPre13()
: TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
@@ -295,23 +288,27 @@ class TlsExtensionTestPre13
};
TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
- ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 1));
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
- ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 4));
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionDamager>(ssl_server_name_xtn, 4));
}
TEST_P(TlsExtensionTestGeneric, TruncateSni) {
- ClientHelloErrorTest(new TlsExtensionTruncator(ssl_server_name_xtn, 7));
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionTruncator>(ssl_server_name_xtn, 7));
}
// A valid extension that appears twice will be reported as unsupported.
TEST_P(TlsExtensionTestGeneric, RepeatSni) {
DataBuffer extension;
InitSimpleSni(&extension);
- ClientHelloErrorTest(new TlsExtensionInjector(ssl_server_name_xtn, extension),
- kTlsAlertIllegalParameter);
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionInjector>(ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
}
// An SNI entry with zero length is considered invalid (strangely, not if it is
@@ -324,7 +321,7 @@ TEST_P(TlsExtensionTestGeneric, BadSni) {
extension.Write(0, static_cast<uint32_t>(0), 3);
extension.Write(3, simple);
ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_server_name_xtn, extension));
+ std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptySni) {
@@ -332,15 +329,15 @@ TEST_P(TlsExtensionTestGeneric, EmptySni) {
extension.Allocate(2);
extension.Write(0, static_cast<uint32_t>(0), 2);
ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_server_name_xtn, extension));
+ std::make_shared<TlsExtensionReplacer>(ssl_server_name_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
EnableAlpn();
DataBuffer extension;
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension),
- kTlsAlertIllegalParameter);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertIllegalParameter);
}
// An empty ALPN isn't considered bad, though it does lead to there being no
@@ -349,30 +346,30 @@ TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
EnableAlpn();
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension),
- kTlsAlertNoApplicationProtocol);
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertNoApplicationProtocol);
}
TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
EnableAlpn();
ClientHelloErrorTest(
- new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 1));
+ std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 1));
}
TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
EnableAlpn();
// This will leave the length of the second entry, but no value.
ClientHelloErrorTest(
- new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 5));
+ std::make_shared<TlsExtensionTruncator>(ssl_app_layer_protocol_xtn, 5));
}
TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
EnableAlpn();
const uint8_t val[] = {0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
@@ -390,158 +387,169 @@ TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) {
EnableAlpn();
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
- ServerHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
EnableAlpn();
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
- ServerHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
EnableAlpn();
const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
- ServerHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
EnableAlpn();
const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62};
DataBuffer extension(val, sizeof(val));
- ServerHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
EnableAlpn();
const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00};
DataBuffer extension(val, sizeof(val));
- ServerHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
EnableAlpn();
const uint8_t val[] = {0x00, 0x02, 0x99, 0x61};
DataBuffer extension(val, sizeof(val));
- ServerHelloErrorTest(
- new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedUnknownName) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x02, 0x01, 0x67};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertIllegalParameter);
}
TEST_P(TlsExtensionTestDtls, SrtpShort) {
EnableSrtp();
- ClientHelloErrorTest(new TlsExtensionTruncator(ssl_use_srtp_xtn, 3));
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionTruncator>(ssl_use_srtp_xtn, 3));
}
TEST_P(TlsExtensionTestDtls, SrtpOdd) {
EnableSrtp();
const uint8_t val[] = {0x00, 0x01, 0xff, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(new TlsExtensionReplacer(ssl_use_srtp_xtn, extension));
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionReplacer>(ssl_use_srtp_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
const uint8_t val[] = {0x00, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
const uint8_t val[] = {0x00, 0x01, 0x04};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_signature_algorithms_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) {
- ClientHelloErrorTest(new TlsExtensionDropper(ssl_supported_groups_xtn),
- version_ < SSL_LIBRARY_VERSION_TLS_1_3
- ? kTlsAlertDecryptError
- : kTlsAlertMissingExtension);
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionDropper>(ssl_supported_groups_xtn),
+ version_ < SSL_LIBRARY_VERSION_TLS_1_3 ? kTlsAlertDecryptError
+ : kTlsAlertMissingExtension);
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
const uint8_t val[] = {0x00, 0x01, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
const uint8_t val[] = {0x09, 0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_elliptic_curves_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) {
const uint8_t val[] = {0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) {
const uint8_t val[] = {0x99, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) {
const uint8_t val[] = {0x01, 0x00, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_ec_point_formats_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) {
const uint8_t val[] = {0x99};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_renegotiation_info_xtn, extension));
}
TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) {
const uint8_t val[] = {0x01, 0x00};
DataBuffer extension(val, sizeof(val));
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_renegotiation_info_xtn, extension));
}
// The extension has to contain a length.
TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) {
DataBuffer extension;
- ClientHelloErrorTest(
- new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+ ClientHelloErrorTest(std::make_shared<TlsExtensionReplacer>(
+ ssl_renegotiation_info_xtn, extension));
}
// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl
@@ -550,8 +558,8 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512,
ssl_sig_rsa_pss_sha384};
- TlsExtensionCapture* capture =
- new TlsExtensionCapture(ssl_signature_algorithms_xtn);
+ auto capture =
+ std::make_shared<TlsExtensionCapture>(ssl_signature_algorithms_xtn);
client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes));
client_->SetPacketFilter(capture);
EnableOnlyStaticRsaCiphers();
@@ -571,8 +579,9 @@ TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
// Temporary test to verify that we choke on an empty ClientKeyShare.
// This test will fail when we implement HelloRetryRequest.
TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
- ClientHelloErrorTest(new TlsExtensionTruncator(ssl_tls13_key_share_xtn, 2),
- kTlsAlertHandshakeFailure);
+ ClientHelloErrorTest(
+ std::make_shared<TlsExtensionTruncator>(ssl_tls13_key_share_xtn, 2),
+ kTlsAlertHandshakeFailure);
}
// These tests only work in stream mode because the client sends a
@@ -581,7 +590,10 @@ TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
// packet gets dropped.
TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) {
EnsureTlsSetup();
- server_->SetPacketFilter(new TlsExtensionDropper(ssl_tls13_key_share_xtn));
+ server_->SetPacketFilter(
+ std::make_shared<TlsExtensionDropper>(ssl_tls13_key_share_xtn));
+ client_->ExpectSendAlert(kTlsAlertMissingExtension);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code());
EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
@@ -600,7 +612,9 @@ TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) {
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
server_->SetPacketFilter(
- new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf));
+ std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code());
EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
@@ -620,7 +634,9 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
DataBuffer buf(key_share, sizeof(key_share));
EnsureTlsSetup();
server_->SetPacketFilter(
- new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf));
+ std::make_shared<TlsExtensionReplacer>(ssl_tls13_key_share_xtn, buf));
+ client_->ExpectSendAlert(kTlsAlertMissingExtension);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code());
EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
@@ -629,8 +645,10 @@ TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) {
SetupForResume();
DataBuffer empty;
- server_->SetPacketFilter(
- new TlsExtensionInjector(ssl_signature_algorithms_xtn, empty));
+ server_->SetPacketFilter(std::make_shared<TlsExtensionInjector>(
+ ssl_signature_algorithms_xtn, empty));
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
EXPECT_EQ(SSL_ERROR_EXTENSION_DISALLOWED_FOR_VERSION, client_->error_code());
EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
@@ -763,9 +781,9 @@ class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
SetupForResume();
- client_->SetPacketFilter(new TlsPreSharedKeyReplacer([](
+ client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); }));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
}
@@ -775,10 +793,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
SetupForResume();
client_->SetPacketFilter(
- new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
}));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
}
@@ -788,10 +806,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
SetupForResume();
client_->SetPacketFilter(
- new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
r->binders_[0].Write(r->binders_[0].len(), 0xff, 1);
}));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
}
@@ -800,9 +818,9 @@ TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
SetupForResume();
- client_->SetPacketFilter(new TlsPreSharedKeyReplacer(
+ client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>(
[](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
}
@@ -813,11 +831,11 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
SetupForResume();
client_->SetPacketFilter(
- new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
r->binders_.push_back(r->binders_[0]);
}));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
}
@@ -828,10 +846,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
SetupForResume();
client_->SetPacketFilter(
- new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ std::make_shared<TlsPreSharedKeyReplacer>([](TlsPreSharedKeyReplacer* r) {
r->identities_.push_back(r->identities_[0]);
}));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
}
@@ -839,9 +857,9 @@ TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) {
SetupForResume();
- client_->SetPacketFilter(new TlsPreSharedKeyReplacer([](
+ client_->SetPacketFilter(std::make_shared<TlsPreSharedKeyReplacer>([](
TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); }));
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
}
@@ -851,10 +869,10 @@ TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
const uint8_t empty_buf[] = {0};
DataBuffer empty(empty_buf, 0);
- client_->SetPacketFilter(
- // Inject an unused extension.
- new TlsExtensionAppender(0xffff, empty));
- ConnectExpectFail();
+ // Inject an unused extension after the PSK extension.
+ client_->SetPacketFilter(std::make_shared<TlsExtensionAppender>(
+ kTlsHandshakeClientHello, 0xffff, empty));
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
}
@@ -863,9 +881,9 @@ TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
SetupForResume();
DataBuffer empty;
- client_->SetPacketFilter(
- new TlsExtensionDropper(ssl_tls13_psk_key_exchange_modes_xtn));
- ConnectExpectFail();
+ client_->SetPacketFilter(std::make_shared<TlsExtensionDropper>(
+ ssl_tls13_psk_key_exchange_modes_xtn));
+ ConnectExpectAlert(server_, kTlsAlertMissingExtension);
client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
}
@@ -879,8 +897,10 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
kTls13PskKe};
DataBuffer modes(ke_modes, sizeof(ke_modes));
- client_->SetPacketFilter(
- new TlsExtensionReplacer(ssl_tls13_psk_key_exchange_modes_xtn, modes));
+ client_->SetPacketFilter(std::make_shared<TlsExtensionReplacer>(
+ ssl_tls13_psk_key_exchange_modes_xtn, modes));
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
@@ -888,7 +908,8 @@ TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
- auto capture = new TlsExtensionCapture(ssl_tls13_psk_key_exchange_modes_xtn);
+ auto capture = std::make_shared<TlsExtensionCapture>(
+ ssl_tls13_psk_key_exchange_modes_xtn);
client_->SetPacketFilter(capture);
Connect();
EXPECT_FALSE(capture->captured());
@@ -899,6 +920,7 @@ TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
// 1. Both sides only support TLS 1.3, so we get a cipher version
// error.
TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) {
+ ExpectAlert(server_, kTlsAlertProtocolVersion);
ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
@@ -909,6 +931,7 @@ TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) {
TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListServerV12) {
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_3);
+ ExpectAlert(server_, kTlsAlertHandshakeFailure);
ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
@@ -921,6 +944,11 @@ TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) {
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_3);
+#ifndef TLS_1_3_DRAFT_VERSION
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
+#else
+ ExpectAlert(server_, kTlsAlertDecryptError);
+#endif
ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
#ifndef TLS_1_3_DRAFT_VERSION
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
@@ -932,18 +960,21 @@ TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) {
}
TEST_P(TlsExtensionTest13, HrrThenRemoveSignatureAlgorithms) {
+ ExpectAlert(server_, kTlsAlertMissingExtension);
HrrThenRemoveExtensionsTest(ssl_signature_algorithms_xtn,
SSL_ERROR_MISSING_EXTENSION_ALERT,
SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
}
TEST_P(TlsExtensionTest13, HrrThenRemoveKeyShare) {
+ ExpectAlert(server_, kTlsAlertIllegalParameter);
HrrThenRemoveExtensionsTest(ssl_tls13_key_share_xtn,
SSL_ERROR_ILLEGAL_PARAMETER_ALERT,
SSL_ERROR_BAD_2ND_CLIENT_HELLO);
}
TEST_P(TlsExtensionTest13, HrrThenRemoveSupportedGroups) {
+ ExpectAlert(server_, kTlsAlertMissingExtension);
HrrThenRemoveExtensionsTest(ssl_supported_groups_xtn,
SSL_ERROR_MISSING_EXTENSION_ALERT,
SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
@@ -959,27 +990,192 @@ TEST_P(TlsExtensionTest13, OddVersionList) {
ConnectWithBogusVersionList(ext, sizeof(ext));
}
-INSTANTIATE_TEST_CASE_P(ExtensionStream, TlsExtensionTestGeneric,
- ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
- TlsConnectTestBase::kTlsVAll));
-INSTANTIATE_TEST_CASE_P(ExtensionDatagram, TlsExtensionTestGeneric,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
- TlsConnectTestBase::kTlsV11Plus));
+// TODO: this only tests extensions in server messages. The client can extend
+// Certificate messages, which is not checked here.
+class TlsBogusExtensionTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
+ public:
+ TlsBogusExtensionTest()
+ : TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
+
+ protected:
+ virtual void ConnectAndFail(uint8_t message) = 0;
+
+ void AddFilter(uint8_t message, uint16_t extension) {
+ static uint8_t empty_buf[1] = {0};
+ DataBuffer empty(empty_buf, 0);
+ auto filter =
+ std::make_shared<TlsExtensionAppender>(message, extension, empty);
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ server_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ } else {
+ server_->SetPacketFilter(filter);
+ }
+ }
+
+ void Run(uint8_t message, uint16_t extension = 0xff) {
+ EnsureTlsSetup();
+ AddFilter(message, extension);
+ ConnectAndFail(message);
+ }
+};
+
+class TlsBogusExtensionTestPre13 : public TlsBogusExtensionTest {
+ protected:
+ void ConnectAndFail(uint8_t) override {
+ ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
+ }
+};
+
+class TlsBogusExtensionTest13 : public TlsBogusExtensionTest {
+ protected:
+ void ConnectAndFail(uint8_t message) override {
+ if (message == kTlsHandshakeHelloRetryRequest) {
+ ConnectExpectAlert(client_, kTlsAlertUnsupportedExtension);
+ return;
+ }
+
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake(); // ClientHello
+ server_->Handshake(); // ServerHello
+
+ client_->ExpectSendAlert(kTlsAlertUnsupportedExtension);
+ client_->Handshake();
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ }
+ server_->Handshake();
+ }
+};
+
+TEST_P(TlsBogusExtensionTestPre13, AddBogusExtensionServerHello) {
+ Run(kTlsHandshakeServerHello);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionServerHello) {
+ Run(kTlsHandshakeServerHello);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionEncryptedExtensions) {
+ Run(kTlsHandshakeEncryptedExtensions);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificate) {
+ Run(kTlsHandshakeCertificate);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionCertificateRequest) {
+ server_->RequestClientAuth(false);
+ Run(kTlsHandshakeCertificateRequest);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionHelloRetryRequest) {
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ Run(kTlsHandshakeHelloRetryRequest);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionServerHello) {
+ Run(kTlsHandshakeServerHello, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionEncryptedExtensions) {
+ Run(kTlsHandshakeEncryptedExtensions, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificate) {
+ Run(kTlsHandshakeCertificate, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionCertificateRequest) {
+ server_->RequestClientAuth(false);
+ Run(kTlsHandshakeCertificateRequest, ssl_tls13_supported_versions_xtn);
+}
+
+TEST_P(TlsBogusExtensionTest13, AddVersionExtensionHelloRetryRequest) {
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1};
+ server_->ConfigNamedGroups(groups);
+
+ Run(kTlsHandshakeHelloRetryRequest, ssl_tls13_supported_versions_xtn);
+}
+
+// NewSessionTicket allows unknown extensions AND it isn't protected by the
+// Finished. So adding an unknown extension doesn't cause an error.
+TEST_P(TlsBogusExtensionTest13, AddBogusExtensionNewSessionTicket) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+
+ AddFilter(kTlsHandshakeNewSessionTicket, 0xff);
+ Connect();
+ SendReceive();
+ CheckKeys();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectStream, IncludePadding) {
+ EnsureTlsSetup();
+
+ // This needs to be long enough to push a TLS 1.0 ClientHello over 255, but
+ // short enough not to push a TLS 1.3 ClientHello over 511.
+ static const char* long_name =
+ "chickenchickenchickenchickenchickenchickenchickenchicken."
+ "chickenchickenchickenchickenchickenchickenchickenchicken."
+ "chickenchickenchickenchickenchicken.";
+ SECStatus rv = SSL_SetURL(client_->ssl_fd(), long_name);
+ EXPECT_EQ(SECSuccess, rv);
+
+ auto capture = std::make_shared<TlsExtensionCapture>(ssl_padding_xtn);
+ client_->SetPacketFilter(capture);
+ client_->StartConnect();
+ client_->Handshake();
+ EXPECT_TRUE(capture->captured());
+}
+
+INSTANTIATE_TEST_CASE_P(
+ ExtensionStream, TlsExtensionTestGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_CASE_P(
+ ExtensionDatagram, TlsExtensionTestGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11Plus));
INSTANTIATE_TEST_CASE_P(ExtensionDatagramOnly, TlsExtensionTestDtls,
TlsConnectTestBase::kTlsV11Plus);
INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12Plus));
-INSTANTIATE_TEST_CASE_P(ExtensionPre13Stream, TlsExtensionTestPre13,
- ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
- TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_CASE_P(
+ ExtensionPre13Stream, TlsExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
INSTANTIATE_TEST_CASE_P(ExtensionPre13Datagram, TlsExtensionTestPre13,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV11V12));
INSTANTIATE_TEST_CASE_P(ExtensionTls13, TlsExtensionTest13,
- TlsConnectTestBase::kTlsModesAll);
-
-} // namespace nspr_test
+ TlsConnectTestBase::kTlsVariantsAll);
+
+INSTANTIATE_TEST_CASE_P(
+ BogusExtensionStream, TlsBogusExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_CASE_P(
+ BogusExtensionDatagram, TlsBogusExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
+ TlsConnectTestBase::kTlsV11V12));
+
+INSTANTIATE_TEST_CASE_P(BogusExtension13, TlsBogusExtensionTest13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc b/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
new file mode 100644
index 0000000..44cacce
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_fragment_unittest.cc
@@ -0,0 +1,157 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "secerr.h"
+#include "ssl.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+// This class cuts every unencrypted handshake record into two parts.
+class RecordFragmenter : public PacketFilter {
+ public:
+ RecordFragmenter() : sequence_number_(0), splitting_(true) {}
+
+ private:
+ class HandshakeSplitter {
+ public:
+ HandshakeSplitter(const DataBuffer& input, DataBuffer* output,
+ uint64_t* sequence_number)
+ : input_(input),
+ output_(output),
+ cursor_(0),
+ sequence_number_(sequence_number) {}
+
+ private:
+ void WriteRecord(TlsRecordHeader& record_header,
+ DataBuffer& record_fragment) {
+ TlsRecordHeader fragment_header(record_header.version(),
+ record_header.content_type(),
+ *sequence_number_);
+ ++*sequence_number_;
+ if (::g_ssl_gtest_verbose) {
+ std::cerr << "Fragment: " << fragment_header << ' ' << record_fragment
+ << std::endl;
+ }
+ cursor_ = fragment_header.Write(output_, cursor_, record_fragment);
+ }
+
+ bool SplitRecord(TlsRecordHeader& record_header, DataBuffer& record) {
+ TlsParser parser(record);
+ while (parser.remaining()) {
+ TlsHandshakeFilter::HandshakeHeader handshake_header;
+ DataBuffer handshake_body;
+ if (!handshake_header.Parse(&parser, record_header, &handshake_body)) {
+ ADD_FAILURE() << "couldn't parse handshake header";
+ return false;
+ }
+
+ DataBuffer record_fragment;
+ // We can't fragment handshake records that are too small.
+ if (handshake_body.len() < 2) {
+ handshake_header.Write(&record_fragment, 0U, handshake_body);
+ WriteRecord(record_header, record_fragment);
+ continue;
+ }
+
+ size_t cut = handshake_body.len() / 2;
+ handshake_header.WriteFragment(&record_fragment, 0U, handshake_body, 0U,
+ cut);
+ WriteRecord(record_header, record_fragment);
+
+ handshake_header.WriteFragment(&record_fragment, 0U, handshake_body,
+ cut, handshake_body.len() - cut);
+ WriteRecord(record_header, record_fragment);
+ }
+ return true;
+ }
+
+ public:
+ bool Split() {
+ TlsParser parser(input_);
+ while (parser.remaining()) {
+ TlsRecordHeader header;
+ DataBuffer record;
+ if (!header.Parse(&parser, &record)) {
+ ADD_FAILURE() << "bad record header";
+ return false;
+ }
+
+ if (::g_ssl_gtest_verbose) {
+ std::cerr << "Record: " << header << ' ' << record << std::endl;
+ }
+
+ // Don't touch packets from a non-zero epoch. Leave these unmodified.
+ if ((header.sequence_number() >> 48) != 0ULL) {
+ cursor_ = header.Write(output_, cursor_, record);
+ continue;
+ }
+
+ // Just rewrite the sequence number (CCS only).
+ if (header.content_type() != kTlsHandshakeType) {
+ EXPECT_EQ(kTlsChangeCipherSpecType, header.content_type());
+ WriteRecord(header, record);
+ continue;
+ }
+
+ if (!SplitRecord(header, record)) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ private:
+ const DataBuffer& input_;
+ DataBuffer* output_;
+ size_t cursor_;
+ uint64_t* sequence_number_;
+ };
+
+ protected:
+ virtual PacketFilter::Action Filter(const DataBuffer& input,
+ DataBuffer* output) override {
+ if (!splitting_) {
+ return KEEP;
+ }
+
+ output->Allocate(input.len());
+ HandshakeSplitter splitter(input, output, &sequence_number_);
+ if (!splitter.Split()) {
+ // If splitting fails, we obviously reached encrypted packets.
+ // Stop splitting from that point onward.
+ splitting_ = false;
+ return KEEP;
+ }
+
+ return CHANGE;
+ }
+
+ private:
+ uint64_t sequence_number_;
+ bool splitting_;
+};
+
+TEST_P(TlsConnectDatagram, FragmentClientPackets) {
+ client_->SetPacketFilter(std::make_shared<RecordFragmenter>());
+ Connect();
+ SendReceive();
+}
+
+TEST_P(TlsConnectDatagram, FragmentServerPackets) {
+ server_->SetPacketFilter(std::make_shared<RecordFragmenter>());
+ Connect();
+ SendReceive();
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc b/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
index d144cd7..d08a0b6 100644
--- a/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_fuzz_unittest.cc
@@ -12,6 +12,12 @@
namespace nss_test {
#ifdef UNSAFE_FUZZER_MODE
+#define FUZZ_F(c, f) TEST_F(c, Fuzz_##f)
+#define FUZZ_P(c, f) TEST_P(c, Fuzz_##f)
+#else
+#define FUZZ_F(c, f) TEST_F(c, DISABLED_Fuzz_##f)
+#define FUZZ_P(c, f) TEST_P(c, DISABLED_Fuzz_##f)
+#endif
const uint8_t kShortEmptyFinished[8] = {0};
const uint8_t kLongEmptyFinished[128] = {0};
@@ -23,7 +29,7 @@ class TlsApplicationDataRecorder : public TlsRecordFilter {
public:
TlsApplicationDataRecorder() : buffer_() {}
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output) {
if (header.content_type() == kTlsApplicationDataType) {
@@ -39,56 +45,28 @@ class TlsApplicationDataRecorder : public TlsRecordFilter {
DataBuffer buffer_;
};
-// Damages an SKE or CV signature.
-class TlsSignatureDamager : public TlsHandshakeFilter {
- public:
- TlsSignatureDamager(uint8_t type) : type_(type) {}
- virtual PacketFilter::Action FilterHandshake(
- const TlsHandshakeFilter::HandshakeHeader& header,
- const DataBuffer& input, DataBuffer* output) {
- if (header.handshake_type() != type_) {
- return KEEP;
- }
-
- *output = input;
-
- // Modify the last byte of the signature.
- output->data()[output->len() - 1]++;
- return CHANGE;
- }
-
- private:
- uint8_t type_;
-};
-
-void ResetState() {
- // Clear the list of RSA blinding params.
- BL_Cleanup();
-
- // Reinit the list of RSA blinding params.
- EXPECT_EQ(SECSuccess, BL_Init());
-
- // Reset the RNG state.
- EXPECT_EQ(SECSuccess, RNG_ResetForFuzzing());
-}
-
// Ensure that ssl_Time() returns a constant value.
-TEST_F(TlsFuzzTest, Fuzz_SSL_Time_Constant) {
- PRInt32 now = ssl_Time();
+FUZZ_F(TlsFuzzTest, SSL_Time_Constant) {
+ PRUint32 now = ssl_Time();
PR_Sleep(PR_SecondsToInterval(2));
EXPECT_EQ(ssl_Time(), now);
}
// Check that due to the deterministic PRNG we derive
// the same master secret in two consecutive TLS sessions.
-TEST_P(TlsConnectGeneric, Fuzz_DeterministicExporter) {
+FUZZ_P(TlsConnectGeneric, DeterministicExporter) {
const char kLabel[] = "label";
std::vector<unsigned char> out1(32), out2(32);
+ // Make sure we have RSA blinding params.
+ Connect();
+
+ Reset();
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
DisableECDHEServerKeyReuse();
- ResetState();
+ // Reset the RNG state.
+ EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
Connect();
// Export a key derived from the MS and nonces.
@@ -101,7 +79,8 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicExporter) {
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
DisableECDHEServerKeyReuse();
- ResetState();
+ // Reset the RNG state.
+ EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
Connect();
// Export another key derived from the MS and nonces.
@@ -115,7 +94,10 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicExporter) {
// Check that due to the deterministic RNG two consecutive
// TLS sessions will have the exact same transcript.
-TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) {
+FUZZ_P(TlsConnectGeneric, DeterministicTranscript) {
+ // Make sure we have RSA blinding params.
+ Connect();
+
// Connect a few times and compare the transcripts byte-by-byte.
DataBuffer last;
for (size_t i = 0; i < 5; i++) {
@@ -124,15 +106,16 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) {
DisableECDHEServerKeyReuse();
DataBuffer buffer;
- client_->SetPacketFilter(new TlsConversationRecorder(buffer));
- server_->SetPacketFilter(new TlsConversationRecorder(buffer));
+ client_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer));
+ server_->SetPacketFilter(std::make_shared<TlsConversationRecorder>(buffer));
- ResetState();
+ // Reset the RNG state.
+ EXPECT_EQ(SECSuccess, RNG_RandomUpdate(NULL, 0));
Connect();
// Ensure the filters go away before |buffer| does.
- client_->SetPacketFilter(nullptr);
- server_->SetPacketFilter(nullptr);
+ client_->DeletePacketFilter();
+ server_->DeletePacketFilter();
if (last.len() > 0) {
EXPECT_EQ(last, buffer);
@@ -146,13 +129,13 @@ TEST_P(TlsConnectGeneric, Fuzz_DeterministicTranscript) {
// with all supported TLS versions, STREAM and DGRAM.
// Check that records are NOT encrypted.
// Check that records don't have a MAC.
-TEST_P(TlsConnectGeneric, Fuzz_ConnectSendReceive_NullCipher) {
+FUZZ_P(TlsConnectGeneric, ConnectSendReceive_NullCipher) {
EnsureTlsSetup();
// Set up app data filters.
- auto client_recorder = new TlsApplicationDataRecorder();
+ auto client_recorder = std::make_shared<TlsApplicationDataRecorder>();
client_->SetPacketFilter(client_recorder);
- auto server_recorder = new TlsApplicationDataRecorder();
+ auto server_recorder = std::make_shared<TlsApplicationDataRecorder>();
server_->SetPacketFilter(server_recorder);
Connect();
@@ -175,10 +158,10 @@ TEST_P(TlsConnectGeneric, Fuzz_ConnectSendReceive_NullCipher) {
}
// Check that an invalid Finished message doesn't abort the connection.
-TEST_P(TlsConnectGeneric, Fuzz_BogusClientFinished) {
+FUZZ_P(TlsConnectGeneric, BogusClientFinished) {
EnsureTlsSetup();
- auto i1 = new TlsInspectorReplaceHandshakeMessage(
+ auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
kTlsHandshakeFinished,
DataBuffer(kShortEmptyFinished, sizeof(kShortEmptyFinished)));
client_->SetPacketFilter(i1);
@@ -187,10 +170,10 @@ TEST_P(TlsConnectGeneric, Fuzz_BogusClientFinished) {
}
// Check that an invalid Finished message doesn't abort the connection.
-TEST_P(TlsConnectGeneric, Fuzz_BogusServerFinished) {
+FUZZ_P(TlsConnectGeneric, BogusServerFinished) {
EnsureTlsSetup();
- auto i1 = new TlsInspectorReplaceHandshakeMessage(
+ auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
kTlsHandshakeFinished,
DataBuffer(kLongEmptyFinished, sizeof(kLongEmptyFinished)));
server_->SetPacketFilter(i1);
@@ -199,25 +182,120 @@ TEST_P(TlsConnectGeneric, Fuzz_BogusServerFinished) {
}
// Check that an invalid server auth signature doesn't abort the connection.
-TEST_P(TlsConnectGeneric, Fuzz_BogusServerAuthSignature) {
+FUZZ_P(TlsConnectGeneric, BogusServerAuthSignature) {
EnsureTlsSetup();
uint8_t msg_type = version_ == SSL_LIBRARY_VERSION_TLS_1_3
? kTlsHandshakeCertificateVerify
: kTlsHandshakeServerKeyExchange;
- server_->SetPacketFilter(new TlsSignatureDamager(msg_type));
+ server_->SetPacketFilter(std::make_shared<TlsLastByteDamager>(msg_type));
Connect();
SendReceive();
}
// Check that an invalid client auth signature doesn't abort the connection.
-TEST_P(TlsConnectGeneric, Fuzz_BogusClientAuthSignature) {
+FUZZ_P(TlsConnectGeneric, BogusClientAuthSignature) {
EnsureTlsSetup();
client_->SetupClientAuth();
server_->RequestClientAuth(true);
client_->SetPacketFilter(
- new TlsSignatureDamager(kTlsHandshakeCertificateVerify));
+ std::make_shared<TlsLastByteDamager>(kTlsHandshakeCertificateVerify));
Connect();
}
-#endif
+// Check that session ticket resumption works.
+FUZZ_P(TlsConnectGeneric, SessionTicketResumption) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+ Connect();
+ SendReceive();
+}
+
+class TlsSessionTicketMacDamager : public TlsExtensionFilter {
+ public:
+ TlsSessionTicketMacDamager() {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != ssl_session_ticket_xtn &&
+ extension_type != ssl_tls13_pre_shared_key_xtn) {
+ return KEEP;
+ }
+
+ *output = input;
+
+ // Handle everything before TLS 1.3.
+ if (extension_type == ssl_session_ticket_xtn) {
+ // Modify the last byte of the MAC.
+ output->data()[output->len() - 1] ^= 0xff;
+ }
+
+ // Handle TLS 1.3.
+ if (extension_type == ssl_tls13_pre_shared_key_xtn) {
+ TlsParser parser(input);
+
+ uint32_t ids_len;
+ EXPECT_TRUE(parser.Read(&ids_len, 2) && ids_len > 0);
+
+ uint32_t ticket_len;
+ EXPECT_TRUE(parser.Read(&ticket_len, 2) && ticket_len > 0);
+
+ // Modify the last byte of the MAC.
+ output->data()[2 + 2 + ticket_len - 1] ^= 0xff;
+ }
+
+ return CHANGE;
+ }
+};
+
+// Check that session ticket resumption works with a bad MAC.
+FUZZ_P(TlsConnectGeneric, SessionTicketResumptionBadMac) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_TICKET);
+
+ client_->SetPacketFilter(std::make_shared<TlsSessionTicketMacDamager>());
+ Connect();
+ SendReceive();
+}
+
+// Check that session tickets are not encrypted.
+FUZZ_P(TlsConnectGeneric, UnencryptedSessionTickets) {
+ ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET);
+
+ auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeNewSessionTicket);
+ server_->SetPacketFilter(i1);
+ Connect();
+
+ size_t offset = 4; /* lifetime */
+ if (version_ == SSL_LIBRARY_VERSION_TLS_1_3) {
+ offset += 1 + 1 + /* ke_modes */
+ 1 + 1; /* auth_modes */
+ }
+
+ offset += 2 + /* ticket length */
+ 16 + /* SESS_TICKET_KEY_NAME_LEN */
+ 16 + /* AES-128 IV */
+ 2 + /* ciphertext length */
+ 2; /* TLS_EX_SESS_TICKET_VERSION */
+
+ // Check the protocol version number.
+ uint32_t tls_version = 0;
+ EXPECT_TRUE(i1->buffer().Read(offset, sizeof(version_), &tls_version));
+ EXPECT_EQ(version_, static_cast<decltype(version_)>(tls_version));
+
+ // Check the cipher suite.
+ uint32_t suite = 0;
+ EXPECT_TRUE(i1->buffer().Read(offset + sizeof(version_), 2, &suite));
+ client_->CheckCipherSuite(static_cast<uint16_t>(suite));
+}
}
diff --git a/nss/gtests/ssl_gtest/ssl_gather_unittest.cc b/nss/gtests/ssl_gtest/ssl_gather_unittest.cc
new file mode 100644
index 0000000..f47b2f4
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_gather_unittest.cc
@@ -0,0 +1,143 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "gtest_utils.h"
+#include "tls_connect.h"
+
+namespace nss_test {
+
+class GatherV2ClientHelloTest : public TlsConnectTestBase {
+ public:
+ GatherV2ClientHelloTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
+
+ void ConnectExpectMalformedClientHello(const DataBuffer &data) {
+ EnsureTlsSetup();
+ server_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ client_->SendDirect(data);
+ server_->StartConnect();
+ server_->Handshake();
+ ASSERT_TRUE_WAIT(
+ (server_->error_code() == SSL_ERROR_RX_MALFORMED_CLIENT_HELLO), 2000);
+ }
+};
+
+// Gather a 5-byte v3 record, with a zero fragment length. The empty handshake
+// message should be ignored, and the connection will succeed afterwards.
+TEST_F(TlsConnectTest, GatherEmptyV3Record) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x16, 1); // handshake
+ idx = buffer.Write(idx, 0x0301, 2); // record_version
+ (void)buffer.Write(idx, 0U, 2); // length=0
+
+ EnsureTlsSetup();
+ client_->SendDirect(buffer);
+ Connect();
+}
+
+// Gather a 5-byte v3 record, with a fragment length exceeding the maximum.
+TEST_F(TlsConnectTest, GatherExcessiveV3Record) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x16, 1); // handshake
+ idx = buffer.Write(idx, 0x0301, 2); // record_version
+ (void)buffer.Write(idx, MAX_FRAGMENT_LENGTH + 2048 + 1, 2); // length=max+1
+
+ EnsureTlsSetup();
+ server_->ExpectSendAlert(kTlsAlertRecordOverflow);
+ client_->SendDirect(buffer);
+ server_->StartConnect();
+ server_->Handshake();
+ ASSERT_TRUE_WAIT((server_->error_code() == SSL_ERROR_RX_RECORD_TOO_LONG),
+ 2000);
+}
+
+// Gather a 3-byte v2 header, with a fragment length of 2.
+TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x0002, 2); // length=2 (long header)
+ idx = buffer.Write(idx, 0U, 1); // padding=0
+ (void)buffer.Write(idx, 0U, 2); // data
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 3-byte v2 header, with a fragment length of 1.
+TEST_F(GatherV2ClientHelloTest, GatherV2RecordLongHeader2) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x0001, 2); // length=1 (long header)
+ idx = buffer.Write(idx, 0U, 1); // padding=0
+ idx = buffer.Write(idx, 0U, 1); // data
+ (void)buffer.Write(idx, 0U, 1); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 3-byte v2 header, with a zero fragment length.
+TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordLongHeader) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0U, 2); // length=0 (long header)
+ idx = buffer.Write(idx, 0U, 1); // padding=0
+ (void)buffer.Write(idx, 0U, 2); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 2-byte v2 header, with a fragment length of 3.
+TEST_F(GatherV2ClientHelloTest, GatherV2RecordShortHeader) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x8003, 2); // length=3 (short header)
+ (void)buffer.Write(idx, 0U, 3); // data
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 2-byte v2 header, with a fragment length of 2.
+TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader2) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x8002, 2); // length=2 (short header)
+ idx = buffer.Write(idx, 0U, 2); // data
+ (void)buffer.Write(idx, 0U, 1); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 2-byte v2 header, with a fragment length of 1.
+TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader3) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x8001, 2); // length=1 (short header)
+ idx = buffer.Write(idx, 0U, 1); // data
+ (void)buffer.Write(idx, 0U, 2); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+// Gather a 2-byte v2 header, with a zero fragment length.
+TEST_F(GatherV2ClientHelloTest, GatherEmptyV2RecordShortHeader) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x8000, 2); // length=0 (short header)
+ (void)buffer.Write(idx, 0U, 3); // surplus (need 5 bytes total)
+
+ ConnectExpectMalformedClientHello(buffer);
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_gtest.cc b/nss/gtests/ssl_gtest/ssl_gtest.cc
index 2d08dd8..cd10076 100644
--- a/nss/gtests/ssl_gtest/ssl_gtest.cc
+++ b/nss/gtests/ssl_gtest/ssl_gtest.cc
@@ -31,12 +31,18 @@ int main(int argc, char** argv) {
}
}
- NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB,
- NSS_INIT_READONLY);
- NSS_SetDomesticPolicy();
+ if (NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB,
+ NSS_INIT_READONLY) != SECSuccess) {
+ return 1;
+ }
+ if (NSS_SetDomesticPolicy() != SECSuccess) {
+ return 1;
+ }
int rv = RUN_ALL_TESTS();
- NSS_Shutdown();
+ if (NSS_Shutdown() != SECSuccess) {
+ return 1;
+ }
nss_test::Poller::Shutdown();
diff --git a/nss/gtests/ssl_gtest/ssl_gtest.gyp b/nss/gtests/ssl_gtest/ssl_gtest.gyp
index e232a8b..f0b96d6 100644
--- a/nss/gtests/ssl_gtest/ssl_gtest.gyp
+++ b/nss/gtests/ssl_gtest/ssl_gtest.gyp
@@ -25,6 +25,8 @@
'ssl_exporter_unittest.cc',
'ssl_extension_unittest.cc',
'ssl_fuzz_unittest.cc',
+ 'ssl_fragment_unittest.cc',
+ 'ssl_gather_unittest.cc',
'ssl_gtest.cc',
'ssl_hrr_unittest.cc',
'ssl_loopback_unittest.cc',
@@ -34,37 +36,45 @@
'ssl_staticrsa_unittest.cc',
'ssl_v2_client_hello_unittest.cc',
'ssl_version_unittest.cc',
+ 'ssl_versionpolicy_unittest.cc',
'test_io.cc',
'tls_agent.cc',
'tls_connect.cc',
'tls_filter.cc',
'tls_hkdf_unittest.cc',
- 'tls_parser.cc'
+ 'tls_protect.cc'
],
'dependencies': [
'<(DEPTH)/exports.gyp:nss_exports',
'<(DEPTH)/lib/util/util.gyp:nssutil3',
- '<(DEPTH)/lib/sqlite/sqlite.gyp:sqlite3',
'<(DEPTH)/gtests/google_test/google_test.gyp:gtest',
- '<(DEPTH)/lib/softoken/softoken.gyp:softokn',
'<(DEPTH)/lib/smime/smime.gyp:smime',
'<(DEPTH)/lib/ssl/ssl.gyp:ssl',
'<(DEPTH)/lib/nss/nss.gyp:nss_static',
- '<(DEPTH)/cmd/lib/lib.gyp:sectool',
'<(DEPTH)/lib/pkcs12/pkcs12.gyp:pkcs12',
'<(DEPTH)/lib/pkcs7/pkcs7.gyp:pkcs7',
'<(DEPTH)/lib/certhigh/certhigh.gyp:certhi',
'<(DEPTH)/lib/cryptohi/cryptohi.gyp:cryptohi',
- '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap',
- '<(DEPTH)/lib/softoken/softoken.gyp:softokn',
'<(DEPTH)/lib/certdb/certdb.gyp:certdb',
'<(DEPTH)/lib/pki/pki.gyp:nsspki',
'<(DEPTH)/lib/dev/dev.gyp:nssdev',
'<(DEPTH)/lib/base/base.gyp:nssb',
- '<(DEPTH)/lib/freebl/freebl.gyp:<(freebl_name)',
- '<(DEPTH)/lib/zlib/zlib.gyp:nss_zlib'
+ '<(DEPTH)/lib/zlib/zlib.gyp:nss_zlib',
+ '<(DEPTH)/cpputil/cpputil.gyp:cpputil',
],
'conditions': [
+ [ 'test_build==1', {
+ 'dependencies': [
+ '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap_static',
+ ],
+ }, {
+ 'dependencies': [
+ '<(DEPTH)/lib/sqlite/sqlite.gyp:sqlite3',
+ '<(DEPTH)/lib/pk11wrap/pk11wrap.gyp:pk11wrap',
+ '<(DEPTH)/lib/softoken/softoken.gyp:softokn',
+ '<(DEPTH)/lib/freebl/freebl.gyp:freebl',
+ ],
+ }],
[ 'disable_dbm==0', {
'dependencies': [
'<(DEPTH)/lib/dbm/src/src.gyp:dbm',
@@ -90,10 +100,11 @@
],
'target_defaults': {
'include_dirs': [
- '../../gtests/google_test/gtest/include',
- '../../gtests/common',
'../../lib/ssl'
],
+ 'defines': [
+ 'NSS_USE_STATIC_LIBS'
+ ],
},
'variables': {
'module': 'nss',
diff --git a/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc b/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
index 5d670fa..39055f6 100644
--- a/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_hrr_unittest.cc
@@ -34,7 +34,8 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
ExpectResumption(RESUME_TICKET);
// Send first ClientHello and send 0-RTT data
- auto capture_early_data = new TlsExtensionCapture(ssl_tls13_early_data_xtn);
+ auto capture_early_data =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
client_->SetPacketFilter(capture_early_data);
client_->Handshake();
EXPECT_EQ(k0RttDataLen, PR_Write(client_->ssl_fd(), k0RttData,
@@ -42,8 +43,8 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
EXPECT_TRUE(capture_early_data->captured());
// Send the HelloRetryRequest
- auto hrr_capture =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest);
+ auto hrr_capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeHelloRetryRequest);
server_->SetPacketFilter(hrr_capture);
server_->Handshake();
EXPECT_LT(0U, hrr_capture->buffer().len());
@@ -54,7 +55,8 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
EXPECT_EQ(PR_WOULD_BLOCK_ERROR, PORT_GetError());
// Make a new capture for the early data.
- capture_early_data = new TlsExtensionCapture(ssl_tls13_early_data_xtn);
+ capture_early_data =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_early_data_xtn);
client_->SetPacketFilter(capture_early_data);
// Complete the handshake successfully
@@ -65,6 +67,88 @@ TEST_P(TlsConnectTls13, HelloRetryRequestAbortsZeroRtt) {
EXPECT_FALSE(capture_early_data->captured());
}
+// This filter only works for DTLS 1.3 where there is exactly one handshake
+// packet. If the record is split into two packets, or there are multiple
+// handshake packets, this will break.
+class CorrectMessageSeqAfterHrrFilter : public TlsRecordFilter {
+ protected:
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& record, size_t* offset,
+ DataBuffer* output) {
+ if (filtered_packets() > 0 || header.content_type() != content_handshake) {
+ return KEEP;
+ }
+
+ DataBuffer buffer(record);
+ TlsRecordHeader new_header = {header.version(), header.content_type(),
+ header.sequence_number() + 1};
+
+ // Correct message_seq.
+ buffer.Write(4, 1U, 2);
+
+ *offset = new_header.Write(output, *offset, buffer);
+ return CHANGE;
+ }
+};
+
+TEST_P(TlsConnectTls13, SecondClientHelloRejectEarlyDataXtn) {
+ static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
+ ssl_grp_ec_secp521r1};
+
+ SetupForZeroRtt();
+ ExpectResumption(RESUME_TICKET);
+
+ client_->ConfigNamedGroups(groups);
+ server_->ConfigNamedGroups(groups);
+ client_->Set0RttEnabled(true);
+ server_->Set0RttEnabled(true);
+
+ // A new client that tries to resume with 0-RTT but doesn't send the
+ // correct key share(s). The server will respond with an HRR.
+ auto orig_client =
+ std::make_shared<TlsAgent>(client_->name(), TlsAgent::CLIENT, variant_);
+ client_.swap(orig_client);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ client_->ConfigureSessionCache(RESUME_BOTH);
+ client_->Set0RttEnabled(true);
+ client_->StartConnect();
+
+ // Swap in the new client.
+ client_->SetPeer(server_);
+ server_->SetPeer(client_);
+
+ // Send the ClientHello.
+ client_->Handshake();
+ // Process the CH, send an HRR.
+ server_->Handshake();
+
+ // Swap the client we created manually with the one that successfully
+ // received a PSK, and try to resume with 0-RTT. The client doesn't know
+ // about the HRR so it will send the early_data xtn as well as 0-RTT data.
+ client_.swap(orig_client);
+ orig_client.reset();
+
+ // Correct the DTLS message sequence number after an HRR.
+ if (variant_ == ssl_variant_datagram) {
+ client_->SetPacketFilter(
+ std::make_shared<CorrectMessageSeqAfterHrrFilter>());
+ }
+
+ server_->SetPeer(client_);
+ client_->Handshake();
+
+ // Send 0-RTT data.
+ const char* k0RttData = "ABCDEF";
+ const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ PRInt32 rv = PR_Write(client_->ssl_fd(), k0RttData, k0RttDataLen);
+ EXPECT_EQ(k0RttDataLen, rv);
+
+ ExpectAlert(server_, kTlsAlertUnsupportedExtension);
+ Handshake();
+ client_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_EXTENSION_ALERT);
+}
+
class KeyShareReplayer : public TlsExtensionFilter {
public:
KeyShareReplayer() {}
@@ -94,11 +178,11 @@ class KeyShareReplayer : public TlsExtensionFilter {
// server should reject this.
TEST_P(TlsConnectTls13, RetryWithSameKeyShare) {
EnsureTlsSetup();
- client_->SetPacketFilter(new KeyShareReplayer());
+ client_->SetPacketFilter(std::make_shared<KeyShareReplayer>());
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(groups);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
EXPECT_EQ(SSL_ERROR_BAD_2ND_CLIENT_HELLO, server_->error_code());
EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, client_->error_code());
}
@@ -109,7 +193,7 @@ TEST_F(TlsConnectDatagram13, DropClientSecondFlightWithHelloRetry) {
static const std::vector<SSLNamedGroup> groups = {ssl_grp_ec_secp384r1,
ssl_grp_ec_secp521r1};
server_->ConfigNamedGroups(groups);
- server_->SetPacketFilter(new SelectiveDropFilter(0x2));
+ server_->SetPacketFilter(std::make_shared<SelectiveDropFilter>(0x2));
Connect();
}
@@ -169,16 +253,13 @@ TEST_F(TlsConnectTest, Select12AfterHelloRetryRequest) {
// Here we replace the TLS server with one that does TLS 1.2 only.
// This will happily send the client a TLS 1.2 ServerHello.
- TlsAgent* replacement_server =
- new TlsAgent(server_->name(), TlsAgent::SERVER, mode_);
- delete server_;
- server_ = replacement_server;
- server_->Init();
+ server_.reset(new TlsAgent(server_->name(), TlsAgent::SERVER, variant_));
client_->SetPeer(server_);
server_->SetPeer(client_);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_2);
server_->StartConnect();
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
Handshake();
EXPECT_EQ(SSL_ERROR_ILLEGAL_PARAMETER_ALERT, server_->error_code());
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
@@ -189,8 +270,6 @@ class HelloRetryRequestAgentTest : public TlsAgentTestClient {
void SetUp() override {
TlsAgentTestClient::SetUp();
EnsureInit();
- agent_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_3,
- SSL_LIBRARY_VERSION_TLS_1_3);
agent_->StartConnect();
}
@@ -232,6 +311,7 @@ TEST_P(HelloRetryRequestAgentTest, SendSecondHelloRetryRequest) {
MakeGroupHrr(ssl_grp_ec_secp384r1, &hrr, 0);
ProcessMessage(hrr, TlsAgent::STATE_CONNECTING);
MakeGroupHrr(ssl_grp_ec_secp521r1, &hrr, 1);
+ ExpectAlert(kTlsAlertUnexpectedMessage);
ProcessMessage(hrr, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_UNEXPECTED_HELLO_RETRY_REQUEST);
}
@@ -241,6 +321,7 @@ TEST_P(HelloRetryRequestAgentTest, SendSecondHelloRetryRequest) {
TEST_P(HelloRetryRequestAgentTest, HandleBogusHelloRetryRequest) {
DataBuffer hrr;
MakeGroupHrr(ssl_grp_ec_curve25519, &hrr);
+ ExpectAlert(kTlsAlertIllegalParameter);
ProcessMessage(hrr, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST);
}
@@ -248,6 +329,7 @@ TEST_P(HelloRetryRequestAgentTest, HandleBogusHelloRetryRequest) {
TEST_P(HelloRetryRequestAgentTest, HandleNoopHelloRetryRequest) {
DataBuffer hrr;
MakeCannedHrr(nullptr, 0U, &hrr);
+ ExpectAlert(kTlsAlertDecodeError);
ProcessMessage(hrr, TlsAgent::STATE_ERROR,
SSL_ERROR_RX_MALFORMED_HELLO_RETRY_REQUEST);
}
@@ -265,7 +347,7 @@ TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) {
0x13};
DataBuffer hrr;
MakeCannedHrr(canned_cookie_hrr, sizeof(canned_cookie_hrr), &hrr);
- TlsExtensionCapture* capture = new TlsExtensionCapture(ssl_tls13_cookie_xtn);
+ auto capture = std::make_shared<TlsExtensionCapture>(ssl_tls13_cookie_xtn);
agent_->SetPacketFilter(capture);
ProcessMessage(hrr, TlsAgent::STATE_CONNECTING);
const size_t cookie_pos = 2 + 2; // cookie_xtn, extension len
@@ -275,10 +357,11 @@ TEST_P(HelloRetryRequestAgentTest, HandleHelloRetryRequestCookie) {
}
INSTANTIATE_TEST_CASE_P(HelloRetryRequestAgentTests, HelloRetryRequestAgentTest,
- TlsConnectTestBase::kTlsModesAll);
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ TlsConnectTestBase::kTlsV13));
#ifndef NSS_DISABLE_TLS_1_3
INSTANTIATE_TEST_CASE_P(HelloRetryRequestKeyExchangeTests, TlsKeyExchange13,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV13));
#endif
diff --git a/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc b/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
index 65c0ca1..fd05754 100644
--- a/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_loopback_unittest.cc
@@ -39,7 +39,7 @@ TEST_P(TlsConnectGeneric, ConnectEcdsa) {
CheckKeys(ssl_kea_ecdh, ssl_auth_ecdsa);
}
-TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) {
+TEST_P(TlsConnectGeneric, CipherSuiteMismatch) {
EnsureTlsSetup();
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
client_->EnableSingleCipher(TLS_AES_128_GCM_SHA256);
@@ -48,11 +48,97 @@ TEST_P(TlsConnectGenericPre13, CipherSuiteMismatch) {
client_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA);
}
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
}
+class TlsAlertRecorder : public TlsRecordFilter {
+ public:
+ TlsAlertRecorder() : level_(255), description_(255) {}
+
+ PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ if (level_ != 255) { // Already captured.
+ return KEEP;
+ }
+ if (header.content_type() != kTlsAlertType) {
+ return KEEP;
+ }
+
+ std::cerr << "Alert: " << input << std::endl;
+
+ TlsParser parser(input);
+ EXPECT_TRUE(parser.Read(&level_));
+ EXPECT_TRUE(parser.Read(&description_));
+ return KEEP;
+ }
+
+ uint8_t level() const { return level_; }
+ uint8_t description() const { return description_; }
+
+ private:
+ uint8_t level_;
+ uint8_t description_;
+};
+
+class HelloTruncator : public TlsHandshakeFilter {
+ PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) override {
+ if (header.handshake_type() != kTlsHandshakeClientHello &&
+ header.handshake_type() != kTlsHandshakeServerHello) {
+ return KEEP;
+ }
+ output->Assign(input.data(), input.len() - 1);
+ return CHANGE;
+ }
+};
+
+// Verify that when NSS reports that an alert is sent, it is actually sent.
+TEST_P(TlsConnectGeneric, CaptureAlertServer) {
+ client_->SetPacketFilter(std::make_shared<HelloTruncator>());
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+ server_->SetPacketFilter(alert_recorder);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertIllegalParameter, alert_recorder->description());
+}
+
+TEST_P(TlsConnectGenericPre13, CaptureAlertClient) {
+ server_->SetPacketFilter(std::make_shared<HelloTruncator>());
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+ client_->SetPacketFilter(alert_recorder);
+
+ ConnectExpectAlert(client_, kTlsAlertDecodeError);
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
+// In TLS 1.3, the server can't read the client alert.
+TEST_P(TlsConnectTls13, CaptureAlertClient) {
+ server_->SetPacketFilter(std::make_shared<HelloTruncator>());
+ auto alert_recorder = std::make_shared<TlsAlertRecorder>();
+ client_->SetPacketFilter(alert_recorder);
+
+ server_->StartConnect();
+ client_->StartConnect();
+
+ client_->Handshake();
+ client_->ExpectSendAlert(kTlsAlertDecodeError);
+ server_->Handshake();
+ client_->Handshake();
+ if (variant_ == ssl_variant_stream) {
+ // DTLS just drops the alert it can't decrypt.
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ }
+ server_->Handshake();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(kTlsAlertDecodeError, alert_recorder->description());
+}
+
TEST_P(TlsConnectGenericPre13, ConnectFalseStart) {
client_->EnableFalseStart();
Connect();
@@ -141,7 +227,8 @@ TEST_P(TlsConnectGeneric, ConnectWithCompressionMaybe) {
client_->EnableCompression();
server_->EnableCompression();
Connect();
- EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 && mode_ != DGRAM,
+ EXPECT_EQ(client_->version() < SSL_LIBRARY_VERSION_TLS_1_3 &&
+ variant_ != ssl_variant_datagram,
client_->is_compressed());
SendReceive();
}
@@ -161,16 +248,15 @@ TEST_P(TlsConnectDatagram, TestDtlsHolddownExpiry) {
class TlsPreCCSHeaderInjector : public TlsRecordFilter {
public:
TlsPreCCSHeaderInjector() {}
- virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header,
- const DataBuffer& input,
- size_t* offset,
- DataBuffer* output) override {
+ virtual PacketFilter::Action FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ size_t* offset, DataBuffer* output) override {
if (record_header.content_type() != kTlsChangeCipherSpecType) return KEEP;
std::cerr << "Injecting Finished header before CCS\n";
const uint8_t hhdr[] = {kTlsHandshakeFinished, 0x00, 0x00, 0x0c};
DataBuffer hhdr_buf(hhdr, sizeof(hhdr));
- RecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0);
+ TlsRecordHeader nhdr(record_header.version(), kTlsHandshakeType, 0);
*offset = nhdr.Write(output, *offset, hhdr_buf);
*offset = record_header.Write(output, *offset, input);
return CHANGE;
@@ -178,24 +264,28 @@ class TlsPreCCSHeaderInjector : public TlsRecordFilter {
};
TEST_P(TlsConnectStreamPre13, ClientFinishedHeaderBeforeCCS) {
- client_->SetPacketFilter(new TlsPreCCSHeaderInjector());
- ConnectExpectFail();
+ client_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
+ ConnectExpectAlert(server_, kTlsAlertUnexpectedMessage);
client_->CheckErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT);
server_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
}
TEST_P(TlsConnectStreamPre13, ServerFinishedHeaderBeforeCCS) {
- server_->SetPacketFilter(new TlsPreCCSHeaderInjector());
+ server_->SetPacketFilter(std::make_shared<TlsPreCCSHeaderInjector>());
client_->StartConnect();
server_->StartConnect();
+ ExpectAlert(client_, kTlsAlertUnexpectedMessage);
Handshake();
EXPECT_EQ(TlsAgent::STATE_ERROR, client_->state());
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_CHANGE_CIPHER);
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
+ server_->Handshake(); // Make sure alert is consumed.
}
TEST_P(TlsConnectTls13, UnknownAlert) {
Connect();
+ server_->ExpectSendAlert(0xff, kTlsAlertWarning);
+ client_->ExpectReceiveAlert(0xff, kTlsAlertWarning);
SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning,
0xff); // Unknown value.
client_->ExpectReadWriteError();
@@ -204,20 +294,14 @@ TEST_P(TlsConnectTls13, UnknownAlert) {
TEST_P(TlsConnectTls13, AlertWrongLevel) {
Connect();
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage, kTlsAlertWarning);
SSLInt_SendAlert(server_->ssl_fd(), kTlsAlertWarning,
kTlsAlertUnexpectedMessage);
client_->ExpectReadWriteError();
client_->WaitForErrorCode(SSL_ERROR_HANDSHAKE_UNEXPECTED_ALERT, 2000);
}
-TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) {
- client_->SetShortHeadersEnabled();
- server_->SetShortHeadersEnabled();
- client_->ExpectShortHeaders();
- server_->ExpectShortHeaders();
- Connect();
-}
-
TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) {
EnsureTlsSetup();
client_->StartConnect();
@@ -229,12 +313,21 @@ TEST_F(TlsConnectStreamTls13, Tls13FailedWriteSecondFlight) {
client_->CheckErrorCode(SSL_ERROR_SOCKET_WRITE_FAILURE);
}
-INSTANTIATE_TEST_CASE_P(GenericStream, TlsConnectGeneric,
- ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
- TlsConnectTestBase::kTlsVAll));
+TEST_F(TlsConnectStreamTls13, NegotiateShortHeaders) {
+ client_->SetShortHeadersEnabled();
+ server_->SetShortHeadersEnabled();
+ client_->ExpectShortHeaders();
+ server_->ExpectShortHeaders();
+ Connect();
+}
+
+INSTANTIATE_TEST_CASE_P(
+ GenericStream, TlsConnectGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsVAll));
INSTANTIATE_TEST_CASE_P(
GenericDatagram, TlsConnectGeneric,
- ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
TlsConnectTestBase::kTlsV11Plus));
INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream,
@@ -242,33 +335,35 @@ INSTANTIATE_TEST_CASE_P(StreamOnly, TlsConnectStream,
INSTANTIATE_TEST_CASE_P(DatagramOnly, TlsConnectDatagram,
TlsConnectTestBase::kTlsV11Plus);
-INSTANTIATE_TEST_CASE_P(Pre12Stream, TlsConnectPre12,
- ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
- TlsConnectTestBase::kTlsV10V11));
+INSTANTIATE_TEST_CASE_P(
+ Pre12Stream, TlsConnectPre12,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10V11));
INSTANTIATE_TEST_CASE_P(
Pre12Datagram, TlsConnectPre12,
- ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
TlsConnectTestBase::kTlsV11));
INSTANTIATE_TEST_CASE_P(Version12Only, TlsConnectTls12,
- TlsConnectTestBase::kTlsModesAll);
+ TlsConnectTestBase::kTlsVariantsAll);
#ifndef NSS_DISABLE_TLS_1_3
INSTANTIATE_TEST_CASE_P(Version13Only, TlsConnectTls13,
- TlsConnectTestBase::kTlsModesAll);
+ TlsConnectTestBase::kTlsVariantsAll);
#endif
-INSTANTIATE_TEST_CASE_P(Pre13Stream, TlsConnectGenericPre13,
- ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
- TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_CASE_P(
+ Pre13Stream, TlsConnectGenericPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10ToV12));
INSTANTIATE_TEST_CASE_P(
Pre13Datagram, TlsConnectGenericPre13,
- ::testing::Combine(TlsConnectTestBase::kTlsModesDatagram,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsDatagram,
TlsConnectTestBase::kTlsV11V12));
INSTANTIATE_TEST_CASE_P(Pre13StreamOnly, TlsConnectStreamPre13,
TlsConnectTestBase::kTlsV10ToV12);
INSTANTIATE_TEST_CASE_P(Version12Plus, TlsConnectTls12Plus,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV12Plus));
} // namespace nspr_test
diff --git a/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc b/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
index cfe42cb..7b43870 100644
--- a/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_resumption_unittest.cc
@@ -21,6 +21,7 @@ extern "C" {
#include "tls_connect.h"
#include "tls_filter.h"
#include "tls_parser.h"
+#include "tls_protect.h"
namespace nss_test {
@@ -200,6 +201,63 @@ TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicketForget) {
SendReceive();
}
+TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtClient) {
+ SSLInt_SetTicketLifetime(1); // one second
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ WAIT_(false, 1000);
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+
+ // TLS 1.3 uses the pre-shared key extension instead.
+ SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
+ ? ssl_tls13_pre_shared_key_xtn
+ : ssl_session_ticket_xtn;
+ auto capture = std::make_shared<TlsExtensionCapture>(xtn);
+ client_->SetPacketFilter(capture);
+ Connect();
+
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ EXPECT_FALSE(capture->captured());
+ } else {
+ EXPECT_TRUE(capture->captured());
+ EXPECT_EQ(0U, capture->extension().len());
+ }
+}
+
+TEST_P(TlsConnectGeneric, ConnectWithExpiredTicketAtServer) {
+ SSLInt_SetTicketLifetime(1); // one second
+ // This causes a ticket resumption.
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ Connect();
+ SendReceive();
+
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ExpectResumption(RESUME_NONE);
+
+ SSLExtensionType xtn = (version_ >= SSL_LIBRARY_VERSION_TLS_1_3)
+ ? ssl_tls13_pre_shared_key_xtn
+ : ssl_session_ticket_xtn;
+ auto capture = std::make_shared<TlsExtensionCapture>(xtn);
+ client_->SetPacketFilter(capture);
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake();
+ EXPECT_TRUE(capture->captured());
+ EXPECT_LT(0U, capture->extension().len());
+
+ WAIT_(false, 1000); // Let the ticket expire on the server.
+
+ Handshake();
+ CheckConnected();
+}
+
// This callback switches out the "server" cert used on the server with
// the "client" certificate, which should be the same type.
static int32_t SwitchCertificates(TlsAgent* agent, const SECItem* srvNameArr,
@@ -245,8 +303,8 @@ TEST_P(TlsConnectGeneric, ServerSNICertTypeSwitch) {
// Prior to TLS 1.3, we were not fully ephemeral; though 1.3 fixes that
TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
- TlsInspectorRecordHandshakeMessage* i1 =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i1);
Connect();
CheckKeys();
@@ -255,8 +313,8 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceReuseKey) {
// Restart
Reset();
- TlsInspectorRecordHandshakeMessage* i2 =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i2);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
@@ -277,8 +335,8 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
SECStatus rv =
SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
- TlsInspectorRecordHandshakeMessage* i1 =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ auto i1 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i1);
Connect();
CheckKeys();
@@ -290,8 +348,8 @@ TEST_P(TlsConnectGenericPre13, ConnectEcdheTwiceNewKey) {
server_->EnsureTlsSetup();
rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
- TlsInspectorRecordHandshakeMessage* i2 =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange);
+ auto i2 = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeServerKeyExchange);
server_->SetPacketFilter(i2);
ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
Connect();
@@ -356,7 +414,7 @@ TEST_P(TlsConnectGeneric, TestResumeClientDifferentCipher) {
} else {
ticket_extension = ssl_session_ticket_xtn;
}
- auto ticket_capture = new TlsExtensionCapture(ticket_extension);
+ auto ticket_capture = std::make_shared<TlsExtensionCapture>(ticket_extension);
client_->SetPacketFilter(ticket_capture);
Connect();
CheckKeys(ssl_kea_ecdh, ssl_auth_rsa_sign);
@@ -420,9 +478,15 @@ TEST_P(TlsConnectStream, TestResumptionOverrideCipher) {
Reset();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
- server_->SetPacketFilter(
- new SelectedCipherSuiteReplacer(ChooseAnotherCipher(version_)));
+ server_->SetPacketFilter(std::make_shared<SelectedCipherSuiteReplacer>(
+ ChooseAnotherCipher(version_)));
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ client_->ExpectSendAlert(kTlsAlertIllegalParameter);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ } else {
+ ExpectAlert(client_, kTlsAlertHandshakeFailure);
+ }
ConnectExpectFail();
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
if (version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
@@ -459,7 +523,7 @@ class SelectedVersionReplacer : public TlsHandshakeFilter {
// lower version number on resumption.
TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) {
uint16_t override_version = 0;
- if (mode_ == STREAM) {
+ if (variant_ == ssl_variant_stream) {
switch (version_) {
case SSL_LIBRARY_VERSION_TLS_1_0:
return; // Skip the test.
@@ -492,9 +556,10 @@ TEST_P(TlsConnectGenericPre13, TestResumptionOverrideVersion) {
// Enable the lower version on the client.
client_->SetVersionRange(version_ - 1, version_);
server_->EnableSingleCipher(TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA);
- server_->SetPacketFilter(new SelectedVersionReplacer(override_version));
+ server_->SetPacketFilter(
+ std::make_shared<SelectedVersionReplacer>(override_version));
- ConnectExpectFail();
+ ConnectExpectAlert(client_, kTlsAlertHandshakeFailure);
client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
server_->CheckErrorCode(SSL_ERROR_HANDSHAKE_FAILURE_ALERT);
}
@@ -515,8 +580,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
ExpectResumption(RESUME_TICKET);
- TlsExtensionCapture* c1 =
- new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
+ auto c1 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
client_->SetPacketFilter(c1);
Connect();
SendReceive();
@@ -533,8 +597,7 @@ TEST_F(TlsConnectTest, TestTls13ResumptionTwice) {
ClearStats();
ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
- TlsExtensionCapture* c2 =
- new TlsExtensionCapture(ssl_tls13_pre_shared_key_xtn);
+ auto c2 = std::make_shared<TlsExtensionCapture>(ssl_tls13_pre_shared_key_xtn);
client_->SetPacketFilter(c2);
ExpectResumption(RESUME_TICKET);
Connect();
@@ -579,4 +642,66 @@ TEST_F(TlsConnectTest, TestTls13ResumptionDuplicateNST) {
SendReceive();
}
+TEST_F(TlsConnectTest, TestTls13ResumptionDowngrade) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Try resuming the connection. This will fail resuming the 1.3 session
+ // from before, but will successfully establish a 1.2 connection.
+ Reset();
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_2);
+ Connect();
+
+ // Renegotiate to ensure we don't carryover any state
+ // from the 1.3 resumption attempt.
+ client_->SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->PrepareForRenegotiate();
+ server_->StartRenegotiate();
+ Handshake();
+
+ SendReceive();
+ CheckKeys();
+}
+
+TEST_F(TlsConnectTest, TestTls13ResumptionForcedDowngrade) {
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ Connect();
+
+ SendReceive(); // Need to read so that we absorb the session tickets.
+ CheckKeys();
+
+ // Try resuming the connection.
+ Reset();
+ ConfigureVersion(SSL_LIBRARY_VERSION_TLS_1_3);
+ ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET);
+ // Enable the lower version on the client.
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+
+ // Add filters that set downgrade SH.version to 1.2 and the cipher suite
+ // to one that works with 1.2, so that we don't run into early sanity checks.
+ // We will eventually fail the (sid.version == SH.version) check.
+ std::vector<std::shared_ptr<PacketFilter>> filters;
+ filters.push_back(std::make_shared<SelectedCipherSuiteReplacer>(
+ TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256));
+ filters.push_back(
+ std::make_shared<SelectedVersionReplacer>(SSL_LIBRARY_VERSION_TLS_1_2));
+ server_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(filters));
+
+ client_->ExpectSendAlert(kTlsAlertDecodeError);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac); // Server can't read
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_skip_unittest.cc b/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
index 523a374..a130ef7 100644
--- a/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_skip_unittest.cc
@@ -28,9 +28,9 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter {
protected:
// Takes a record; if it is a handshake record, it removes the first handshake
// message that is of handshake_type_ type.
- virtual PacketFilter::Action FilterRecord(const RecordHeader& record_header,
- const DataBuffer& input,
- DataBuffer* output) {
+ virtual PacketFilter::Action FilterRecord(
+ const TlsRecordHeader& record_header, const DataBuffer& input,
+ DataBuffer* output) {
if (record_header.content_type() != kTlsHandshakeType) {
return KEEP;
}
@@ -78,81 +78,162 @@ class TlsHandshakeSkipFilter : public TlsRecordFilter {
bool skipped_;
};
-class TlsSkipTest
- : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsSkipTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
protected:
TlsSkipTest()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
- void ServerSkipTest(PacketFilter* filter,
+ void ServerSkipTest(std::shared_ptr<PacketFilter> filter,
uint8_t alert = kTlsAlertUnexpectedMessage) {
- auto alert_recorder = new TlsAlertRecorder();
- client_->SetPacketFilter(alert_recorder);
- if (filter) {
- server_->SetPacketFilter(filter);
+ server_->SetPacketFilter(filter);
+ ConnectExpectAlert(client_, alert);
+ }
+};
+
+class Tls13SkipTest : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
+ protected:
+ Tls13SkipTest()
+ : TlsConnectTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+ void ServerSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ EnsureTlsSetup();
+ server_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ client_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ if (variant_ == ssl_variant_stream) {
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ ConnectExpectFail();
+ } else {
+ ConnectExpectFailOneSide(TlsAgent::CLIENT);
+ }
+ client_->CheckErrorCode(error);
+ if (variant_ == ssl_variant_stream) {
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ } else {
+ ASSERT_EQ(TlsAgent::STATE_CONNECTING, server_->state());
}
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(alert, alert_recorder->description());
+ }
+
+ void ClientSkipTest(std::shared_ptr<TlsRecordFilter> filter, int32_t error) {
+ EnsureTlsSetup();
+ client_->SetTlsRecordFilter(filter);
+ filter->EnableDecryption();
+ server_->ExpectSendAlert(kTlsAlertUnexpectedMessage);
+ ConnectExpectFailOneSide(TlsAgent::SERVER);
+
+ server_->CheckErrorCode(error);
+ ASSERT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
+
+ client_->Handshake(); // Make sure to consume the alert the server sends.
}
};
TEST_P(TlsSkipTest, SkipCertificateRsa) {
EnableOnlyStaticRsaCiphers();
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertificateDhe) {
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdhe) {
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipCertificateEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH);
}
TEST_P(TlsSkipTest, SkipServerKeyExchange) {
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExch) {
- auto chain = new ChainedPacketFilter();
- chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
- chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ auto chain = std::make_shared<ChainedPacketFilter>();
+ chain->Add(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ chain->Add(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) {
Reset(TlsAgent::kServerEcdsa256);
- auto chain = new ChainedPacketFilter();
- chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate));
- chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange));
+ auto chain = std::make_shared<ChainedPacketFilter>();
+ chain->Add(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate));
+ chain->Add(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeServerKeyExchange));
ServerSkipTest(chain);
client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE);
}
-INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest,
- ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
- TlsConnectTestBase::kTlsV10));
+TEST_P(Tls13SkipTest, SkipEncryptedExtensions) {
+ ServerSkipTest(std::make_shared<TlsHandshakeSkipFilter>(
+ kTlsHandshakeEncryptedExtensions),
+ SSL_ERROR_RX_UNEXPECTED_CERTIFICATE);
+}
+
+TEST_P(Tls13SkipTest, SkipServerCertificate) {
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(Tls13SkipTest, SkipServerCertificateVerify) {
+ ServerSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+TEST_P(Tls13SkipTest, SkipClientCertificate) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ ClientSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificate),
+ SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY);
+}
+
+TEST_P(Tls13SkipTest, SkipClientCertificateVerify) {
+ client_->SetupClientAuth();
+ server_->RequestClientAuth(true);
+ client_->ExpectReceiveAlert(kTlsAlertUnexpectedMessage);
+ ClientSkipTest(
+ std::make_shared<TlsHandshakeSkipFilter>(kTlsHandshakeCertificateVerify),
+ SSL_ERROR_RX_UNEXPECTED_FINISHED);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ SkipTls10, TlsSkipTest,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsStream,
+ TlsConnectTestBase::kTlsV10));
INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest,
- ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
TlsConnectTestBase::kTlsV11V12));
-
+INSTANTIATE_TEST_CASE_P(Skip13Variants, Tls13SkipTest,
+ TlsConnectTestBase::kTlsVariantsAll);
} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc b/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
index baf24ed..8db1f30 100644
--- a/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_staticrsa_unittest.cc
@@ -48,28 +48,20 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSA) {
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
- TlsInspectorReplaceHandshakeMessage* i1 =
- new TlsInspectorReplaceHandshakeMessage(
- kTlsHandshakeClientKeyExchange,
- DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
+ auto i1 = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
client_->SetPacketFilter(i1);
- auto alert_recorder = new TlsAlertRecorder();
- server_->SetPacketFilter(alert_recorder);
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
+ ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
// Test that a PMS with a bogus version number is handled correctly.
// This test is stream so we can catch the bad_record_mac alert.
TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
- client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_));
- auto alert_recorder = new TlsAlertRecorder();
- server_->SetPacketFilter(alert_recorder);
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
+ client_->SetPacketFilter(
+ std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
// Test that a PMS with a bogus version number is ignored when
@@ -77,7 +69,8 @@ TEST_P(TlsConnectStreamPre13, ConnectStaticRSABogusPMSVersionDetect) {
// ConnectStaticRSABogusPMSVersionDetect.
TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
- client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_));
+ client_->SetPacketFilter(
+ std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
server_->DisableRollbackDetection();
Connect();
}
@@ -86,16 +79,11 @@ TEST_P(TlsConnectGenericPre13, ConnectStaticRSABogusPMSVersionIgnore) {
TEST_P(TlsConnectStreamPre13, ConnectExtendedMasterSecretStaticRSABogusCKE) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- TlsInspectorReplaceHandshakeMessage* inspect =
- new TlsInspectorReplaceHandshakeMessage(
- kTlsHandshakeClientKeyExchange,
- DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
+ auto inspect = std::make_shared<TlsInspectorReplaceHandshakeMessage>(
+ kTlsHandshakeClientKeyExchange,
+ DataBuffer(kBogusClientKeyExchange, sizeof(kBogusClientKeyExchange)));
client_->SetPacketFilter(inspect);
- auto alert_recorder = new TlsAlertRecorder();
- server_->SetPacketFilter(alert_recorder);
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
+ ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
// This test is stream so we can catch the bad_record_mac alert.
@@ -103,19 +91,17 @@ TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_));
- auto alert_recorder = new TlsAlertRecorder();
- server_->SetPacketFilter(alert_recorder);
- ConnectExpectFail();
- EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
- EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description());
+ client_->SetPacketFilter(
+ std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
+ ConnectExpectAlert(server_, kTlsAlertBadRecordMac);
}
TEST_P(TlsConnectStreamPre13,
ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) {
EnableOnlyStaticRsaCiphers();
EnableExtendedMasterSecret();
- client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger(server_));
+ client_->SetPacketFilter(
+ std::make_shared<TlsInspectorClientHelloVersionChanger>(server_));
server_->DisableRollbackDetection();
Connect();
}
diff --git a/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc b/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
index 8b586be..110e3e0 100644
--- a/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_v2_client_hello_unittest.cc
@@ -23,7 +23,7 @@ namespace nss_test {
// Replaces the client hello with an SSLv2 version once.
class SSLv2ClientHelloFilter : public PacketFilter {
public:
- SSLv2ClientHelloFilter(TlsAgent* client, uint16_t version)
+ SSLv2ClientHelloFilter(std::shared_ptr<TlsAgent>& client, uint16_t version)
: replaced_(false),
client_(client),
version_(version),
@@ -121,7 +121,7 @@ class SSLv2ClientHelloFilter : public PacketFilter {
// Update the client random so that the handshake succeeds.
SECStatus rv = SSLInt_UpdateSSLv2ClientRandom(
- client_->ssl_fd(), challenge.data(), challenge.size(),
+ client_.lock()->ssl_fd(), challenge.data(), challenge.size(),
output->data() + hdr_len, output->len() - hdr_len);
EXPECT_EQ(SECSuccess, rv);
@@ -130,7 +130,7 @@ class SSLv2ClientHelloFilter : public PacketFilter {
private:
bool replaced_;
- TlsAgent* client_;
+ std::weak_ptr<TlsAgent> client_;
uint16_t version_;
uint8_t pad_len_;
uint8_t reported_pad_len_;
@@ -141,14 +141,15 @@ class SSLv2ClientHelloFilter : public PacketFilter {
class SSLv2ClientHelloTestF : public TlsConnectTestBase {
public:
- SSLv2ClientHelloTestF() : TlsConnectTestBase(STREAM, 0), filter_(nullptr) {}
+ SSLv2ClientHelloTestF()
+ : TlsConnectTestBase(ssl_variant_stream, 0), filter_(nullptr) {}
- SSLv2ClientHelloTestF(Mode mode, uint16_t version)
- : TlsConnectTestBase(mode, version), filter_(nullptr) {}
+ SSLv2ClientHelloTestF(SSLProtocolVariant variant, uint16_t version)
+ : TlsConnectTestBase(variant, version), filter_(nullptr) {}
void SetUp() {
TlsConnectTestBase::SetUp();
- filter_ = new SSLv2ClientHelloFilter(client_, version_);
+ filter_ = std::make_shared<SSLv2ClientHelloFilter>(client_, version_);
client_->SetPacketFilter(filter_);
}
@@ -185,7 +186,7 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase {
void SetSendEscape(bool send_escape) { filter_->SetSendEscape(send_escape); }
private:
- SSLv2ClientHelloFilter* filter_;
+ std::shared_ptr<SSLv2ClientHelloFilter> filter_;
};
// Parameterized version of SSLv2ClientHelloTestF we can
@@ -193,7 +194,8 @@ class SSLv2ClientHelloTestF : public TlsConnectTestBase {
class SSLv2ClientHelloTest : public SSLv2ClientHelloTestF,
public ::testing::WithParamInterface<uint16_t> {
public:
- SSLv2ClientHelloTest() : SSLv2ClientHelloTestF(STREAM, GetParam()) {}
+ SSLv2ClientHelloTest()
+ : SSLv2ClientHelloTestF(ssl_variant_stream, GetParam()) {}
};
// Test negotiating TLS 1.0 - 1.2.
@@ -202,6 +204,28 @@ TEST_P(SSLv2ClientHelloTest, Connect) {
Connect();
}
+// Sending a v2 ClientHello after a no-op v3 record must fail.
+TEST_P(SSLv2ClientHelloTest, ConnectAfterEmptyV3Record) {
+ DataBuffer buffer;
+
+ size_t idx = 0;
+ idx = buffer.Write(idx, 0x16, 1); // handshake
+ idx = buffer.Write(idx, 0x0301, 2); // record_version
+ (void)buffer.Write(idx, 0U, 2); // length=0
+
+ SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
+ EnsureTlsSetup();
+ client_->SendDirect(buffer);
+
+ // Need padding so the connection doesn't just time out. With a v2
+ // ClientHello parsed as a v3 record we will use the record version
+ // as the record length.
+ SetPadding(255);
+
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
+ EXPECT_EQ(SSL_ERROR_BAD_CLIENT, server_->error_code());
+}
+
// Test negotiating TLS 1.3.
TEST_F(SSLv2ClientHelloTestF, Connect13) {
EnsureTlsSetup();
@@ -211,7 +235,7 @@ TEST_F(SSLv2ClientHelloTestF, Connect13) {
std::vector<uint16_t> cipher_suites = {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256};
SetAvailableCipherSuites(cipher_suites);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
}
@@ -238,7 +262,7 @@ TEST_P(SSLv2ClientHelloTest, SendSecurityEscape) {
// Set a big padding so that the server fails instead of timing out.
SetPadding(255);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
}
// Invalid SSLv2 client hello padding must fail the handshake.
@@ -248,7 +272,7 @@ TEST_P(SSLv2ClientHelloTest, AddErroneousPadding) {
// Append 5 bytes of padding but say it's only 4.
SetPadding(5, 4);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
}
@@ -259,7 +283,7 @@ TEST_P(SSLv2ClientHelloTest, AddErroneousPadding2) {
// Append 5 bytes of padding but say it's 6.
SetPadding(5, 6);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
}
@@ -270,7 +294,7 @@ TEST_P(SSLv2ClientHelloTest, SmallClientRandom) {
// Send a ClientRandom that's too small.
SetClientRandomLength(15);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
}
@@ -288,7 +312,7 @@ TEST_P(SSLv2ClientHelloTest, BigClientRandom) {
// Send a ClientRandom that's too big.
SetClientRandomLength(33);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertIllegalParameter);
EXPECT_EQ(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO, server_->error_code());
}
@@ -297,7 +321,7 @@ TEST_P(SSLv2ClientHelloTest, BigClientRandom) {
TEST_P(SSLv2ClientHelloTest, RequireSafeRenegotiation) {
RequireSafeRenegotiation();
SetAvailableCipherSuite(TLS_DHE_RSA_WITH_AES_128_CBC_SHA);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertHandshakeFailure);
EXPECT_EQ(SSL_ERROR_UNSAFE_NEGOTIATION, server_->error_code());
}
@@ -339,7 +363,7 @@ TEST_F(SSLv2ClientHelloTestF, InappropriateFallbackSCSV) {
TLS_FALLBACK_SCSV};
SetAvailableCipherSuites(cipher_suites);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertInappropriateFallback);
EXPECT_EQ(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT, server_->error_code());
}
diff --git a/nss/gtests/ssl_gtest/ssl_version_unittest.cc b/nss/gtests/ssl_gtest/ssl_version_unittest.cc
index b353849..379a67e 100644
--- a/nss/gtests/ssl_gtest/ssl_version_unittest.cc
+++ b/nss/gtests/ssl_gtest/ssl_version_unittest.cc
@@ -57,7 +57,8 @@ TEST_P(TlsConnectGeneric, ServerNegotiateTls12) {
// SSL_SetDowngradeCheckVersion() API.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) {
client_->SetPacketFilter(
- new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_1));
+ std::make_shared<TlsInspectorClientHelloVersionSetter>(
+ SSL_LIBRARY_VERSION_TLS_1_1));
ConnectExpectFail();
ASSERT_EQ(SSL_ERROR_RX_MALFORMED_SERVER_HELLO, client_->error_code());
}
@@ -65,7 +66,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls11) {
/* Attempt to negotiate the bogus DTLS 1.1 version. */
TEST_F(DtlsConnectTest, TestDtlsVersion11) {
client_->SetPacketFilter(
- new TlsInspectorClientHelloVersionSetter(((~0x0101) & 0xffff)));
+ std::make_shared<TlsInspectorClientHelloVersionSetter>(
+ ((~0x0101) & 0xffff)));
ConnectExpectFail();
// It's kind of surprising that SSL_ERROR_NO_CYPHER_OVERLAP is
// what is returned here, but this is deliberate in ssl3_HandleAlert().
@@ -77,7 +79,8 @@ TEST_F(DtlsConnectTest, TestDtlsVersion11) {
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
EnsureTlsSetup();
client_->SetPacketFilter(
- new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_2));
+ std::make_shared<TlsInspectorClientHelloVersionSetter>(
+ SSL_LIBRARY_VERSION_TLS_1_2));
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
SSL_LIBRARY_VERSION_TLS_1_3);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
@@ -90,7 +93,8 @@ TEST_F(TlsConnectTest, TestDowngradeDetectionToTls12) {
// instead get a handshake failure alert from the server.
TEST_F(TlsConnectTest, TestDowngradeDetectionToTls10) {
client_->SetPacketFilter(
- new TlsInspectorClientHelloVersionSetter(SSL_LIBRARY_VERSION_TLS_1_0));
+ std::make_shared<TlsInspectorClientHelloVersionSetter>(
+ SSL_LIBRARY_VERSION_TLS_1_0));
client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
SSL_LIBRARY_VERSION_TLS_1_1);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_0,
@@ -123,6 +127,18 @@ TEST_F(TlsConnectTest, TestFallbackFromTls13) {
}
#endif
+TEST_P(TlsConnectGeneric, TestFallbackSCSVVersionMatch) {
+ client_->SetFallbackSCSVEnabled(true);
+ Connect();
+}
+
+TEST_P(TlsConnectGenericPre13, TestFallbackSCSVVersionMismatch) {
+ client_->SetFallbackSCSVEnabled(true);
+ server_->SetVersionRange(version_, version_ + 1);
+ ConnectExpectAlert(server_, kTlsAlertInappropriateFallback);
+ client_->CheckErrorCode(SSL_ERROR_INAPPROPRIATE_FALLBACK_ALERT);
+}
+
// The TLS v1.3 spec section C.4 states that 'Implementations MUST NOT send or
// accept any records with a version less than { 3, 0 }'. Thus we will not
// allow version ranges including both SSL v3 and TLS v1.3.
@@ -161,6 +177,13 @@ TEST_P(TlsConnectStream, ConnectTls10AndServerRenegotiateHigher) {
// doesn't fail.
server_->ResetPreliminaryInfo();
server_->StartRenegotiate();
+
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ }
+
Handshake();
if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
// In TLS 1.3, the server detects this problem.
@@ -194,6 +217,11 @@ TEST_P(TlsConnectStream, ConnectTls10AndClientRenegotiateHigher) {
// doesn't fail.
server_->ResetPreliminaryInfo();
client_->StartRenegotiate();
+ if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ ExpectAlert(server_, kTlsAlertUnexpectedMessage);
+ } else {
+ ExpectAlert(client_, kTlsAlertIllegalParameter);
+ }
Handshake();
if (test_version >= SSL_LIBRARY_VERSION_TLS_1_3) {
// In TLS 1.3, the server detects this problem.
@@ -225,13 +253,14 @@ TEST_F(TlsConnectTest, Tls13RejectsRehandshakeServer) {
TEST_P(TlsConnectGeneric, AlertBeforeServerHello) {
EnsureTlsSetup();
+ client_->ExpectReceiveAlert(kTlsAlertUnrecognizedName, kTlsAlertWarning);
client_->StartConnect();
server_->StartConnect();
client_->Handshake(); // Send ClientHello.
static const uint8_t kWarningAlert[] = {kTlsAlertWarning,
kTlsAlertUnrecognizedName};
DataBuffer alert;
- TlsAgentTestBase::MakeRecord(mode_, kTlsAlertType,
+ TlsAgentTestBase::MakeRecord(variant_, kTlsAlertType,
SSL_LIBRARY_VERSION_TLS_1_0, kWarningAlert,
PR_ARRAY_SIZE(kWarningAlert), &alert);
client_->adapter()->PacketReceived(alert);
@@ -246,11 +275,12 @@ class Tls13NoSupportedVersions : public TlsConnectStreamTls12 {
SSL_LIBRARY_VERSION_TLS_1_2);
server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2, max_server_version);
client_->SetPacketFilter(
- new TlsInspectorClientHelloVersionSetter(overwritten_client_version));
- auto capture =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerHello);
+ std::make_shared<TlsInspectorClientHelloVersionSetter>(
+ overwritten_client_version));
+ auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeServerHello);
server_->SetPacketFilter(capture);
- ConnectExpectFail();
+ ConnectExpectAlert(server_, kTlsAlertDecryptError);
client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
const DataBuffer& server_hello = capture->buffer();
@@ -281,11 +311,14 @@ TEST_F(Tls13NoSupportedVersions,
// Offer 1.3 but with ClientHello.legacy_version == TLS 1.4. This
// causes a bad MAC error when we read EncryptedExtensions.
TEST_F(TlsConnectStreamTls13, Tls14ClientHelloWithSupportedVersions) {
- client_->SetPacketFilter(new TlsInspectorClientHelloVersionSetter(
- SSL_LIBRARY_VERSION_TLS_1_3 + 1));
- auto capture =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerHello);
+ client_->SetPacketFilter(
+ std::make_shared<TlsInspectorClientHelloVersionSetter>(
+ SSL_LIBRARY_VERSION_TLS_1_3 + 1));
+ auto capture = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeServerHello);
server_->SetPacketFilter(capture);
+ client_->ExpectSendAlert(kTlsAlertBadRecordMac);
+ server_->ExpectSendAlert(kTlsAlertBadRecordMac);
ConnectExpectFail();
client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
diff --git a/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc b/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
new file mode 100644
index 0000000..eda9683
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_versionpolicy_unittest.cc
@@ -0,0 +1,394 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "nss.h"
+#include "secerr.h"
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include "gtest_utils.h"
+#include "scoped_ptrs.h"
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+#include <iostream>
+
+namespace nss_test {
+
+std::string GetSSLVersionString(uint16_t v) {
+ switch (v) {
+ case SSL_LIBRARY_VERSION_3_0:
+ return "ssl3";
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ return "tls1.0";
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ return "tls1.1";
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ return "tls1.2";
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ return "tls1.3";
+ case SSL_LIBRARY_VERSION_NONE:
+ return "NONE";
+ }
+ if (v < SSL_LIBRARY_VERSION_3_0) {
+ return "undefined-too-low";
+ }
+ return "undefined-too-high";
+}
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const SSLVersionRange& vr) {
+ return stream << GetSSLVersionString(vr.min) << ","
+ << GetSSLVersionString(vr.max);
+}
+
+class VersionRangeWithLabel {
+ public:
+ VersionRangeWithLabel(const std::string& label, const SSLVersionRange& vr)
+ : label_(label), vr_(vr) {}
+ VersionRangeWithLabel(const std::string& label, uint16_t min, uint16_t max)
+ : label_(label) {
+ vr_.min = min;
+ vr_.max = max;
+ }
+ VersionRangeWithLabel(const std::string& label) : label_(label) {
+ vr_.min = vr_.max = SSL_LIBRARY_VERSION_NONE;
+ }
+
+ void WriteStream(std::ostream& stream) const {
+ stream << " " << label_ << ": " << vr_;
+ }
+
+ uint16_t min() const { return vr_.min; }
+ uint16_t max() const { return vr_.max; }
+ SSLVersionRange range() const { return vr_; }
+
+ private:
+ std::string label_;
+ SSLVersionRange vr_;
+};
+
+inline std::ostream& operator<<(std::ostream& stream,
+ const VersionRangeWithLabel& vrwl) {
+ vrwl.WriteStream(stream);
+ return stream;
+}
+
+typedef std::tuple<SSLProtocolVariant, // variant
+ uint16_t, // policy min
+ uint16_t, // policy max
+ uint16_t, // input min
+ uint16_t> // input max
+ PolicyVersionRangeInput;
+
+class TestPolicyVersionRange
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<PolicyVersionRangeInput> {
+ public:
+ TestPolicyVersionRange()
+ : TlsConnectTestBase(std::get<0>(GetParam()), 0),
+ variant_(std::get<0>(GetParam())),
+ policy_("policy", std::get<1>(GetParam()), std::get<2>(GetParam())),
+ input_("input", std::get<3>(GetParam()), std::get<4>(GetParam())),
+ library_("supported-by-library",
+ ((variant_ == ssl_variant_stream)
+ ? SSL_LIBRARY_VERSION_MIN_SUPPORTED_STREAM
+ : SSL_LIBRARY_VERSION_MIN_SUPPORTED_DATAGRAM),
+ SSL_LIBRARY_VERSION_MAX_SUPPORTED) {
+ TlsConnectTestBase::SkipVersionChecks();
+ }
+
+ void SetPolicy(const SSLVersionRange& policy) {
+ NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, NSS_USE_POLICY_IN_SSL, 0);
+
+ SECStatus rv;
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, policy.min);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, policy.max);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, policy.min);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, policy.max);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+
+ void CreateDummySocket(std::shared_ptr<DummyPrSocket>* dummy_socket,
+ ScopedPRFileDesc* ssl_fd) {
+ (*dummy_socket).reset(new DummyPrSocket("dummy", variant_));
+ *ssl_fd = (*dummy_socket)->CreateFD();
+ if (variant_ == ssl_variant_stream) {
+ SSL_ImportFD(nullptr, ssl_fd->get());
+ } else {
+ DTLS_ImportFD(nullptr, ssl_fd->get());
+ }
+ }
+
+ bool GetOverlap(const SSLVersionRange& r1, const SSLVersionRange& r2,
+ SSLVersionRange* overlap) {
+ if (r1.min == SSL_LIBRARY_VERSION_NONE ||
+ r1.max == SSL_LIBRARY_VERSION_NONE ||
+ r2.min == SSL_LIBRARY_VERSION_NONE ||
+ r2.max == SSL_LIBRARY_VERSION_NONE) {
+ return false;
+ }
+
+ SSLVersionRange temp;
+ temp.min = PR_MAX(r1.min, r2.min);
+ temp.max = PR_MIN(r1.max, r2.max);
+
+ if (temp.min > temp.max) {
+ return false;
+ }
+
+ *overlap = temp;
+ return true;
+ }
+
+ bool IsValidInputForVersionRangeSet(SSLVersionRange* expectedEffectiveRange) {
+ if (input_.min() <= SSL_LIBRARY_VERSION_3_0 &&
+ input_.max() >= SSL_LIBRARY_VERSION_TLS_1_3) {
+ // This is always invalid input, independent of policy
+ return false;
+ }
+
+ if (input_.min() < library_.min() || input_.max() > library_.max() ||
+ input_.min() > input_.max()) {
+ // Asking for unsupported ranges is invalid input for VersionRangeSet
+ // APIs, regardless of overlap.
+ return false;
+ }
+
+ SSLVersionRange overlap_with_library;
+ if (!GetOverlap(input_.range(), library_.range(), &overlap_with_library)) {
+ return false;
+ }
+
+ SSLVersionRange overlap_with_library_and_policy;
+ if (!GetOverlap(overlap_with_library, policy_.range(),
+ &overlap_with_library_and_policy)) {
+ return false;
+ }
+
+ RemoveConflictingVersions(variant_, &overlap_with_library_and_policy);
+ *expectedEffectiveRange = overlap_with_library_and_policy;
+ return true;
+ }
+
+ void RemoveConflictingVersions(SSLProtocolVariant variant,
+ SSLVersionRange* r) {
+ ASSERT_TRUE(r != nullptr);
+ if (r->max >= SSL_LIBRARY_VERSION_TLS_1_3 &&
+ r->min < SSL_LIBRARY_VERSION_TLS_1_0) {
+ r->min = SSL_LIBRARY_VERSION_TLS_1_0;
+ }
+ }
+
+ void SetUp() {
+ SetPolicy(policy_.range());
+ TlsConnectTestBase::SetUp();
+ }
+
+ void TearDown() {
+ TlsConnectTestBase::TearDown();
+ saved_version_policy_.RestoreOriginalPolicy();
+ }
+
+ protected:
+ class VersionPolicy {
+ public:
+ VersionPolicy() { SaveOriginalPolicy(); }
+
+ void RestoreOriginalPolicy() {
+ SECStatus rv;
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MIN_POLICY, saved_min_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_TLS_VERSION_MAX_POLICY, saved_max_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MIN_POLICY, saved_min_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionSet(NSS_DTLS_VERSION_MAX_POLICY, saved_max_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ // If it wasn't set initially, clear the bit that we set.
+ if (!(saved_algorithm_policy_ & NSS_USE_POLICY_IN_SSL)) {
+ rv = NSS_SetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY, 0,
+ NSS_USE_POLICY_IN_SSL);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+ }
+
+ private:
+ void SaveOriginalPolicy() {
+ SECStatus rv;
+ rv = NSS_OptionGet(NSS_TLS_VERSION_MIN_POLICY, &saved_min_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_TLS_VERSION_MAX_POLICY, &saved_max_tls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_DTLS_VERSION_MIN_POLICY, &saved_min_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_OptionGet(NSS_DTLS_VERSION_MAX_POLICY, &saved_max_dtls_);
+ ASSERT_EQ(SECSuccess, rv);
+ rv = NSS_GetAlgorithmPolicy(SEC_OID_APPLY_SSL_POLICY,
+ &saved_algorithm_policy_);
+ ASSERT_EQ(SECSuccess, rv);
+ }
+
+ int32_t saved_min_tls_;
+ int32_t saved_max_tls_;
+ int32_t saved_min_dtls_;
+ int32_t saved_max_dtls_;
+ uint32_t saved_algorithm_policy_;
+ };
+
+ VersionPolicy saved_version_policy_;
+
+ SSLProtocolVariant variant_;
+ const VersionRangeWithLabel policy_;
+ const VersionRangeWithLabel input_;
+ const VersionRangeWithLabel library_;
+};
+
+static const uint16_t kExpandedVersionsArr[] = {
+ /* clang-format off */
+ SSL_LIBRARY_VERSION_3_0 - 1,
+ SSL_LIBRARY_VERSION_3_0,
+ SSL_LIBRARY_VERSION_TLS_1_0,
+ SSL_LIBRARY_VERSION_TLS_1_1,
+ SSL_LIBRARY_VERSION_TLS_1_2,
+#ifndef NSS_DISABLE_TLS_1_3
+ SSL_LIBRARY_VERSION_TLS_1_3,
+#endif
+ SSL_LIBRARY_VERSION_MAX_SUPPORTED + 1
+ /* clang-format on */
+};
+static ::testing::internal::ParamGenerator<uint16_t> kExpandedVersions =
+ ::testing::ValuesIn(kExpandedVersionsArr);
+
+TEST_P(TestPolicyVersionRange, TestAllTLSVersionsAndPolicyCombinations) {
+ ASSERT_TRUE(variant_ == ssl_variant_stream ||
+ variant_ == ssl_variant_datagram)
+ << "testing unsupported ssl variant";
+
+ std::cerr << "testing: " << variant_ << policy_ << input_ << library_
+ << std::endl;
+
+ SSLVersionRange supported_range;
+ SECStatus rv = SSL_VersionRangeGetSupported(variant_, &supported_range);
+ VersionRangeWithLabel supported("SSL_VersionRangeGetSupported",
+ supported_range);
+
+ std::cerr << supported << std::endl;
+
+ std::shared_ptr<DummyPrSocket> dummy_socket;
+ ScopedPRFileDesc ssl_fd;
+ CreateDummySocket(&dummy_socket, &ssl_fd);
+
+ SECStatus rv_socket;
+ SSLVersionRange overlap_policy_and_lib;
+ if (!GetOverlap(policy_.range(), library_.range(), &overlap_policy_and_lib)) {
+ EXPECT_EQ(SECFailure, rv)
+ << "expected SSL_VersionRangeGetSupported to fail with invalid policy";
+
+ SSLVersionRange enabled_range;
+ rv = SSL_VersionRangeGetDefault(variant_, &enabled_range);
+ EXPECT_EQ(SECFailure, rv)
+ << "expected SSL_VersionRangeGetDefault to fail with invalid policy";
+
+ SSLVersionRange enabled_range_on_socket;
+ rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &enabled_range_on_socket);
+ EXPECT_EQ(SECFailure, rv_socket)
+ << "expected SSL_VersionRangeGet to fail with invalid policy";
+
+ ConnectExpectFail();
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected SSL_VersionRangeGetSupported to succeed with valid policy";
+
+ EXPECT_TRUE(supported_range.min != SSL_LIBRARY_VERSION_NONE &&
+ supported_range.max != SSL_LIBRARY_VERSION_NONE)
+ << "expected SSL_VersionRangeGetSupported to return real values with "
+ "valid policy";
+
+ RemoveConflictingVersions(variant_, &overlap_policy_and_lib);
+ VersionRangeWithLabel overlap_info("overlap", overlap_policy_and_lib);
+
+ EXPECT_TRUE(supported_range == overlap_policy_and_lib)
+ << "expected range from GetSupported to be identical with calculated "
+ "overlap "
+ << overlap_info;
+
+ // We don't know which versions are "enabled by default" by the library,
+ // therefore we don't know if there's overlap between the default
+ // and the policy, and therefore, we don't if TLS connections should
+ // be successful or fail in this combination.
+ // Therefore we don't test if we can connect, without having configured a
+ // version range explicitly.
+
+ // Now start testing with supplied input.
+
+ SSLVersionRange expected_effective_range;
+ bool is_valid_input =
+ IsValidInputForVersionRangeSet(&expected_effective_range);
+
+ SSLVersionRange temp_input = input_.range();
+ rv = SSL_VersionRangeSetDefault(variant_, &temp_input);
+ rv_socket = SSL_VersionRangeSet(ssl_fd.get(), &temp_input);
+
+ if (!is_valid_input) {
+ EXPECT_EQ(SECFailure, rv)
+ << "expected failure return from SSL_VersionRangeSetDefault";
+
+ EXPECT_EQ(SECFailure, rv_socket)
+ << "expected failure return from SSL_VersionRangeSet";
+ return;
+ }
+
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected successful return from SSL_VersionRangeSetDefault";
+
+ EXPECT_EQ(SECSuccess, rv_socket)
+ << "expected successful return from SSL_VersionRangeSet";
+
+ SSLVersionRange effective;
+ SSLVersionRange effective_socket;
+
+ rv = SSL_VersionRangeGetDefault(variant_, &effective);
+ EXPECT_EQ(SECSuccess, rv)
+ << "expected successful return from SSL_VersionRangeGetDefault";
+
+ rv_socket = SSL_VersionRangeGet(ssl_fd.get(), &effective_socket);
+ EXPECT_EQ(SECSuccess, rv_socket)
+ << "expected successful return from SSL_VersionRangeGet";
+
+ VersionRangeWithLabel expected_info("expectation", expected_effective_range);
+ VersionRangeWithLabel effective_info("effectively-enabled", effective);
+
+ EXPECT_TRUE(expected_effective_range == effective)
+ << "range returned by SSL_VersionRangeGetDefault doesn't match "
+ "expectation: "
+ << expected_info << effective_info;
+
+ EXPECT_TRUE(expected_effective_range == effective_socket)
+ << "range returned by SSL_VersionRangeGet doesn't match "
+ "expectation: "
+ << expected_info << effective_info;
+
+ // Because we found overlap between policy and supported versions,
+ // and because we have used SetDefault to enable at least one version,
+ // it should be possible to execute an SSL/TLS connection.
+ Connect();
+}
+
+INSTANTIATE_TEST_CASE_P(TLSVersionRanges, TestPolicyVersionRange,
+ ::testing::Combine(TlsConnectTestBase::kTlsVariantsAll,
+ kExpandedVersions, kExpandedVersions,
+ kExpandedVersions,
+ kExpandedVersions));
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/test_io.cc b/nss/gtests/ssl_gtest/test_io.cc
index f3fd0b2..b9f0c67 100644
--- a/nss/gtests/ssl_gtest/test_io.cc
+++ b/nss/gtests/ssl_gtest/test_io.cc
@@ -15,314 +15,33 @@
#include "prlog.h"
#include "prthread.h"
-#include "databuffer.h"
-
extern bool g_ssl_gtest_verbose;
namespace nss_test {
-static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER;
-
-#define UNIMPLEMENTED() \
- std::cerr << "Call to unimplemented function " << __FUNCTION__ << std::endl; \
- PR_ASSERT(PR_FALSE); \
- PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0)
-
#define LOG(a) std::cerr << name_ << ": " << a << std::endl
#define LOGV(a) \
do { \
if (g_ssl_gtest_verbose) LOG(a); \
} while (false)
-class Packet : public DataBuffer {
- public:
- Packet(const DataBuffer &buf) : DataBuffer(buf), offset_(0) {}
-
- void Advance(size_t delta) {
- PR_ASSERT(offset_ + delta <= len());
- offset_ = std::min(len(), offset_ + delta);
- }
-
- size_t offset() const { return offset_; }
- size_t remaining() const { return len() - offset_; }
-
- private:
- size_t offset_;
-};
-
-// Implementation of NSPR methods
-static PRStatus DummyClose(PRFileDesc *f) {
- DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
- f->secret = nullptr;
- f->dtor(f);
- delete io;
- return PR_SUCCESS;
-}
-
-static int32_t DummyRead(PRFileDesc *f, void *buf, int32_t length) {
- DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
- return io->Read(buf, length);
-}
-
-static int32_t DummyWrite(PRFileDesc *f, const void *buf, int32_t length) {
- DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
- return io->Write(buf, length);
-}
-
-static int32_t DummyAvailable(PRFileDesc *f) {
- UNIMPLEMENTED();
- return -1;
-}
-
-int64_t DummyAvailable64(PRFileDesc *f) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static PRStatus DummySync(PRFileDesc *f) {
- UNIMPLEMENTED();
- return PR_FAILURE;
-}
-
-static int32_t DummySeek(PRFileDesc *f, int32_t offset, PRSeekWhence how) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static int64_t DummySeek64(PRFileDesc *f, int64_t offset, PRSeekWhence how) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static PRStatus DummyFileInfo(PRFileDesc *f, PRFileInfo *info) {
- UNIMPLEMENTED();
- return PR_FAILURE;
-}
-
-static PRStatus DummyFileInfo64(PRFileDesc *f, PRFileInfo64 *info) {
- UNIMPLEMENTED();
- return PR_FAILURE;
-}
-
-static int32_t DummyWritev(PRFileDesc *f, const PRIOVec *iov, int32_t iov_size,
- PRIntervalTime to) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static PRStatus DummyConnect(PRFileDesc *f, const PRNetAddr *addr,
- PRIntervalTime to) {
- UNIMPLEMENTED();
- return PR_FAILURE;
-}
-
-static PRFileDesc *DummyAccept(PRFileDesc *sd, PRNetAddr *addr,
- PRIntervalTime to) {
- UNIMPLEMENTED();
- return nullptr;
-}
-
-static PRStatus DummyBind(PRFileDesc *f, const PRNetAddr *addr) {
- UNIMPLEMENTED();
- return PR_FAILURE;
-}
-
-static PRStatus DummyListen(PRFileDesc *f, int32_t depth) {
- UNIMPLEMENTED();
- return PR_FAILURE;
-}
-
-static PRStatus DummyShutdown(PRFileDesc *f, int32_t how) {
- DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
- io->Reset();
- return PR_SUCCESS;
-}
-
-// This function does not support peek.
-static int32_t DummyRecv(PRFileDesc *f, void *buf, int32_t buflen,
- int32_t flags, PRIntervalTime to) {
- PR_ASSERT(flags == 0);
- if (flags != 0) {
- PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
- return -1;
- }
-
- DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret);
-
- if (io->mode() == DGRAM) {
- return io->Recv(buf, buflen);
- } else {
- return io->Read(buf, buflen);
- }
-}
-
-// Note: this is always nonblocking and assumes a zero timeout.
-static int32_t DummySend(PRFileDesc *f, const void *buf, int32_t amount,
- int32_t flags, PRIntervalTime to) {
- int32_t written = DummyWrite(f, buf, amount);
- return written;
-}
-
-static int32_t DummyRecvfrom(PRFileDesc *f, void *buf, int32_t amount,
- int32_t flags, PRNetAddr *addr,
- PRIntervalTime to) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static int32_t DummySendto(PRFileDesc *f, const void *buf, int32_t amount,
- int32_t flags, const PRNetAddr *addr,
- PRIntervalTime to) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static int16_t DummyPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static int32_t DummyAcceptRead(PRFileDesc *sd, PRFileDesc **nd,
- PRNetAddr **raddr, void *buf, int32_t amount,
- PRIntervalTime t) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static int32_t DummyTransmitFile(PRFileDesc *sd, PRFileDesc *f,
- const void *headers, int32_t hlen,
- PRTransmitFileFlags flags, PRIntervalTime t) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static PRStatus DummyGetpeername(PRFileDesc *f, PRNetAddr *addr) {
- // TODO: Modify to return unique names for each channel
- // somehow, as opposed to always the same static address. The current
- // implementation messes up the session cache, which is why it's off
- // elsewhere
- addr->inet.family = PR_AF_INET;
- addr->inet.port = 0;
- addr->inet.ip = 0;
-
- return PR_SUCCESS;
-}
-
-static PRStatus DummyGetsockname(PRFileDesc *f, PRNetAddr *addr) {
- UNIMPLEMENTED();
- return PR_FAILURE;
-}
-
-static PRStatus DummyGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) {
- switch (opt->option) {
- case PR_SockOpt_Nonblocking:
- opt->value.non_blocking = PR_TRUE;
- return PR_SUCCESS;
- default:
- UNIMPLEMENTED();
- break;
- }
-
- return PR_FAILURE;
-}
-
-// Imitate setting socket options. These are mostly noops.
-static PRStatus DummySetsockoption(PRFileDesc *f,
- const PRSocketOptionData *opt) {
- switch (opt->option) {
- case PR_SockOpt_Nonblocking:
- return PR_SUCCESS;
- case PR_SockOpt_NoDelay:
- return PR_SUCCESS;
- default:
- UNIMPLEMENTED();
- break;
- }
-
- return PR_FAILURE;
-}
-
-static int32_t DummySendfile(PRFileDesc *out, PRSendFileData *in,
- PRTransmitFileFlags flags, PRIntervalTime to) {
- UNIMPLEMENTED();
- return -1;
-}
-
-static PRStatus DummyConnectContinue(PRFileDesc *f, int16_t flags) {
- UNIMPLEMENTED();
- return PR_FAILURE;
-}
-
-static int32_t DummyReserved(PRFileDesc *f) {
- UNIMPLEMENTED();
- return -1;
-}
-
-DummyPrSocket::~DummyPrSocket() { Reset(); }
-
-void DummyPrSocket::SetPacketFilter(PacketFilter *filter) {
- if (filter_) {
- delete filter_;
- }
+void DummyPrSocket::SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
filter_ = filter;
}
-void DummyPrSocket::Reset() {
- delete filter_;
- if (peer_) {
- peer_->SetPeer(nullptr);
- peer_ = nullptr;
- }
- while (!input_.empty()) {
- Packet *front = input_.front();
- input_.pop();
- delete front;
- }
-}
-
-static const struct PRIOMethods DummyMethods = {
- PR_DESC_LAYERED, DummyClose,
- DummyRead, DummyWrite,
- DummyAvailable, DummyAvailable64,
- DummySync, DummySeek,
- DummySeek64, DummyFileInfo,
- DummyFileInfo64, DummyWritev,
- DummyConnect, DummyAccept,
- DummyBind, DummyListen,
- DummyShutdown, DummyRecv,
- DummySend, DummyRecvfrom,
- DummySendto, DummyPoll,
- DummyAcceptRead, DummyTransmitFile,
- DummyGetsockname, DummyGetpeername,
- DummyReserved, DummyReserved,
- DummyGetsockoption, DummySetsockoption,
- DummySendfile, DummyConnectContinue,
- DummyReserved, DummyReserved,
- DummyReserved, DummyReserved};
-
-PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) {
- if (test_fd_identity == PR_INVALID_IO_LAYER) {
- test_fd_identity = PR_GetUniqueIdentity("testtransportadapter");
- }
-
- PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods));
- fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode));
-
- return fd;
-}
-
-DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) {
- return reinterpret_cast<DummyPrSocket *>(fd->secret);
+ScopedPRFileDesc DummyPrSocket::CreateFD() {
+ static PRDescIdentity test_fd_identity =
+ PR_GetUniqueIdentity("testtransportadapter");
+ return DummyIOLayerMethods::CreateFD(test_fd_identity, this);
}
void DummyPrSocket::PacketReceived(const DataBuffer &packet) {
- input_.push(new Packet(packet));
+ input_.push(Packet(packet));
}
-int32_t DummyPrSocket::Read(void *data, int32_t len) {
- PR_ASSERT(mode_ == STREAM);
-
- if (mode_ != STREAM) {
+int32_t DummyPrSocket::Read(PRFileDesc *f, void *data, int32_t len) {
+ PR_ASSERT(variant_ == ssl_variant_stream);
+ if (variant_ != ssl_variant_stream) {
PR_SetError(PR_INVALID_METHOD_ERROR, 0);
return -1;
}
@@ -333,45 +52,54 @@ int32_t DummyPrSocket::Read(void *data, int32_t len) {
return -1;
}
- Packet *front = input_.front();
+ auto &front = input_.front();
size_t to_read =
- std::min(static_cast<size_t>(len), front->len() - front->offset());
- memcpy(data, static_cast<const void *>(front->data() + front->offset()),
+ std::min(static_cast<size_t>(len), front.len() - front.offset());
+ memcpy(data, static_cast<const void *>(front.data() + front.offset()),
to_read);
- front->Advance(to_read);
+ front.Advance(to_read);
- if (!front->remaining()) {
+ if (!front.remaining()) {
input_.pop();
- delete front;
}
return static_cast<int32_t>(to_read);
}
-int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) {
+int32_t DummyPrSocket::Recv(PRFileDesc *f, void *buf, int32_t buflen,
+ int32_t flags, PRIntervalTime to) {
+ PR_ASSERT(flags == 0);
+ if (flags != 0) {
+ PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
+ return -1;
+ }
+
+ if (variant() != ssl_variant_datagram) {
+ return Read(f, buf, buflen);
+ }
+
if (input_.empty()) {
PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
return -1;
}
- Packet *front = input_.front();
- if (static_cast<size_t>(buflen) < front->len()) {
+ auto &front = input_.front();
+ if (static_cast<size_t>(buflen) < front.len()) {
PR_ASSERT(false);
PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0);
return -1;
}
- size_t count = front->len();
- memcpy(buf, front->data(), count);
+ size_t count = front.len();
+ memcpy(buf, front.data(), count);
input_.pop();
- delete front;
-
return static_cast<int32_t>(count);
}
-int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
- if (!peer_ || !writeable_) {
+int32_t DummyPrSocket::Write(PRFileDesc *f, const void *buf, int32_t length) {
+ auto peer = peer_.lock();
+ if (!peer || !writeable_) {
PR_SetError(PR_IO_ERROR, 0);
return -1;
}
@@ -387,14 +115,14 @@ int32_t DummyPrSocket::Write(const void *buf, int32_t length) {
case PacketFilter::CHANGE:
LOG("Original packet: " << packet);
LOG("Filtered packet: " << filtered);
- peer_->PacketReceived(filtered);
+ peer->PacketReceived(filtered);
break;
case PacketFilter::DROP:
LOG("Droppped packet: " << packet);
break;
case PacketFilter::KEEP:
LOGV("Packet: " << packet);
- peer_->PacketReceived(packet);
+ peer->PacketReceived(packet);
break;
}
// libssl can't handle it if this reports something other than the length
@@ -415,43 +143,31 @@ void Poller::Shutdown() {
instance = nullptr;
}
-Poller::~Poller() {
- while (!timers_.empty()) {
- Timer *timer = timers_.top();
- timers_.pop();
- delete timer;
- }
-}
+void Poller::Wait(Event event, std::shared_ptr<DummyPrSocket> &adapter,
+ PollTarget *target, PollCallback cb) {
+ assert(event < TIMER_EVENT);
+ if (event >= TIMER_EVENT) return;
-void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target,
- PollCallback cb) {
+ std::unique_ptr<Waiter> waiter;
auto it = waiters_.find(adapter);
- Waiter *waiter;
-
if (it == waiters_.end()) {
- waiter = new Waiter(adapter);
+ waiter.reset(new Waiter(adapter));
} else {
- waiter = it->second;
+ waiter = std::move(it->second);
}
- assert(event < TIMER_EVENT);
- if (event >= TIMER_EVENT) return;
-
waiter->targets_[event] = target;
waiter->callbacks_[event] = cb;
- waiters_[adapter] = waiter;
+ waiters_[adapter] = std::move(waiter);
}
-void Poller::Cancel(Event event, DummyPrSocket *adapter) {
+void Poller::Cancel(Event event, std::shared_ptr<DummyPrSocket> &adapter) {
auto it = waiters_.find(adapter);
- Waiter *waiter;
-
if (it == waiters_.end()) {
return;
}
- waiter = it->second;
-
+ auto &waiter = it->second;
waiter->targets_[event] = nullptr;
waiter->callbacks_[event] = nullptr;
@@ -460,13 +176,12 @@ void Poller::Cancel(Event event, DummyPrSocket *adapter) {
if (waiter->callbacks_[i]) return;
}
- delete waiter;
waiters_.erase(adapter);
}
void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb,
- Timer **timer) {
- Timer *t = new Timer(PR_Now() + timer_ms * 1000, target, cb);
+ std::shared_ptr<Timer> *timer) {
+ auto t = std::make_shared<Timer>(PR_Now() + timer_ms * 1000, target, cb);
timers_.push(t);
if (timer) *timer = t;
}
@@ -482,7 +197,7 @@ bool Poller::Poll() {
// Figure out the timer for the select.
if (!timers_.empty()) {
- Timer *first_timer = timers_.top();
+ auto first_timer = timers_.top();
if (now >= first_timer->deadline_) {
// Timer expired.
timeout = PR_INTERVAL_NO_WAIT;
@@ -493,7 +208,7 @@ bool Poller::Poll() {
}
for (auto it = waiters_.begin(); it != waiters_.end(); ++it) {
- Waiter *waiter = it->second;
+ auto &waiter = it->second;
if (waiter->callbacks_[READABLE_EVENT]) {
if (waiter->io_->readable()) {
@@ -522,12 +237,11 @@ bool Poller::Poll() {
while (!timers_.empty()) {
if (now < timers_.top()->deadline_) break;
- Timer *timer = timers_.top();
+ auto timer = timers_.top();
timers_.pop();
if (timer->callback_) {
timer->callback_(timer->target_, TIMER_EVENT);
}
- delete timer;
}
return true;
diff --git a/nss/gtests/ssl_gtest/test_io.h b/nss/gtests/ssl_gtest/test_io.h
index b78db0d..ac24972 100644
--- a/nss/gtests/ssl_gtest/test_io.h
+++ b/nss/gtests/ssl_gtest/test_io.h
@@ -14,12 +14,15 @@
#include <queue>
#include <string>
+#include "databuffer.h"
+#include "dummy_io.h"
#include "prio.h"
+#include "scoped_ptrs.h"
+#include "sslt.h"
namespace nss_test {
class DataBuffer;
-class Packet;
class DummyPrSocket; // Fwd decl.
// Allow us to inspect a packet before it is written.
@@ -42,49 +45,59 @@ class PacketFilter {
virtual Action Filter(const DataBuffer& input, DataBuffer* output) = 0;
};
-enum Mode { STREAM, DGRAM };
-
-inline std::ostream& operator<<(std::ostream& os, Mode m) {
- return os << ((m == STREAM) ? "TLS" : "DTLS");
-}
-
-class DummyPrSocket {
+class DummyPrSocket : public DummyIOLayerMethods {
public:
- ~DummyPrSocket();
+ DummyPrSocket(const std::string& name, SSLProtocolVariant variant)
+ : name_(name),
+ variant_(variant),
+ peer_(),
+ input_(),
+ filter_(nullptr),
+ writeable_(true) {}
+ virtual ~DummyPrSocket() {}
- static PRFileDesc* CreateFD(const std::string& name,
- Mode mode); // Returns an FD.
- static DummyPrSocket* GetAdapter(PRFileDesc* fd);
+ // Create a file descriptor that will reference this object. The fd must not
+ // live longer than this adapter; call PR_Close() before.
+ ScopedPRFileDesc CreateFD();
- DummyPrSocket* peer() const { return peer_; }
- void SetPeer(DummyPrSocket* peer) { peer_ = peer; }
- void SetPacketFilter(PacketFilter* filter);
+ std::weak_ptr<DummyPrSocket>& peer() { return peer_; }
+ void SetPeer(const std::shared_ptr<DummyPrSocket>& peer) { peer_ = peer; }
+ void SetPacketFilter(std::shared_ptr<PacketFilter> filter);
// Drops peer, packet filter and any outstanding packets.
void Reset();
void PacketReceived(const DataBuffer& data);
- int32_t Read(void* data, int32_t len);
- int32_t Recv(void* buf, int32_t buflen);
- int32_t Write(const void* buf, int32_t length);
+ int32_t Read(PRFileDesc* f, void* data, int32_t len) override;
+ int32_t Recv(PRFileDesc* f, void* buf, int32_t buflen, int32_t flags,
+ PRIntervalTime to) override;
+ int32_t Write(PRFileDesc* f, const void* buf, int32_t length) override;
void CloseWrites() { writeable_ = false; }
- Mode mode() const { return mode_; }
+ SSLProtocolVariant variant() const { return variant_; }
bool readable() const { return !input_.empty(); }
private:
- DummyPrSocket(const std::string& name, Mode mode)
- : name_(name),
- mode_(mode),
- peer_(nullptr),
- input_(),
- filter_(nullptr),
- writeable_(true) {}
+ class Packet : public DataBuffer {
+ public:
+ Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {}
+
+ void Advance(size_t delta) {
+ PR_ASSERT(offset_ + delta <= len());
+ offset_ = std::min(len(), offset_ + delta);
+ }
+
+ size_t offset() const { return offset_; }
+ size_t remaining() const { return len() - offset_; }
+
+ private:
+ size_t offset_;
+ };
const std::string name_;
- Mode mode_;
- DummyPrSocket* peer_;
- std::queue<Packet*> input_;
- PacketFilter* filter_;
+ SSLProtocolVariant variant_;
+ std::weak_ptr<DummyPrSocket> peer_;
+ std::queue<Packet> input_;
+ std::shared_ptr<PacketFilter> filter_;
bool writeable_;
};
@@ -111,40 +124,44 @@ class Poller {
PollCallback callback_;
};
- void Wait(Event event, DummyPrSocket* adapter, PollTarget* target,
- PollCallback cb);
- void Cancel(Event event, DummyPrSocket* adapter);
+ void Wait(Event event, std::shared_ptr<DummyPrSocket>& adapter,
+ PollTarget* target, PollCallback cb);
+ void Cancel(Event event, std::shared_ptr<DummyPrSocket>& adapter);
void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb,
- Timer** handle);
+ std::shared_ptr<Timer>* handle);
bool Poll();
private:
Poller() : waiters_(), timers_() {}
- ~Poller();
+ ~Poller() {}
class Waiter {
public:
- Waiter(DummyPrSocket* io) : io_(io) {
+ Waiter(std::shared_ptr<DummyPrSocket> io) : io_(io) {
+ memset(&targets_[0], 0, sizeof(targets_));
memset(&callbacks_[0], 0, sizeof(callbacks_));
}
void WaitFor(Event event, PollCallback callback);
- DummyPrSocket* io_;
+ std::shared_ptr<DummyPrSocket> io_;
PollTarget* targets_[TIMER_EVENT];
PollCallback callbacks_[TIMER_EVENT];
};
class TimerComparator {
public:
- bool operator()(const Timer* lhs, const Timer* rhs) {
+ bool operator()(const std::shared_ptr<Timer> lhs,
+ const std::shared_ptr<Timer> rhs) {
return lhs->deadline_ > rhs->deadline_;
}
};
static Poller* instance;
- std::map<DummyPrSocket*, Waiter*> waiters_;
- std::priority_queue<Timer*, std::vector<Timer*>, TimerComparator> timers_;
+ std::map<std::shared_ptr<DummyPrSocket>, std::unique_ptr<Waiter>> waiters_;
+ std::priority_queue<std::shared_ptr<Timer>,
+ std::vector<std::shared_ptr<Timer>>, TimerComparator>
+ timers_;
};
} // end of namespace
diff --git a/nss/gtests/ssl_gtest/tls_agent.cc b/nss/gtests/ssl_gtest/tls_agent.cc
index b75bba5..a53cf88 100644
--- a/nss/gtests/ssl_gtest/tls_agent.cc
+++ b/nss/gtests/ssl_gtest/tls_agent.cc
@@ -43,14 +43,14 @@ const std::string TlsAgent::kServerEcdhRsa = "ecdh_rsa";
const std::string TlsAgent::kServerEcdhEcdsa = "ecdh_ecdsa";
const std::string TlsAgent::kServerDsa = "dsa";
-TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
+TlsAgent::TlsAgent(const std::string& name, Role role,
+ SSLProtocolVariant variant)
: name_(name),
- mode_(mode),
+ variant_(variant),
+ role_(role),
server_key_bits_(0),
- pr_fd_(nullptr),
- adapter_(nullptr),
+ adapter_(new DummyPrSocket(role_str(), variant)),
ssl_fd_(nullptr),
- role_(role),
state_(STATE_INIT),
timer_handle_(nullptr),
falsestart_enabled_(false),
@@ -61,6 +61,10 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
can_falsestart_hook_called_(false),
sni_hook_called_(false),
auth_certificate_hook_called_(false),
+ expected_received_alert_(kTlsAlertCloseNotify),
+ expected_received_alert_level_(kTlsAlertWarning),
+ expected_sent_alert_(kTlsAlertCloseNotify),
+ expected_sent_alert_level_(kTlsAlertWarning),
handshake_callback_called_(false),
error_code_(0),
send_ctr_(0),
@@ -69,29 +73,31 @@ TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode)
handshake_callback_(),
auth_certificate_callback_(),
sni_callback_(),
- expect_short_headers_(false) {
+ expect_short_headers_(false),
+ skip_version_checks_(false) {
memset(&info_, 0, sizeof(info_));
memset(&csinfo_, 0, sizeof(csinfo_));
- SECStatus rv = SSL_VersionRangeGetDefault(
- mode_ == STREAM ? ssl_variant_stream : ssl_variant_datagram, &vrange_);
+ SECStatus rv = SSL_VersionRangeGetDefault(variant_, &vrange_);
EXPECT_EQ(SECSuccess, rv);
}
TlsAgent::~TlsAgent() {
- if (adapter_) {
- Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
- // The adapter is closed when the FD closes.
- }
if (timer_handle_) {
timer_handle_->Cancel();
}
- if (pr_fd_) {
- PR_Close(pr_fd_);
+ if (adapter_) {
+ Poller::Instance()->Cancel(READABLE_EVENT, adapter_);
}
- if (ssl_fd_) {
- PR_Close(ssl_fd_);
+ // Add failures manually, if any, so we don't throw in a destructor.
+ if (expected_received_alert_ != kTlsAlertCloseNotify ||
+ expected_received_alert_level_ != kTlsAlertWarning) {
+ ADD_FAILURE() << "Wrong expected_received_alert status";
+ }
+ if (expected_sent_alert_ != kTlsAlertCloseNotify ||
+ expected_sent_alert_level_ != kTlsAlertWarning) {
+ ADD_FAILURE() << "Wrong expected_sent_alert status";
}
}
@@ -102,27 +108,39 @@ void TlsAgent::SetState(State state) {
state_ = state;
}
+/*static*/ bool TlsAgent::LoadCertificate(const std::string& name,
+ ScopedCERTCertificate* cert,
+ ScopedSECKEYPrivateKey* priv) {
+ cert->reset(PK11_FindCertFromNickname(name.c_str(), nullptr));
+ EXPECT_NE(nullptr, cert->get());
+ if (!cert->get()) return false;
+
+ priv->reset(PK11_FindKeyByAnyCert(cert->get(), nullptr));
+ EXPECT_NE(nullptr, priv->get());
+ if (!priv->get()) return false;
+
+ return true;
+}
+
bool TlsAgent::ConfigServerCert(const std::string& name, bool updateKeyBits,
const SSLExtraServerCertData* serverCertData) {
- ScopedCERTCertificate cert(PK11_FindCertFromNickname(name.c_str(), nullptr));
- EXPECT_NE(nullptr, cert.get());
- if (!cert.get()) return false;
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(name, &cert, &priv)) {
+ return false;
+ }
- ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
- EXPECT_NE(nullptr, pub.get());
- if (!pub.get()) return false;
if (updateKeyBits) {
+ ScopedSECKEYPublicKey pub(CERT_ExtractPublicKey(cert.get()));
+ EXPECT_NE(nullptr, pub.get());
+ if (!pub.get()) return false;
server_key_bits_ = SECKEY_PublicKeyStrengthInBits(pub.get());
}
- ScopedSECKEYPrivateKey priv(PK11_FindKeyByAnyCert(cert.get(), nullptr));
- EXPECT_NE(nullptr, priv.get());
- if (!priv.get()) return false;
-
SECStatus rv =
- SSL_ConfigSecureServer(ssl_fd_, nullptr, nullptr, ssl_kea_null);
+ SSL_ConfigSecureServer(ssl_fd(), nullptr, nullptr, ssl_kea_null);
EXPECT_EQ(SECFailure, rv);
- rv = SSL_ConfigServerCert(ssl_fd_, cert.get(), priv.get(), serverCertData,
+ rv = SSL_ConfigServerCert(ssl_fd(), cert.get(), priv.get(), serverCertData,
serverCertData ? sizeof(*serverCertData) : 0);
return rv == SECSuccess;
}
@@ -131,41 +149,59 @@ bool TlsAgent::EnsureTlsSetup(PRFileDesc* modelSocket) {
// Don't set up twice
if (ssl_fd_) return true;
- if (adapter_->mode() == STREAM) {
- ssl_fd_ = SSL_ImportFD(modelSocket, pr_fd_);
+ ScopedPRFileDesc dummy_fd(adapter_->CreateFD());
+ EXPECT_NE(nullptr, dummy_fd);
+ if (!dummy_fd) {
+ return false;
+ }
+ if (adapter_->variant() == ssl_variant_stream) {
+ ssl_fd_.reset(SSL_ImportFD(modelSocket, dummy_fd.get()));
} else {
- ssl_fd_ = DTLS_ImportFD(modelSocket, pr_fd_);
+ ssl_fd_.reset(DTLS_ImportFD(modelSocket, dummy_fd.get()));
}
EXPECT_NE(nullptr, ssl_fd_);
- if (!ssl_fd_) return false;
- pr_fd_ = nullptr;
+ if (!ssl_fd_) {
+ return false;
+ }
+ dummy_fd.release(); // Now subsumed by ssl_fd_.
- SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
- EXPECT_EQ(SECSuccess, rv);
- if (rv != SECSuccess) return false;
+ SECStatus rv;
+ if (!skip_version_checks_) {
+ rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+ }
if (role_ == SERVER) {
EXPECT_TRUE(ConfigServerCert(name_, true));
- rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this);
+ rv = SSL_SNISocketConfigHook(ssl_fd(), SniHook, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
ScopedCERTCertList anchors(CERT_NewCertList());
- rv = SSL_SetTrustAnchors(ssl_fd_, anchors.get());
+ rv = SSL_SetTrustAnchors(ssl_fd(), anchors.get());
if (rv != SECSuccess) return false;
} else {
- rv = SSL_SetURL(ssl_fd_, "server");
+ rv = SSL_SetURL(ssl_fd(), "server");
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
}
- rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this);
+ rv = SSL_AuthCertificateHook(ssl_fd(), AuthCertificateHook, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_AlertReceivedCallback(ssl_fd(), AlertReceivedCallback, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
- rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this);
+ rv = SSL_AlertSentCallback(ssl_fd(), AlertSentCallback, this);
+ EXPECT_EQ(SECSuccess, rv);
+ if (rv != SECSuccess) return false;
+
+ rv = SSL_HandshakeCallback(ssl_fd(), HandshakeCallback, this);
EXPECT_EQ(SECSuccess, rv);
if (rv != SECSuccess) return false;
@@ -177,38 +213,31 @@ void TlsAgent::SetupClientAuth() {
ASSERT_EQ(CLIENT, role_);
EXPECT_EQ(SECSuccess,
- SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook,
+ SSL_GetClientAuthDataHook(ssl_fd(), GetClientAuthDataHook,
reinterpret_cast<void*>(this)));
}
-bool TlsAgent::GetClientAuthCredentials(CERTCertificate** cert,
- SECKEYPrivateKey** priv) const {
- *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr);
- EXPECT_NE(nullptr, *cert);
- if (!*cert) return false;
-
- *priv = PK11_FindKeyByAnyCert(*cert, nullptr);
- EXPECT_NE(nullptr, *priv);
- if (!*priv) return false; // Leak cert.
-
- return true;
-}
-
SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd,
CERTDistNames* caNames,
- CERTCertificate** cert,
- SECKEYPrivateKey** privKey) {
+ CERTCertificate** clientCert,
+ SECKEYPrivateKey** clientKey) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(self);
ScopedCERTCertificate peerCert(SSL_PeerCertificate(agent->ssl_fd()));
EXPECT_TRUE(peerCert) << "Client should be able to see the server cert";
- if (agent->GetClientAuthCredentials(cert, privKey)) {
- return SECSuccess;
+
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey priv;
+ if (!TlsAgent::LoadCertificate(agent->name(), &cert, &priv)) {
+ return SECFailure;
}
- return SECFailure;
+
+ *clientCert = cert.release();
+ *clientKey = priv.release();
+ return SECSuccess;
}
bool TlsAgent::GetPeerChainLength(size_t* count) {
- CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd_);
+ CERTCertList* chain = SSL_PeerCertificateChain(ssl_fd());
if (!chain) return false;
*count = 0;
@@ -224,17 +253,21 @@ bool TlsAgent::GetPeerChainLength(size_t* count) {
return true;
}
+void TlsAgent::CheckCipherSuite(uint16_t cipher_suite) {
+ EXPECT_EQ(csinfo_.cipherSuite, cipher_suite);
+}
+
void TlsAgent::RequestClientAuth(bool requireAuth) {
EXPECT_TRUE(EnsureTlsSetup());
ASSERT_EQ(SERVER, role_);
EXPECT_EQ(SECSuccess,
- SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE));
- EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE,
+ SSL_OptionSet(ssl_fd(), SSL_REQUEST_CERTIFICATE, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_REQUIRE_CERTIFICATE,
requireAuth ? PR_TRUE : PR_FALSE));
EXPECT_EQ(SECSuccess, SSL_AuthCertificateHook(
- ssl_fd_, &TlsAgent::ClientAuthenticated, this));
+ ssl_fd(), &TlsAgent::ClientAuthenticated, this));
expect_client_auth_ = true;
}
@@ -242,7 +275,7 @@ void TlsAgent::StartConnect(PRFileDesc* model) {
EXPECT_TRUE(EnsureTlsSetup(model));
SECStatus rv;
- rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE);
+ rv = SSL_ResetHandshake(ssl_fd(), role_ == SERVER ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
SetState(STATE_CONNECTING);
}
@@ -250,7 +283,7 @@ void TlsAgent::StartConnect(PRFileDesc* model) {
void TlsAgent::DisableAllCiphers() {
for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) {
SECStatus rv =
- SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE);
+ SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
}
@@ -287,7 +320,7 @@ void TlsAgent::EnableCiphersByKeyExchange(SSLKEAType kea) {
EXPECT_EQ(sizeof(csinfo), csinfo.length);
if ((csinfo.keaType == kea) || (csinfo.keaType == ssl_kea_tls13_any)) {
- rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
+ rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
}
@@ -325,7 +358,7 @@ void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
if ((csinfo.authType == authType) ||
(csinfo.keaType == ssl_kea_tls13_any)) {
- rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_TRUE);
+ rv = SSL_CipherPrefSet(ssl_fd(), SSL_ImplementedCiphers[i], PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
}
@@ -333,20 +366,20 @@ void TlsAgent::EnableCiphersByAuthType(SSLAuthType authType) {
void TlsAgent::EnableSingleCipher(uint16_t cipher) {
DisableAllCiphers();
- SECStatus rv = SSL_CipherPrefSet(ssl_fd_, cipher, PR_TRUE);
+ SECStatus rv = SSL_CipherPrefSet(ssl_fd(), cipher, PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups) {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_NamedGroupConfig(ssl_fd_, &groups[0], groups.size());
+ SECStatus rv = SSL_NamedGroupConfig(ssl_fd(), &groups[0], groups.size());
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SetSessionTicketsEnabled(bool en) {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS,
en ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
@@ -354,7 +387,7 @@ void TlsAgent::SetSessionTicketsEnabled(bool en) {
void TlsAgent::SetSessionCacheEnabled(bool en) {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE, en ? PR_FALSE : PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
@@ -362,14 +395,22 @@ void TlsAgent::Set0RttEnabled(bool en) {
EXPECT_TRUE(EnsureTlsSetup());
SECStatus rv =
- SSL_OptionSet(ssl_fd_, SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
+ SSL_OptionSet(ssl_fd(), SSL_ENABLE_0RTT_DATA, en ? PR_TRUE : PR_FALSE);
+ EXPECT_EQ(SECSuccess, rv);
+}
+
+void TlsAgent::SetFallbackSCSVEnabled(bool en) {
+ EXPECT_TRUE(role_ == CLIENT && EnsureTlsSetup());
+
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALLBACK_SCSV,
+ en ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SetShortHeadersEnabled() {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd_);
+ SECStatus rv = SSLInt_EnableShortHeaders(ssl_fd());
EXPECT_EQ(SECSuccess, rv);
}
@@ -377,8 +418,8 @@ void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) {
vrange_.min = minver;
vrange_.max = maxver;
- if (ssl_fd_) {
- SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_);
+ if (ssl_fd()) {
+ SECStatus rv = SSL_VersionRangeSet(ssl_fd(), &vrange_);
EXPECT_EQ(SECSuccess, rv);
}
}
@@ -398,32 +439,34 @@ void TlsAgent::ExpectReadWriteError() { expect_readwrite_error_ = true; }
void TlsAgent::ExpectShortHeaders() { expect_short_headers_ = true; }
+void TlsAgent::SkipVersionChecks() { skip_version_checks_ = true; }
+
void TlsAgent::SetSignatureSchemes(const SSLSignatureScheme* schemes,
size_t count) {
EXPECT_TRUE(EnsureTlsSetup());
EXPECT_LE(count, SSL_SignatureMaxCount());
EXPECT_EQ(SECSuccess,
- SSL_SignatureSchemePrefSet(ssl_fd_, schemes,
+ SSL_SignatureSchemePrefSet(ssl_fd(), schemes,
static_cast<unsigned int>(count)));
- EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd_, schemes, 0))
+ EXPECT_EQ(SECFailure, SSL_SignatureSchemePrefSet(ssl_fd(), schemes, 0))
<< "setting no schemes should fail and do nothing";
std::vector<SSLSignatureScheme> configuredSchemes(count);
unsigned int configuredCount;
EXPECT_EQ(SECFailure,
- SSL_SignatureSchemePrefGet(ssl_fd_, nullptr, &configuredCount, 1))
+ SSL_SignatureSchemePrefGet(ssl_fd(), nullptr, &configuredCount, 1))
<< "get schemes, schemes is nullptr";
EXPECT_EQ(SECFailure,
- SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0],
+ SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0],
&configuredCount, 0))
<< "get schemes, too little space";
EXPECT_EQ(SECFailure,
- SSL_SignatureSchemePrefGet(ssl_fd_, &configuredSchemes[0], nullptr,
+ SSL_SignatureSchemePrefGet(ssl_fd(), &configuredSchemes[0], nullptr,
configuredSchemes.size()))
<< "get schemes, countOut is nullptr";
EXPECT_EQ(SECSuccess, SSL_SignatureSchemePrefGet(
- ssl_fd_, &configuredSchemes[0], &configuredCount,
+ ssl_fd(), &configuredSchemes[0], &configuredCount,
configuredSchemes.size()));
// SignatureSchemePrefSet drops unsupported algorithms silently, so the
// number that are configured might be fewer.
@@ -524,10 +567,10 @@ void TlsAgent::EnableFalseStart() {
EXPECT_TRUE(EnsureTlsSetup());
falsestart_enabled_ = true;
+ EXPECT_EQ(SECSuccess, SSL_SetCanFalseStartCallback(
+ ssl_fd(), CanFalseStartCallback, this));
EXPECT_EQ(SECSuccess,
- SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this));
- EXPECT_EQ(SECSuccess,
- SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE));
+ SSL_OptionSet(ssl_fd(), SSL_ENABLE_FALSE_START, PR_TRUE));
}
void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
@@ -535,8 +578,8 @@ void TlsAgent::ExpectResumption() { expect_resumption_ = true; }
void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) {
EXPECT_TRUE(EnsureTlsSetup());
- EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE));
- EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len));
+ EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd(), SSL_ENABLE_ALPN, PR_TRUE));
+ EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd(), val, len));
}
void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
@@ -544,7 +587,7 @@ void TlsAgent::CheckAlpn(SSLNextProtoState expected_state,
SSLNextProtoState state;
char chosen[10];
unsigned int chosen_len;
- SECStatus rv = SSL_GetNextProto(ssl_fd_, &state,
+ SECStatus rv = SSL_GetNextProto(ssl_fd(), &state,
reinterpret_cast<unsigned char*>(chosen),
&chosen_len, sizeof(chosen));
EXPECT_EQ(SECSuccess, rv);
@@ -562,12 +605,12 @@ void TlsAgent::EnableSrtp() {
const uint16_t ciphers[] = {SRTP_AES128_CM_HMAC_SHA1_80,
SRTP_AES128_CM_HMAC_SHA1_32};
EXPECT_EQ(SECSuccess,
- SSL_SetSRTPCiphers(ssl_fd_, ciphers, PR_ARRAY_SIZE(ciphers)));
+ SSL_SetSRTPCiphers(ssl_fd(), ciphers, PR_ARRAY_SIZE(ciphers)));
}
void TlsAgent::CheckSrtp() const {
uint16_t actual;
- EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual));
+ EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd(), &actual));
EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual);
}
@@ -578,6 +621,55 @@ void TlsAgent::CheckErrorCode(int32_t expected) const {
<< PORT_ErrorToName(expected) << std::endl;
}
+static uint8_t GetExpectedAlertLevel(uint8_t alert) {
+ switch (alert) {
+ case kTlsAlertCloseNotify:
+ case kTlsAlertEndOfEarlyData:
+ return kTlsAlertWarning;
+ default:
+ break;
+ }
+ return kTlsAlertFatal;
+}
+
+void TlsAgent::ExpectReceiveAlert(uint8_t alert, uint8_t level) {
+ expected_received_alert_ = alert;
+ if (level == 0) {
+ expected_received_alert_level_ = GetExpectedAlertLevel(alert);
+ } else {
+ expected_received_alert_level_ = level;
+ }
+}
+
+void TlsAgent::ExpectSendAlert(uint8_t alert, uint8_t level) {
+ expected_sent_alert_ = alert;
+ if (level == 0) {
+ expected_sent_alert_level_ = GetExpectedAlertLevel(alert);
+ } else {
+ expected_sent_alert_level_ = level;
+ }
+}
+
+void TlsAgent::CheckAlert(bool sent, const SSLAlert* alert) {
+ LOG(((alert->level == kTlsAlertWarning) ? "Warning" : "Fatal")
+ << " alert " << (sent ? "sent" : "received") << ": "
+ << static_cast<int>(alert->description));
+
+ auto& expected = sent ? expected_sent_alert_ : expected_received_alert_;
+ auto& expected_level =
+ sent ? expected_sent_alert_level_ : expected_received_alert_level_;
+ /* Silently pass close_notify in case the test has already ended. */
+ if (expected == kTlsAlertCloseNotify && expected_level == kTlsAlertWarning &&
+ alert->description == expected && alert->level == expected_level) {
+ return;
+ }
+
+ EXPECT_EQ(expected, alert->description);
+ EXPECT_EQ(expected_level, alert->level);
+ expected = kTlsAlertCloseNotify;
+ expected_level = kTlsAlertWarning;
+}
+
void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
ASSERT_EQ(0, error_code_);
WAIT_(error_code_ != 0, delay);
@@ -589,7 +681,7 @@ void TlsAgent::WaitForErrorCode(int32_t expected, uint32_t delay) const {
void TlsAgent::CheckPreliminaryInfo() {
SSLPreliminaryChannelInfo info;
EXPECT_EQ(SECSuccess,
- SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info)));
+ SSL_GetPreliminaryChannelInfo(ssl_fd(), &info, sizeof(info)));
EXPECT_EQ(sizeof(info), info.length);
EXPECT_TRUE(info.valuesSet & ssl_preinfo_version);
EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite);
@@ -619,7 +711,7 @@ void TlsAgent::CheckCallbacks() const {
// These callbacks shouldn't fire if we are resuming, except on TLS 1.3.
if (role_ == SERVER) {
- PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd_, ssl_server_name_xtn);
+ PRBool have_sni = SSLInt_ExtensionNegotiated(ssl_fd(), ssl_server_name_xtn);
EXPECT_EQ(((!expect_resumption_ && have_sni) ||
expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3),
sni_hook_called_);
@@ -639,11 +731,15 @@ void TlsAgent::ResetPreliminaryInfo() {
}
void TlsAgent::Connected() {
+ if (state_ == STATE_CONNECTED) {
+ return;
+ }
+
LOG("Handshake success");
CheckPreliminaryInfo();
CheckCallbacks();
- SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_));
+ SECStatus rv = SSL_GetChannelInfo(ssl_fd(), &info_, sizeof(info_));
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ(sizeof(info_), info_.length);
@@ -658,18 +754,19 @@ void TlsAgent::Connected() {
EXPECT_EQ(sizeof(csinfo_), csinfo_.length);
if (expected_version_ >= SSL_LIBRARY_VERSION_TLS_1_3) {
- PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd_);
+ PRInt32 cipherSuites = SSLInt_CountTls13CipherSpecs(ssl_fd());
// We use one ciphersuite in each direction, plus one that's kept around
// by DTLS for retransmission.
- PRInt32 expected = ((mode_ == DGRAM) && (role_ == CLIENT)) ? 3 : 2;
+ PRInt32 expected =
+ ((variant_ == ssl_variant_datagram) && (role_ == CLIENT)) ? 3 : 2;
EXPECT_EQ(expected, cipherSuites);
if (expected != cipherSuites) {
- SSLInt_PrintTls13CipherSpecs(ssl_fd_);
+ SSLInt_PrintTls13CipherSpecs(ssl_fd());
}
}
PRBool short_headers;
- rv = SSLInt_UsingShortHeaders(ssl_fd_, &short_headers);
+ rv = SSLInt_UsingShortHeaders(ssl_fd(), &short_headers);
EXPECT_EQ(SECSuccess, rv);
EXPECT_EQ((PRBool)expect_short_headers_, short_headers);
SetState(STATE_CONNECTED);
@@ -679,7 +776,7 @@ void TlsAgent::EnableExtendedMasterSecret() {
ASSERT_TRUE(EnsureTlsSetup());
SECStatus rv =
- SSL_OptionSet(ssl_fd_, SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
+ SSL_OptionSet(ssl_fd(), SSL_ENABLE_EXTENDED_MASTER_SECRET, PR_TRUE);
ASSERT_EQ(SECSuccess, rv);
}
@@ -701,13 +798,13 @@ void TlsAgent::CheckEarlyDataAccepted(bool expected) {
}
void TlsAgent::CheckSecretsDestroyed() {
- ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd_));
+ ASSERT_EQ(PR_TRUE, SSLInt_CheckSecretsDestroyed(ssl_fd()));
}
void TlsAgent::DisableRollbackDetection() {
ASSERT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ROLLBACK_DETECTION, PR_FALSE);
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ROLLBACK_DETECTION, PR_FALSE);
ASSERT_EQ(SECSuccess, rv);
}
@@ -715,23 +812,22 @@ void TlsAgent::DisableRollbackDetection() {
void TlsAgent::EnableCompression() {
ASSERT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_DEFLATE, PR_TRUE);
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_DEFLATE, PR_TRUE);
ASSERT_EQ(SECSuccess, rv);
}
void TlsAgent::SetDowngradeCheckVersion(uint16_t version) {
ASSERT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd_, version);
+ SECStatus rv = SSL_SetDowngradeCheckVersion(ssl_fd(), version);
ASSERT_EQ(SECSuccess, rv);
}
void TlsAgent::Handshake() {
LOGV("Handshake");
- SECStatus rv = SSL_ForceHandshake(ssl_fd_);
+ SECStatus rv = SSL_ForceHandshake(ssl_fd());
if (rv == SECSuccess) {
Connected();
-
Poller::Instance()->Wait(READABLE_EVENT, adapter_, this,
&TlsAgent::ReadableCallback);
return;
@@ -740,14 +836,14 @@ void TlsAgent::Handshake() {
int32_t err = PR_GetError();
if (err == PR_WOULD_BLOCK_ERROR) {
LOGV("Would have blocked");
- if (mode_ == DGRAM) {
+ if (variant_ == ssl_variant_datagram) {
if (timer_handle_) {
timer_handle_->Cancel();
timer_handle_ = nullptr;
}
PRIntervalTime timeout;
- rv = DTLS_GetHandshakeTimeout(ssl_fd_, &timeout);
+ rv = DTLS_GetHandshakeTimeout(ssl_fd(), &timeout);
if (rv == SECSuccess) {
Poller::Instance()->SetTimer(
timeout + 1, this, &TlsAgent::ReadableCallback, &timer_handle_);
@@ -773,13 +869,18 @@ void TlsAgent::PrepareForRenegotiate() {
void TlsAgent::StartRenegotiate() {
PrepareForRenegotiate();
- SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE);
+ SECStatus rv = SSL_ReHandshake(ssl_fd(), PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::SendDirect(const DataBuffer& buf) {
LOG("Send Direct " << buf);
- adapter_->peer()->PacketReceived(buf);
+ auto peer = adapter_->peer().lock();
+ if (peer) {
+ peer->PacketReceived(buf);
+ } else {
+ LOG("Send Direct peer absent");
+ }
}
static bool ErrorIsNonFatal(PRErrorCode code) {
@@ -806,7 +907,7 @@ void TlsAgent::SendData(size_t bytes, size_t blocksize) {
void TlsAgent::SendBuffer(const DataBuffer& buf) {
LOGV("Writing " << buf.len() << " bytes");
- int32_t rv = PR_Write(ssl_fd_, buf.data(), buf.len());
+ int32_t rv = PR_Write(ssl_fd(), buf.data(), buf.len());
if (expect_readwrite_error_) {
EXPECT_GT(0, rv);
EXPECT_NE(PR_WOULD_BLOCK_ERROR, error_code_);
@@ -820,7 +921,7 @@ void TlsAgent::SendBuffer(const DataBuffer& buf) {
void TlsAgent::ReadBytes() {
uint8_t block[1024];
- int32_t rv = PR_Read(ssl_fd_, block, sizeof(block));
+ int32_t rv = PR_Read(ssl_fd(), block, sizeof(block));
LOGV("ReadBytes " << rv);
int32_t err;
@@ -853,18 +954,19 @@ void TlsAgent::ResetSentBytes() { send_ctr_ = 0; }
void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) {
EXPECT_TRUE(EnsureTlsSetup());
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE,
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_NO_CACHE,
mode & RESUME_SESSIONID ? PR_FALSE : PR_TRUE);
EXPECT_EQ(SECSuccess, rv);
- rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS,
+ rv = SSL_OptionSet(ssl_fd(), SSL_ENABLE_SESSION_TICKETS,
mode & RESUME_TICKET ? PR_TRUE : PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
void TlsAgent::DisableECDHEServerKeyReuse() {
+ ASSERT_TRUE(EnsureTlsSetup());
ASSERT_EQ(TlsAgent::SERVER, role_);
- SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
+ SECStatus rv = SSL_OptionSet(ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE);
EXPECT_EQ(SECSuccess, rv);
}
@@ -877,29 +979,25 @@ void TlsAgentTestBase::SetUp() {
}
void TlsAgentTestBase::TearDown() {
- delete agent_;
+ agent_ = nullptr;
SSL_ClearSessionCache();
SSL_ShutdownServerSessionIDCache();
}
void TlsAgentTestBase::Reset(const std::string& server_name) {
- delete agent_;
- Init(server_name);
-}
-
-void TlsAgentTestBase::Init(const std::string& server_name) {
- agent_ =
+ agent_.reset(
new TlsAgent(role_ == TlsAgent::CLIENT ? TlsAgent::kClient : server_name,
- role_, mode_);
- agent_->Init();
- fd_ = DummyPrSocket::CreateFD(agent_->role_str(), mode_);
- agent_->adapter()->SetPeer(DummyPrSocket::GetAdapter(fd_));
+ role_, variant_));
+ if (version_) {
+ agent_->SetVersionRange(version_, version_);
+ }
+ agent_->adapter()->SetPeer(sink_adapter_);
agent_->StartConnect();
}
void TlsAgentTestBase::EnsureInit() {
if (!agent_) {
- Init();
+ Reset();
}
const std::vector<SSLNamedGroup> groups = {
ssl_grp_ec_curve25519, ssl_grp_ec_secp256r1, ssl_grp_ec_secp384r1,
@@ -907,6 +1005,11 @@ void TlsAgentTestBase::EnsureInit() {
agent_->ConfigNamedGroups(groups);
}
+void TlsAgentTestBase::ExpectAlert(uint8_t alert) {
+ EnsureInit();
+ agent_->ExpectSendAlert(alert);
+}
+
void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
TlsAgent::State expected_state,
int32_t error_code) {
@@ -922,14 +1025,16 @@ void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer,
}
}
-void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version,
- const uint8_t* buf, size_t len,
- DataBuffer* out, uint64_t seq_num) {
+void TlsAgentTestBase::MakeRecord(SSLProtocolVariant variant, uint8_t type,
+ uint16_t version, const uint8_t* buf,
+ size_t len, DataBuffer* out,
+ uint64_t seq_num) {
size_t index = 0;
index = out->Write(index, type, 1);
- index = out->Write(
- index, mode == STREAM ? version : TlsVersionToDtlsVersion(version), 2);
- if (mode == DGRAM) {
+ if (variant == ssl_variant_stream) {
+ index = out->Write(index, version, 2);
+ } else {
+ index = out->Write(index, TlsVersionToDtlsVersion(version), 2);
index = out->Write(index, seq_num >> 32, 4);
index = out->Write(index, seq_num & PR_UINT32_MAX, 4);
}
@@ -940,7 +1045,7 @@ void TlsAgentTestBase::MakeRecord(Mode mode, uint8_t type, uint16_t version,
void TlsAgentTestBase::MakeRecord(uint8_t type, uint16_t version,
const uint8_t* buf, size_t len,
DataBuffer* out, uint64_t seq_num) const {
- MakeRecord(mode_, type, version, buf, len, out, seq_num);
+ MakeRecord(variant_, type, version, buf, len, out, seq_num);
}
void TlsAgentTestBase::MakeHandshakeMessage(uint8_t hs_type,
@@ -959,7 +1064,7 @@ void TlsAgentTestBase::MakeHandshakeMessageFragment(
if (!fragment_length) fragment_length = hs_len;
index = out->Write(index, hs_type, 1); // Handshake record type.
index = out->Write(index, hs_len, 3); // Handshake length
- if (mode_ == DGRAM) {
+ if (variant_ == ssl_variant_datagram) {
index = out->Write(index, seq_num, 2);
index = out->Write(index, fragment_offset, 3);
index = out->Write(index, fragment_length, 3);
diff --git a/nss/gtests/ssl_gtest/tls_agent.h b/nss/gtests/ssl_gtest/tls_agent.h
index 78923c9..32f6175 100644
--- a/nss/gtests/ssl_gtest/tls_agent.h
+++ b/nss/gtests/ssl_gtest/tls_agent.h
@@ -14,9 +14,11 @@
#include <iostream>
#include "test_io.h"
+#include "tls_filter.h"
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
+#include "scoped_ptrs.h"
extern bool g_ssl_gtest_verbose;
@@ -42,6 +44,8 @@ const extern std::vector<SSLNamedGroup> kECDHEGroups;
const extern std::vector<SSLNamedGroup> kFFDHEGroups;
const extern std::vector<SSLNamedGroup> kFasterDHEGroups;
+// These functions are called from callbacks. They use bare pointers because
+// TlsAgent sets up the callback and it doesn't know who owns it.
typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
AuthCertificateCallbackFunction;
@@ -70,25 +74,24 @@ class TlsAgent : public PollTarget {
static const std::string kServerEcdhRsa;
static const std::string kServerDsa;
- TlsAgent(const std::string& name, Role role, Mode mode);
+ TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant);
virtual ~TlsAgent();
- bool Init() {
- pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_);
- if (!pr_fd_) return false;
-
- adapter_ = DummyPrSocket::GetAdapter(pr_fd_);
- if (!adapter_) return false;
-
- return true;
+ void SetPeer(std::shared_ptr<TlsAgent>& peer) {
+ adapter_->SetPeer(peer->adapter_);
}
- void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); }
+ void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
+ filter->SetAgent(this);
+ adapter_->SetPacketFilter(filter);
+ }
- void SetPacketFilter(PacketFilter* filter) {
+ void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
adapter_->SetPacketFilter(filter);
}
+ void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); }
+
void StartConnect(PRFileDesc* model = nullptr);
void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
size_t kea_size = 0) const;
@@ -107,6 +110,9 @@ class TlsAgent : public PollTarget {
void PrepareForRenegotiate();
// Prepares for renegotiation, then actually triggers it.
void StartRenegotiate();
+ static bool LoadCertificate(const std::string& name,
+ ScopedCERTCertificate* cert,
+ ScopedSECKEYPrivateKey* priv);
bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
const SSLExtraServerCertData* serverCertData = nullptr);
bool ConfigServerCertWithChain(const std::string& name);
@@ -114,13 +120,12 @@ class TlsAgent : public PollTarget {
void SetupClientAuth();
void RequestClientAuth(bool requireAuth);
- bool GetClientAuthCredentials(CERTCertificate** cert,
- SECKEYPrivateKey** priv) const;
void ConfigureSessionCache(SessionResumptionMode mode);
void SetSessionTicketsEnabled(bool en);
void SetSessionCacheEnabled(bool en);
void Set0RttEnabled(bool en);
+ void SetFallbackSCSVEnabled(bool en);
void SetShortHeadersEnabled();
void SetVersionRange(uint16_t minver, uint16_t maxver);
void GetVersionRange(uint16_t* minver, uint16_t* maxver);
@@ -132,6 +137,7 @@ class TlsAgent : public PollTarget {
void EnableFalseStart();
void ExpectResumption();
void ExpectShortHeaders();
+ void SkipVersionChecks();
void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
void EnableAlpn(const uint8_t* val, size_t len);
void CheckAlpn(SSLNextProtoState expected_state,
@@ -157,6 +163,7 @@ class TlsAgent : public PollTarget {
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
void DisableECDHEServerKeyReuse();
bool GetPeerChainLength(size_t* count);
+ void CheckCipherSuite(uint16_t cipher_suite);
const std::string& name() const { return name_; }
@@ -166,15 +173,15 @@ class TlsAgent : public PollTarget {
State state() const { return state_; }
const CERTCertificate* peer_cert() const {
- return SSL_PeerCertificate(ssl_fd_);
+ return SSL_PeerCertificate(ssl_fd_.get());
}
const char* state_str() const { return state_str(state()); }
static const char* state_str(State state) { return states[state]; }
- PRFileDesc* ssl_fd() { return ssl_fd_; }
- DummyPrSocket* adapter() { return adapter_; }
+ PRFileDesc* ssl_fd() const { return ssl_fd_.get(); }
+ std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }
bool is_compressed() const {
return info_.compressionMethod != ssl_compression_null;
@@ -239,6 +246,9 @@ class TlsAgent : public PollTarget {
sni_callback_ = sni_callback;
}
+ void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0);
+ void ExpectSendAlert(uint8_t alert, uint8_t level = 0);
+
private:
const static char* states[];
@@ -320,6 +330,18 @@ class TlsAgent : public PollTarget {
return SECSuccess;
}
+ void CheckAlert(bool sent, const SSLAlert* alert);
+
+ static void AlertReceivedCallback(const PRFileDesc* fd, void* arg,
+ const SSLAlert* alert) {
+ reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert);
+ }
+
+ static void AlertSentCallback(const PRFileDesc* fd, void* arg,
+ const SSLAlert* alert) {
+ reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert);
+ }
+
static void HandshakeCallback(PRFileDesc* fd, void* arg) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
agent->handshake_callback_called_ = true;
@@ -336,14 +358,13 @@ class TlsAgent : public PollTarget {
void Connected();
const std::string name_;
- Mode mode_;
- uint16_t server_key_bits_;
- PRFileDesc* pr_fd_;
- DummyPrSocket* adapter_;
- PRFileDesc* ssl_fd_;
+ SSLProtocolVariant variant_;
Role role_;
+ uint16_t server_key_bits_;
+ std::shared_ptr<DummyPrSocket> adapter_;
+ ScopedPRFileDesc ssl_fd_;
State state_;
- Poller::Timer* timer_handle_;
+ std::shared_ptr<Poller::Timer> timer_handle_;
bool falsestart_enabled_;
uint16_t expected_version_;
uint16_t expected_cipher_suite_;
@@ -352,6 +373,10 @@ class TlsAgent : public PollTarget {
bool can_falsestart_hook_called_;
bool sni_hook_called_;
bool auth_certificate_hook_called_;
+ uint8_t expected_received_alert_;
+ uint8_t expected_received_alert_level_;
+ uint8_t expected_sent_alert_;
+ uint8_t expected_sent_alert_level_;
bool handshake_callback_called_;
SSLChannelInfo info_;
SSLCipherSuiteInfo csinfo_;
@@ -364,6 +389,7 @@ class TlsAgent : public PollTarget {
AuthCertificateCallbackFunction auth_certificate_callback_;
SniCallbackFunction sni_callback_;
bool expect_short_headers_;
+ bool skip_version_checks_;
};
inline std::ostream& operator<<(std::ostream& stream,
@@ -375,20 +401,23 @@ class TlsAgentTestBase : public ::testing::Test {
public:
static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;
- TlsAgentTestBase(TlsAgent::Role role, Mode mode)
- : agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {}
- ~TlsAgentTestBase() {
- if (fd_) {
- PR_Close(fd_);
- }
- }
+ TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant,
+ uint16_t version = 0)
+ : agent_(nullptr),
+ role_(role),
+ variant_(variant),
+ version_(version),
+ sink_adapter_(new DummyPrSocket("sink", variant)) {}
+ virtual ~TlsAgentTestBase() {}
void SetUp();
void TearDown();
- static void MakeRecord(Mode mode, uint8_t type, uint16_t version,
- const uint8_t* buf, size_t len, DataBuffer* out,
- uint64_t seq_num = 0);
+ void ExpectAlert(uint8_t alert);
+
+ static void MakeRecord(SSLProtocolVariant variant, uint8_t type,
+ uint16_t version, const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num = 0);
void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
@@ -403,10 +432,6 @@ class TlsAgentTestBase : public ::testing::Test {
return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
}
- static inline Mode ToMode(const std::string& str) {
- return str == "TLS" ? STREAM : DGRAM;
- }
-
void Init(const std::string& server_name = TlsAgent::kServerRsa);
void Reset(const std::string& server_name = TlsAgent::kServerRsa);
@@ -415,43 +440,57 @@ class TlsAgentTestBase : public ::testing::Test {
void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
int32_t error_code = 0);
- TlsAgent* agent_;
- PRFileDesc* fd_;
+ std::unique_ptr<TlsAgent> agent_;
TlsAgent::Role role_;
- Mode mode_;
+ SSLProtocolVariant variant_;
+ uint16_t version_;
+ // This adapter is here just to accept packets from this agent.
+ std::shared_ptr<DummyPrSocket> sink_adapter_;
};
-class TlsAgentTest : public TlsAgentTestBase,
- public ::testing::WithParamInterface<
- std::tuple<std::string, std::string>> {
+class TlsAgentTest
+ : public TlsAgentTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<std::string, SSLProtocolVariant, uint16_t>> {
public:
TlsAgentTest()
: TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
- ToMode(std::get<1>(GetParam()))) {}
+ std::get<1>(GetParam()), std::get<2>(GetParam())) {}
};
class TlsAgentTestClient : public TlsAgentTestBase,
- public ::testing::WithParamInterface<std::string> {
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsAgentTestClient()
- : TlsAgentTestBase(TlsAgent::CLIENT, ToMode(GetParam())) {}
+ : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()),
+ std::get<1>(GetParam())) {}
};
+class TlsAgentTestClient13 : public TlsAgentTestClient {};
+
class TlsAgentStreamTestClient : public TlsAgentTestBase {
public:
- TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {}
+ TlsAgentStreamTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {}
};
class TlsAgentStreamTestServer : public TlsAgentTestBase {
public:
- TlsAgentStreamTestServer() : TlsAgentTestBase(TlsAgent::SERVER, STREAM) {}
+ TlsAgentStreamTestServer()
+ : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {}
};
class TlsAgentDgramTestClient : public TlsAgentTestBase {
public:
- TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {}
+ TlsAgentDgramTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {}
};
+inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) {
+ return vr1.min == vr2.min && vr1.max == vr2.max;
+}
+
} // namespace nss_test
#endif
diff --git a/nss/gtests/ssl_gtest/tls_connect.cc b/nss/gtests/ssl_gtest/tls_connect.cc
index d025499..861d162 100644
--- a/nss/gtests/ssl_gtest/tls_connect.cc
+++ b/nss/gtests/ssl_gtest/tls_connect.cc
@@ -13,23 +13,27 @@ extern "C" {
#include "databuffer.h"
#include "gtest_utils.h"
+#include "scoped_ptrs.h"
#include "sslproto.h"
extern std::string g_working_dir_path;
namespace nss_test {
-static const std::string kTlsModesStreamArr[] = {"TLS"};
-::testing::internal::ParamGenerator<std::string>
- TlsConnectTestBase::kTlsModesStream =
- ::testing::ValuesIn(kTlsModesStreamArr);
-static const std::string kTlsModesDatagramArr[] = {"DTLS"};
-::testing::internal::ParamGenerator<std::string>
- TlsConnectTestBase::kTlsModesDatagram =
- ::testing::ValuesIn(kTlsModesDatagramArr);
-static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"};
-::testing::internal::ParamGenerator<std::string>
- TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr);
+static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsStream =
+ ::testing::ValuesIn(kTlsVariantsStreamArr);
+static const SSLProtocolVariant kTlsVariantsDatagramArr[] = {
+ ssl_variant_datagram};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsDatagram =
+ ::testing::ValuesIn(kTlsVariantsDatagramArr);
+static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream,
+ ssl_variant_datagram};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsAll =
+ ::testing::ValuesIn(kTlsVariantsAllArr);
static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 =
@@ -99,30 +103,29 @@ std::string VersionString(uint16_t version) {
}
}
-TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version)
- : mode_(mode),
- client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)),
- server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)),
+TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
+ uint16_t version)
+ : variant_(variant),
+ client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)),
+ server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)),
client_model_(nullptr),
server_model_(nullptr),
version_(version),
expected_resumption_mode_(RESUME_NONE),
session_ids_(),
expect_extended_master_secret_(false),
- expect_early_data_accepted_(false) {
+ expect_early_data_accepted_(false),
+ skip_version_checks_(false) {
std::string v;
- if (mode_ == DGRAM && version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
+ if (variant_ == ssl_variant_datagram &&
+ version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
v = "1.0";
} else {
v = VersionString(version_);
}
- std::cerr << "Version: " << mode_ << " " << v << std::endl;
+ std::cerr << "Version: " << variant_ << " " << v << std::endl;
}
-TlsConnectTestBase::TlsConnectTestBase(const std::string& mode,
- uint16_t version)
- : TlsConnectTestBase(TlsConnectTestBase::ToMode(mode), version) {}
-
TlsConnectTestBase::~TlsConnectTestBase() {}
// Check the group of each of the supported groups
@@ -173,18 +176,15 @@ void TlsConnectTestBase::ClearServerCache() {
void TlsConnectTestBase::SetUp() {
SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
SSLInt_ClearSessionTicketKey();
+ SSLInt_SetTicketLifetime(30);
+ SSLInt_SetMaxEarlyDataSize(1024);
ClearStats();
Init();
}
void TlsConnectTestBase::TearDown() {
- delete client_;
- delete server_;
- if (client_model_) {
- ASSERT_NE(server_model_, nullptr);
- delete client_model_;
- delete server_model_;
- }
+ client_ = nullptr;
+ server_ = nullptr;
SSL_ClearSessionCache();
SSLInt_ClearSessionTicketKey();
@@ -192,9 +192,6 @@ void TlsConnectTestBase::TearDown() {
}
void TlsConnectTestBase::Init() {
- EXPECT_TRUE(client_->Init());
- EXPECT_TRUE(server_->Init());
-
client_->SetPeer(server_);
server_->SetPeer(client_);
@@ -212,11 +209,12 @@ void TlsConnectTestBase::Reset() {
void TlsConnectTestBase::Reset(const std::string& server_name,
const std::string& client_name) {
- delete client_;
- delete server_;
-
- client_ = new TlsAgent(client_name, TlsAgent::CLIENT, mode_);
- server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_);
+ client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_));
+ server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_));
+ if (skip_version_checks_) {
+ client_->SkipVersionChecks();
+ server_->SkipVersionChecks();
+ }
Init();
}
@@ -276,10 +274,12 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
}
void TlsConnectTestBase::CheckConnected() {
- // Check the version is as expected
EXPECT_EQ(client_->version(), server_->version());
- EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
- client_->version());
+ if (!skip_version_checks_) {
+ // Check the version is as expected
+ EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
+ client_->version());
+ }
EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
@@ -345,6 +345,13 @@ void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
scheme = ssl_sig_none;
break;
case ssl_auth_rsa_sign:
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) {
+ scheme = ssl_sig_rsa_pss_sha256;
+ } else {
+ scheme = ssl_sig_rsa_pkcs1_sha256;
+ }
+ break;
+ case ssl_auth_rsa_pss:
scheme = ssl_sig_rsa_pss_sha256;
break;
case ssl_auth_ecdsa:
@@ -373,7 +380,36 @@ void TlsConnectTestBase::ConnectExpectFail() {
ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
}
+void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender,
+ uint8_t alert) {
+ EnsureTlsSetup();
+ auto receiver = (sender == client_) ? server_ : client_;
+ sender->ExpectSendAlert(alert);
+ receiver->ExpectReceiveAlert(alert);
+}
+
+void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
+ uint8_t alert) {
+ ExpectAlert(sender, alert);
+ ConnectExpectFail();
+}
+
+void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
+ server_->StartConnect();
+ client_->StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+ client_->Handshake();
+ server_->Handshake();
+
+ auto failing_agent = server_;
+ if (failing_side == TlsAgent::CLIENT) {
+ failing_agent = client_;
+ }
+ ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000);
+}
+
void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
+ version_ = version;
client_->SetVersionRange(version, version);
server_->SetVersionRange(version, version);
}
@@ -424,10 +460,16 @@ void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
client_->ConfigureSessionCache(client);
server_->ConfigureSessionCache(server);
if ((server & RESUME_TICKET) != 0) {
- // This is an abomination. NSS encrypts session tickets with the server's
- // RSA public key. That means we need the server to have an RSA certificate
- // even if it won't be used for the connection.
- server_->ConfigServerCert(TlsAgent::kServerRsaDecrypt);
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey privKey;
+ ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert,
+ &privKey));
+
+ ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
+ ASSERT_TRUE(pubKey);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
}
}
@@ -472,13 +514,15 @@ void TlsConnectTestBase::EnsureModelSockets() {
// Make sure models agents are available.
if (!client_model_) {
ASSERT_EQ(server_model_, nullptr);
- client_model_ = new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_);
- server_model_ = new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_);
+ client_model_.reset(
+ new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_));
+ server_model_.reset(
+ new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_));
+ if (skip_version_checks_) {
+ client_model_->SkipVersionChecks();
+ server_model_->SkipVersionChecks();
+ }
}
-
- // Initialise agents.
- ASSERT_TRUE(client_model_->Init());
- ASSERT_TRUE(server_model_->Init());
}
void TlsConnectTestBase::CheckAlpn(const std::string& val) {
@@ -540,6 +584,10 @@ void TlsConnectTestBase::ZeroRttSendReceive(
const char* k0RttData = "ABCDEF";
const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ if (expect_writable && expect_readable) {
+ ExpectAlert(client_, kTlsAlertEndOfEarlyData);
+ }
+
client_->Handshake(); // Send ClientHello.
if (post_clienthello_check) {
if (!post_clienthello_check()) return;
@@ -599,6 +647,12 @@ void TlsConnectTestBase::DisableECDHEServerKeyReuse() {
server_->DisableECDHEServerKeyReuse();
}
+void TlsConnectTestBase::SkipVersionChecks() {
+ skip_version_checks_ = true;
+ client_->SkipVersionChecks();
+ server_->SkipVersionChecks();
+}
+
TlsConnectGeneric::TlsConnectGeneric()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
@@ -616,16 +670,17 @@ TlsConnectTls13::TlsConnectTls13()
void TlsKeyExchangeTest::EnsureKeyShareSetup() {
EnsureTlsSetup();
- groups_capture_ = new TlsExtensionCapture(ssl_supported_groups_xtn);
- shares_capture_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
- shares_capture2_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn, true);
- std::vector<PacketFilter*> captures;
- captures.push_back(groups_capture_);
- captures.push_back(shares_capture_);
- captures.push_back(shares_capture2_);
- client_->SetPacketFilter(new ChainedPacketFilter(captures));
- capture_hrr_ =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest);
+ groups_capture_ =
+ std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ shares_capture_ =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ shares_capture2_ =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true);
+ std::vector<std::shared_ptr<PacketFilter>> captures = {
+ groups_capture_, shares_capture_, shares_capture2_};
+ client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
+ capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeHelloRetryRequest);
server_->SetPacketFilter(capture_hrr_);
}
diff --git a/nss/gtests/ssl_gtest/tls_connect.h b/nss/gtests/ssl_gtest/tls_connect.h
index aa4a32d..73e8dc8 100644
--- a/nss/gtests/ssl_gtest/tls_connect.h
+++ b/nss/gtests/ssl_gtest/tls_connect.h
@@ -25,9 +25,12 @@ extern std::string VersionString(uint16_t version);
// A generic TLS connection test base.
class TlsConnectTestBase : public ::testing::Test {
public:
- static ::testing::internal::ParamGenerator<std::string> kTlsModesStream;
- static ::testing::internal::ParamGenerator<std::string> kTlsModesDatagram;
- static ::testing::internal::ParamGenerator<std::string> kTlsModesAll;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsStream;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsDatagram;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsAll;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV11;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV12;
@@ -39,8 +42,7 @@ class TlsConnectTestBase : public ::testing::Test {
static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll;
- TlsConnectTestBase(Mode mode, uint16_t version);
- TlsConnectTestBase(const std::string& mode, uint16_t version);
+ TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
virtual ~TlsConnectTestBase();
void SetUp();
@@ -68,6 +70,9 @@ class TlsConnectTestBase : public ::testing::Test {
void CheckConnected();
// Connect and expect it to fail.
void ConnectExpectFail();
+ void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
+ void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
+ void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
void ConnectWithCipherSuite(uint16_t cipher_suite);
// Check that the keys used in the handshake match expectations.
void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
@@ -108,13 +113,14 @@ class TlsConnectTestBase : public ::testing::Test {
void ExpectExtendedMasterSecret(bool expected);
void ExpectEarlyDataAccepted(bool expected);
void DisableECDHEServerKeyReuse();
+ void SkipVersionChecks();
protected:
- Mode mode_;
- TlsAgent* client_;
- TlsAgent* server_;
- TlsAgent* client_model_;
- TlsAgent* server_model_;
+ SSLProtocolVariant variant_;
+ std::shared_ptr<TlsAgent> client_;
+ std::shared_ptr<TlsAgent> server_;
+ std::unique_ptr<TlsAgent> client_model_;
+ std::unique_ptr<TlsAgent> server_model_;
uint16_t version_;
SessionResumptionMode expected_resumption_mode_;
std::vector<std::vector<uint8_t>> session_ids_;
@@ -126,16 +132,13 @@ class TlsConnectTestBase : public ::testing::Test {
const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61};
private:
- static inline Mode ToMode(const std::string& str) {
- return str == "TLS" ? STREAM : DGRAM;
- }
-
void CheckResumption(SessionResumptionMode expected);
void CheckExtendedMasterSecret();
void CheckEarlyDataAccepted();
bool expect_extended_master_secret_;
bool expect_early_data_accepted_;
+ bool skip_version_checks_;
// Track groups and make sure that there are no duplicates.
class DuplicateGroupChecker {
@@ -154,20 +157,20 @@ class TlsConnectTestBase : public ::testing::Test {
// A non-parametrized TLS test base.
class TlsConnectTest : public TlsConnectTestBase {
public:
- TlsConnectTest() : TlsConnectTestBase(STREAM, 0) {}
+ TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
};
// A non-parametrized DTLS-only test base.
class DtlsConnectTest : public TlsConnectTestBase {
public:
- DtlsConnectTest() : TlsConnectTestBase(DGRAM, 0) {}
+ DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {}
};
// A TLS-only test base.
class TlsConnectStream : public TlsConnectTestBase,
public ::testing::WithParamInterface<uint16_t> {
public:
- TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {}
+ TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {}
};
// A TLS-only test base for tests before 1.3
@@ -177,30 +180,30 @@ class TlsConnectStreamPre13 : public TlsConnectStream {};
class TlsConnectDatagram : public TlsConnectTestBase,
public ::testing::WithParamInterface<uint16_t> {
public:
- TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {}
+ TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {}
};
-// A generic test class that can be either STREAM or DGRAM and a single version
-// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this
-// should use TEST_P().
-class TlsConnectGeneric
- : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+// A generic test class that can be either stream or datagram and a single
+// version of TLS. This is configured in ssl_loopback_unittest.cc.
+class TlsConnectGeneric : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectGeneric();
};
// A Pre TLS 1.2 generic test.
-class TlsConnectPre12
- : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsConnectPre12 : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectPre12();
};
// A TLS 1.2 only generic test.
-class TlsConnectTls12 : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::string> {
+class TlsConnectTls12
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
public:
TlsConnectTls12();
};
@@ -209,20 +212,21 @@ class TlsConnectTls12 : public TlsConnectTestBase,
class TlsConnectStreamTls12 : public TlsConnectTestBase {
public:
TlsConnectStreamTls12()
- : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_2) {}
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {}
};
// A TLS 1.2+ generic test.
-class TlsConnectTls12Plus
- : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsConnectTls12Plus : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectTls12Plus();
};
// A TLS 1.3 only generic test.
-class TlsConnectTls13 : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::string> {
+class TlsConnectTls13
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
public:
TlsConnectTls13();
};
@@ -231,13 +235,13 @@ class TlsConnectTls13 : public TlsConnectTestBase,
class TlsConnectStreamTls13 : public TlsConnectTestBase {
public:
TlsConnectStreamTls13()
- : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
class TlsConnectDatagram13 : public TlsConnectTestBase {
public:
TlsConnectDatagram13()
- : TlsConnectTestBase(DGRAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+ : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
// A variant that is used only with Pre13.
@@ -245,10 +249,10 @@ class TlsConnectGenericPre13 : public TlsConnectGeneric {};
class TlsKeyExchangeTest : public TlsConnectGeneric {
protected:
- TlsExtensionCapture* groups_capture_;
- TlsExtensionCapture* shares_capture_;
- TlsExtensionCapture* shares_capture2_;
- TlsInspectorRecordHandshakeMessage* capture_hrr_;
+ std::shared_ptr<TlsExtensionCapture> groups_capture_;
+ std::shared_ptr<TlsExtensionCapture> shares_capture_;
+ std::shared_ptr<TlsExtensionCapture> shares_capture2_;
+ std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_;
void EnsureKeyShareSetup();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
diff --git a/nss/gtests/ssl_gtest/tls_filter.cc b/nss/gtests/ssl_gtest/tls_filter.cc
index 4f7d195..76d9aaa 100644
--- a/nss/gtests/ssl_gtest/tls_filter.cc
+++ b/nss/gtests/ssl_gtest/tls_filter.cc
@@ -15,9 +15,62 @@ extern "C" {
#include <iostream>
#include "gtest_utils.h"
#include "tls_agent.h"
+#include "tls_filter.h"
+#include "tls_protect.h"
namespace nss_test {
+void TlsVersioned::WriteStream(std::ostream& stream) const {
+ stream << (is_dtls() ? "DTLS " : "TLS ");
+ switch (version()) {
+ case 0:
+ stream << "(no version)";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_0:
+ stream << "1.0";
+ break;
+ case SSL_LIBRARY_VERSION_DTLS_1_0_WIRE:
+ case SSL_LIBRARY_VERSION_TLS_1_1:
+ stream << (is_dtls() ? "1.0" : "1.1");
+ break;
+ case SSL_LIBRARY_VERSION_DTLS_1_2_WIRE:
+ case SSL_LIBRARY_VERSION_TLS_1_2:
+ stream << "1.2";
+ break;
+ case SSL_LIBRARY_VERSION_TLS_1_3:
+ stream << "1.3";
+ break;
+ default:
+ stream << "Invalid version: " << version();
+ break;
+ }
+}
+
+void TlsRecordFilter::EnableDecryption() {
+ SSLInt_SetCipherSpecChangeFunc(agent()->ssl_fd(), CipherSpecChanged,
+ (void*)this);
+}
+
+void TlsRecordFilter::CipherSpecChanged(void* arg, PRBool sending,
+ ssl3CipherSpec* newSpec) {
+ TlsRecordFilter* self = static_cast<TlsRecordFilter*>(arg);
+ PRBool isServer = self->agent()->role() == TlsAgent::SERVER;
+
+ if (g_ssl_gtest_verbose) {
+ std::cerr << "Cipher spec changed. Role="
+ << (isServer ? "server" : "client")
+ << " direction=" << (sending ? "send" : "receive") << std::endl;
+ }
+ if (!sending) return;
+
+ self->cipher_spec_.reset(new TlsCipherSpec());
+ bool ret =
+ self->cipher_spec_->Init(SSLInt_CipherSpecToAlgorithm(isServer, newSpec),
+ SSLInt_CipherSpecToKey(isServer, newSpec),
+ SSLInt_CipherSpecToIv(isServer, newSpec));
+ EXPECT_EQ(true, ret);
+}
+
PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
bool changed = false;
@@ -25,10 +78,13 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
output->Allocate(input.len());
TlsParser parser(input);
+
while (parser.remaining()) {
- RecordHeader header;
+ TlsRecordHeader header;
DataBuffer record;
+
if (!header.Parse(&parser, &record)) {
+ ADD_FAILURE() << "not a valid record";
return KEEP;
}
@@ -49,12 +105,21 @@ PacketFilter::Action TlsRecordFilter::Filter(const DataBuffer& input,
return KEEP;
}
-PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header,
- const DataBuffer& record,
- size_t* offset,
- DataBuffer* output) {
+PacketFilter::Action TlsRecordFilter::FilterRecord(
+ const TlsRecordHeader& header, const DataBuffer& record, size_t* offset,
+ DataBuffer* output) {
DataBuffer filtered;
- PacketFilter::Action action = FilterRecord(header, record, &filtered);
+ uint8_t inner_content_type;
+ DataBuffer plaintext;
+
+ if (!Unprotect(header, record, &inner_content_type, &plaintext)) {
+ return KEEP;
+ }
+
+ TlsRecordHeader real_header = {header.version(), inner_content_type,
+ header.sequence_number()};
+
+ PacketFilter::Action action = FilterRecord(real_header, plaintext, &filtered);
if (action == KEEP) {
return KEEP;
}
@@ -64,19 +129,21 @@ PacketFilter::Action TlsRecordFilter::FilterRecord(const RecordHeader& header,
return DROP;
}
- const DataBuffer* source = &record;
- if (action == CHANGE) {
- EXPECT_GT(0x10000U, filtered.len());
- std::cerr << "record old: " << record << std::endl;
- std::cerr << "record new: " << filtered << std::endl;
- source = &filtered;
- }
+ EXPECT_GT(0x10000U, filtered.len());
+ std::cerr << "record old: " << plaintext << std::endl;
+ std::cerr << "record new: " << filtered << std::endl;
- *offset = header.Write(output, *offset, *source);
+ DataBuffer ciphertext;
+ bool rv = Protect(header, inner_content_type, filtered, &ciphertext);
+ EXPECT_TRUE(rv);
+ if (!rv) {
+ return KEEP;
+ }
+ *offset = header.Write(output, *offset, ciphertext);
return CHANGE;
}
-bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
+bool TlsRecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
if (!parser->Read(&content_type_)) {
return false;
}
@@ -102,8 +169,8 @@ bool TlsRecordFilter::RecordHeader::Parse(TlsParser* parser, DataBuffer* body) {
return parser->ReadVariable(body, 2);
}
-size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset,
- const DataBuffer& body) const {
+size_t TlsRecordHeader::Write(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body) const {
offset = buffer->Write(offset, content_type_, 1);
offset = buffer->Write(offset, version_, 2);
if (is_dtls()) {
@@ -116,8 +183,48 @@ size_t TlsRecordFilter::RecordHeader::Write(DataBuffer* buffer, size_t offset,
return offset;
}
+bool TlsRecordFilter::Unprotect(const TlsRecordHeader& header,
+ const DataBuffer& ciphertext,
+ uint8_t* inner_content_type,
+ DataBuffer* plaintext) {
+ if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) {
+ *inner_content_type = header.content_type();
+ *plaintext = ciphertext;
+ return true;
+ }
+
+ if (!cipher_spec_->Unprotect(header, ciphertext, plaintext)) return false;
+
+ size_t len = plaintext->len();
+ while (len > 0 && !plaintext->data()[len - 1]) {
+ --len;
+ }
+ if (!len) {
+ // Bogus padding.
+ return false;
+ }
+
+ *inner_content_type = plaintext->data()[len - 1];
+ plaintext->Truncate(len - 1);
+
+ return true;
+}
+
+bool TlsRecordFilter::Protect(const TlsRecordHeader& header,
+ uint8_t inner_content_type,
+ const DataBuffer& plaintext,
+ DataBuffer* ciphertext) {
+ if (!cipher_spec_ || header.content_type() != kTlsApplicationDataType) {
+ *ciphertext = plaintext;
+ return true;
+ }
+ DataBuffer padded = plaintext;
+ padded.Write(padded.len(), inner_content_type, 1);
+ return cipher_spec_->Protect(header, padded, ciphertext);
+}
+
PacketFilter::Action TlsHandshakeFilter::FilterRecord(
- const RecordHeader& record_header, const DataBuffer& input,
+ const TlsRecordHeader& record_header, const DataBuffer& input,
DataBuffer* output) {
// Check that the first byte is as requested.
if (record_header.content_type() != kTlsHandshakeType) {
@@ -159,9 +266,8 @@ PacketFilter::Action TlsHandshakeFilter::FilterRecord(
return changed ? (offset ? CHANGE : DROP) : KEEP;
}
-bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser,
- const RecordHeader& header,
- uint32_t* length) {
+bool TlsHandshakeFilter::HandshakeHeader::ReadLength(
+ TlsParser* parser, const TlsRecordHeader& header, uint32_t* length) {
if (!parser->Read(length, 3)) {
return false; // malformed
}
@@ -192,7 +298,7 @@ bool TlsHandshakeFilter::HandshakeHeader::ReadLength(TlsParser* parser,
}
bool TlsHandshakeFilter::HandshakeHeader::Parse(
- TlsParser* parser, const RecordHeader& record_header, DataBuffer* body) {
+ TlsParser* parser, const TlsRecordHeader& record_header, DataBuffer* body) {
version_ = record_header.version();
if (!parser->Read(&handshake_type_)) {
return false; // malformed
@@ -205,15 +311,28 @@ bool TlsHandshakeFilter::HandshakeHeader::Parse(
return parser->Read(body, length);
}
-size_t TlsHandshakeFilter::HandshakeHeader::Write(
- DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
+size_t TlsHandshakeFilter::HandshakeHeader::WriteFragment(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body,
+ size_t fragment_offset, size_t fragment_length) const {
+ EXPECT_TRUE(is_dtls());
+ EXPECT_GE(body.len(), fragment_offset + fragment_length);
offset = buffer->Write(offset, handshake_type(), 1);
offset = buffer->Write(offset, body.len(), 3);
+ offset = buffer->Write(offset, message_seq_, 2);
+ offset = buffer->Write(offset, fragment_offset, 3);
+ offset = buffer->Write(offset, fragment_length, 3);
+ offset =
+ buffer->Write(offset, body.data() + fragment_offset, fragment_length);
+ return offset;
+}
+
+size_t TlsHandshakeFilter::HandshakeHeader::Write(
+ DataBuffer* buffer, size_t offset, const DataBuffer& body) const {
if (is_dtls()) {
- offset = buffer->Write(offset, message_seq_, 2);
- offset = buffer->Write(offset, 0U, 3); // fragment_offset
- offset = buffer->Write(offset, body.len(), 3);
+ return WriteFragment(buffer, offset, body, 0U, body.len());
}
+ offset = buffer->Write(offset, handshake_type(), 1);
+ offset = buffer->Write(offset, body.len(), 3);
offset = buffer->Write(offset, body);
return offset;
}
@@ -244,42 +363,12 @@ PacketFilter::Action TlsInspectorReplaceHandshakeMessage::FilterHandshake(
}
PacketFilter::Action TlsConversationRecorder::FilterRecord(
- const RecordHeader& header, const DataBuffer& input, DataBuffer* output) {
+ const TlsRecordHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
buffer_.Append(input);
return KEEP;
}
-PacketFilter::Action TlsAlertRecorder::FilterRecord(const RecordHeader& header,
- const DataBuffer& input,
- DataBuffer* output) {
- if (level_ == kTlsAlertFatal) { // already fatal
- return KEEP;
- }
- if (header.content_type() != kTlsAlertType) {
- return KEEP;
- }
-
- std::cerr << "Alert: " << input << std::endl;
-
- TlsParser parser(input);
- uint8_t lvl;
- if (!parser.Read(&lvl)) {
- return KEEP;
- }
- if (lvl == kTlsAlertWarning) { // not strong enough
- return KEEP;
- }
- level_ = lvl;
- (void)parser.Read(&description_);
- return KEEP;
-}
-
-ChainedPacketFilter::~ChainedPacketFilter() {
- for (auto it = filters_.begin(); it != filters_.end(); ++it) {
- delete *it;
- }
-}
-
PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
DataBuffer* output) {
DataBuffer in(input);
@@ -297,28 +386,7 @@ PacketFilter::Action ChainedPacketFilter::Filter(const DataBuffer& input,
return changed ? CHANGE : KEEP;
}
-PacketFilter::Action TlsExtensionFilter::FilterHandshake(
- const HandshakeHeader& header, const DataBuffer& input,
- DataBuffer* output) {
- if (header.handshake_type() == kTlsHandshakeClientHello) {
- TlsParser parser(input);
- if (!FindClientHelloExtensions(&parser, header)) {
- return KEEP;
- }
- return FilterExtensions(&parser, input, output);
- }
- if (header.handshake_type() == kTlsHandshakeServerHello) {
- TlsParser parser(input);
- if (!FindServerHelloExtensions(&parser)) {
- return KEEP;
- }
- return FilterExtensions(&parser, input, output);
- }
- return KEEP;
-}
-
-bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser,
- const Versioned& header) {
+bool FindClientHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
if (!parser->Skip(2 + 32)) { // version + random
return false;
}
@@ -337,7 +405,7 @@ bool TlsExtensionFilter::FindClientHelloExtensions(TlsParser* parser,
return true;
}
-bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) {
+bool FindServerHelloExtensions(TlsParser* parser, const TlsVersioned& header) {
uint32_t vtmp;
if (!parser->Read(&vtmp, 2)) {
return false;
@@ -362,6 +430,92 @@ bool TlsExtensionFilter::FindServerHelloExtensions(TlsParser* parser) {
return true;
}
+static bool FindHelloRetryExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ // TODO for -19 add cipher suite
+ if (!parser->Skip(2)) { // version
+ return false;
+ }
+ return true;
+}
+
+bool FindEncryptedExtensions(TlsParser* parser, const TlsVersioned& header) {
+ return true;
+}
+
+static bool FindCertReqExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->SkipVariable(1)) { // request context
+ return false;
+ }
+ // TODO remove the next two for -19
+ if (!parser->SkipVariable(2)) { // signature_algorithms
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // certificate_authorities
+ return false;
+ }
+ return true;
+}
+
+// Only look at the EE cert for this one.
+static bool FindCertificateExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->SkipVariable(1)) { // request context
+ return false;
+ }
+ if (!parser->Skip(3)) { // length of certificate list
+ return false;
+ }
+ if (!parser->SkipVariable(3)) { // ASN1Cert
+ return false;
+ }
+ return true;
+}
+
+static bool FindNewSessionTicketExtensions(TlsParser* parser,
+ const TlsVersioned& header) {
+ if (!parser->Skip(8)) { // lifetime, age add
+ return false;
+ }
+ if (!parser->SkipVariable(2)) { // ticket
+ return false;
+ }
+ return true;
+}
+
+static const std::map<uint16_t, TlsExtensionFinder> kExtensionFinders = {
+ {kTlsHandshakeClientHello, FindClientHelloExtensions},
+ {kTlsHandshakeServerHello, FindServerHelloExtensions},
+ {kTlsHandshakeHelloRetryRequest, FindHelloRetryExtensions},
+ {kTlsHandshakeEncryptedExtensions, FindEncryptedExtensions},
+ {kTlsHandshakeCertificateRequest, FindCertReqExtensions},
+ {kTlsHandshakeCertificate, FindCertificateExtensions},
+ {kTlsHandshakeNewSessionTicket, FindNewSessionTicketExtensions}};
+
+bool TlsExtensionFilter::FindExtensions(TlsParser* parser,
+ const HandshakeHeader& header) {
+ auto it = kExtensionFinders.find(header.handshake_type());
+ if (it == kExtensionFinders.end()) {
+ return false;
+ }
+ return (it->second)(parser, header);
+}
+
+PacketFilter::Action TlsExtensionFilter::FilterHandshake(
+ const HandshakeHeader& header, const DataBuffer& input,
+ DataBuffer* output) {
+ if (handshake_types_.count(header.handshake_type()) == 0) {
+ return KEEP;
+ }
+
+ TlsParser parser(input);
+ if (!FindExtensions(&parser, header)) {
+ return KEEP;
+ }
+ return FilterExtensions(&parser, input, output);
+}
+
PacketFilter::Action TlsExtensionFilter::FilterExtensions(
TlsParser* parser, const DataBuffer& input, DataBuffer* output) {
size_t length_offset = parser->consumed();
@@ -456,14 +610,14 @@ PacketFilter::Action TlsExtensionDropper::FilterExtension(
return KEEP;
}
-PacketFilter::Action AfterRecordN::FilterRecord(const RecordHeader& header,
+PacketFilter::Action AfterRecordN::FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) {
if (counter_++ == record_) {
DataBuffer buf;
header.Write(&buf, 0, body);
- src_->SendDirect(buf);
- dest_->Handshake();
+ src_.lock()->SendDirect(buf);
+ dest_.lock()->Handshake();
func_();
return DROP;
}
@@ -476,7 +630,7 @@ PacketFilter::Action TlsInspectorClientHelloVersionChanger::FilterHandshake(
DataBuffer* output) {
if (header.handshake_type() == kTlsHandshakeClientKeyExchange) {
EXPECT_EQ(SECSuccess,
- SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd()));
+ SSLInt_IncrementClientHandshakeVersion(server_.lock()->ssl_fd()));
}
return KEEP;
}
diff --git a/nss/gtests/ssl_gtest/tls_filter.h b/nss/gtests/ssl_gtest/tls_filter.h
index fa2e387..e4030e2 100644
--- a/nss/gtests/ssl_gtest/tls_filter.h
+++ b/nss/gtests/ssl_gtest/tls_filter.h
@@ -9,17 +9,67 @@
#include <functional>
#include <memory>
+#include <set>
#include <vector>
#include "test_io.h"
#include "tls_parser.h"
+#include "tls_protect.h"
+
+extern "C" {
+#include "libssl_internals.h"
+}
namespace nss_test {
+class TlsCipherSpec;
+class TlsAgent;
+
+class TlsVersioned {
+ public:
+ TlsVersioned() : version_(0) {}
+ explicit TlsVersioned(uint16_t version) : version_(version) {}
+
+ bool is_dtls() const { return IsDtls(version_); }
+ uint16_t version() const { return version_; }
+
+ void WriteStream(std::ostream& stream) const;
+
+ protected:
+ uint16_t version_;
+};
+
+class TlsRecordHeader : public TlsVersioned {
+ public:
+ TlsRecordHeader() : TlsVersioned(), content_type_(0), sequence_number_(0) {}
+ TlsRecordHeader(uint16_t version, uint8_t content_type,
+ uint64_t sequence_number)
+ : TlsVersioned(version),
+ content_type_(content_type),
+ sequence_number_(sequence_number) {}
+
+ uint8_t content_type() const { return content_type_; }
+ uint64_t sequence_number() const { return sequence_number_; }
+ size_t header_length() const { return is_dtls() ? 11 : 3; }
+
+ // Parse the header; return true if successful; body in an outparam if OK.
+ bool Parse(TlsParser* parser, DataBuffer* body);
+ // Write the header and body to a buffer at the given offset.
+ // Return the offset of the end of the write.
+ size_t Write(DataBuffer* buffer, size_t offset, const DataBuffer& body) const;
+
+ private:
+ uint8_t content_type_;
+ uint64_t sequence_number_;
+};
+
// Abstract filter that operates on entire (D)TLS records.
class TlsRecordFilter : public PacketFilter {
public:
- TlsRecordFilter() : count_(0) {}
+ TlsRecordFilter() : agent_(nullptr), count_(0), cipher_spec_() {}
+
+ void SetAgent(const TlsAgent* agent) { agent_ = agent; }
+ const TlsAgent* agent() const { return agent_; }
// External interface. Overrides PacketFilter.
PacketFilter::Action Filter(const DataBuffer& input, DataBuffer* output);
@@ -27,42 +77,14 @@ class TlsRecordFilter : public PacketFilter {
// Report how many packets were altered by the filter.
size_t filtered_packets() const { return count_; }
- class Versioned {
- public:
- Versioned() : version_(0) {}
- explicit Versioned(uint16_t version) : version_(version) {}
-
- bool is_dtls() const { return IsDtls(version_); }
- uint16_t version() const { return version_; }
-
- protected:
- uint16_t version_;
- };
-
- class RecordHeader : public Versioned {
- public:
- RecordHeader() : Versioned(), content_type_(0), sequence_number_(0) {}
- RecordHeader(uint16_t version, uint8_t content_type,
- uint64_t sequence_number)
- : Versioned(version),
- content_type_(content_type),
- sequence_number_(sequence_number) {}
-
- uint8_t content_type() const { return content_type_; }
- uint64_t sequence_number() const { return sequence_number_; }
- size_t header_length() const { return is_dtls() ? 11 : 3; }
-
- // Parse the header; return true if successful; body in an outparam if OK.
- bool Parse(TlsParser* parser, DataBuffer* body);
- // Write the header and body to a buffer at the given offset.
- // Return the offset of the end of the write.
- size_t Write(DataBuffer* buffer, size_t offset,
- const DataBuffer& body) const;
-
- private:
- uint8_t content_type_;
- uint64_t sequence_number_;
- };
+ // Enable decryption. This only works properly for TLS 1.3 and above.
+ // Enabling it for lower version tests will cause undefined
+ // behavior.
+ void EnableDecryption();
+ bool Unprotect(const TlsRecordHeader& header, const DataBuffer& cipherText,
+ uint8_t* inner_content_type, DataBuffer* plaintext);
+ bool Protect(const TlsRecordHeader& header, uint8_t inner_content_type,
+ const DataBuffer& plaintext, DataBuffer* ciphertext);
protected:
// There are two filter functions which can be overriden. Both are
@@ -72,7 +94,7 @@ class TlsRecordFilter : public PacketFilter {
// just lets you change the record contents. By default, the
// outer one calls the inner one, so if you override the outer
// one, the inner one is never called unless you call it yourself.
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& record,
size_t* offset, DataBuffer* output);
@@ -80,16 +102,49 @@ class TlsRecordFilter : public PacketFilter {
// sequence number (which is zero for TLS), plus the existing record payload.
// It returns an action (KEEP, CHANGE, DROP). It writes to the `changed`
// outparam with the new record contents if it chooses to CHANGE the record.
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& data,
DataBuffer* changed) {
return KEEP;
}
private:
+ static void CipherSpecChanged(void* arg, PRBool sending,
+ ssl3CipherSpec* newSpec);
+
+ const TlsAgent* agent_;
size_t count_;
+ std::unique_ptr<TlsCipherSpec> cipher_spec_;
};
+inline std::ostream& operator<<(std::ostream& stream, TlsVersioned v) {
+ v.WriteStream(stream);
+ return stream;
+}
+
+inline std::ostream& operator<<(std::ostream& stream, TlsRecordHeader& hdr) {
+ hdr.WriteStream(stream);
+ stream << ' ';
+ switch (hdr.content_type()) {
+ case kTlsChangeCipherSpecType:
+ stream << "CCS";
+ break;
+ case kTlsAlertType:
+ stream << "Alert";
+ break;
+ case kTlsHandshakeType:
+ stream << "Handshake";
+ break;
+ case kTlsApplicationDataType:
+ stream << "Data";
+ break;
+ default:
+ stream << '<' << hdr.content_type() << '>';
+ break;
+ }
+ return stream << ' ' << std::hex << hdr.sequence_number() << std::dec;
+}
+
// Abstract filter that operates on handshake messages rather than records.
// This assumes that the handshake messages are written in a block as entire
// records and that they don't span records or anything crazy like that.
@@ -97,20 +152,23 @@ class TlsHandshakeFilter : public TlsRecordFilter {
public:
TlsHandshakeFilter() {}
- class HandshakeHeader : public Versioned {
+ class HandshakeHeader : public TlsVersioned {
public:
- HandshakeHeader() : Versioned(), handshake_type_(0), message_seq_(0) {}
+ HandshakeHeader() : TlsVersioned(), handshake_type_(0), message_seq_(0) {}
uint8_t handshake_type() const { return handshake_type_; }
- bool Parse(TlsParser* parser, const RecordHeader& record_header,
+ bool Parse(TlsParser* parser, const TlsRecordHeader& record_header,
DataBuffer* body);
size_t Write(DataBuffer* buffer, size_t offset,
const DataBuffer& body) const;
+ size_t WriteFragment(DataBuffer* buffer, size_t offset,
+ const DataBuffer& body, size_t fragment_offset,
+ size_t fragment_length) const;
private:
// Reads the length from the record header.
// This also reads the DTLS fragment information and checks it.
- bool ReadLength(TlsParser* parser, const RecordHeader& header,
+ bool ReadLength(TlsParser* parser, const TlsRecordHeader& header,
uint32_t* length);
uint8_t handshake_type_;
@@ -119,7 +177,7 @@ class TlsHandshakeFilter : public TlsRecordFilter {
};
protected:
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
@@ -167,7 +225,7 @@ class TlsConversationRecorder : public TlsRecordFilter {
public:
TlsConversationRecorder(DataBuffer& buffer) : buffer_(buffer) {}
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& input,
DataBuffer* output);
@@ -175,43 +233,39 @@ class TlsConversationRecorder : public TlsRecordFilter {
DataBuffer& buffer_;
};
-// Records an alert. If an alert has already been recorded, it won't save the
-// new alert unless the old alert is a warning and the new one is fatal.
-class TlsAlertRecorder : public TlsRecordFilter {
- public:
- TlsAlertRecorder() : level_(255), description_(255) {}
-
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
- const DataBuffer& input,
- DataBuffer* output);
-
- uint8_t level() const { return level_; }
- uint8_t description() const { return description_; }
-
- private:
- uint8_t level_;
- uint8_t description_;
-};
-
// Runs multiple packet filters in series.
class ChainedPacketFilter : public PacketFilter {
public:
ChainedPacketFilter() {}
- ChainedPacketFilter(const std::vector<PacketFilter*> filters)
+ ChainedPacketFilter(const std::vector<std::shared_ptr<PacketFilter>> filters)
: filters_(filters.begin(), filters.end()) {}
- virtual ~ChainedPacketFilter();
+ virtual ~ChainedPacketFilter() {}
virtual PacketFilter::Action Filter(const DataBuffer& input,
DataBuffer* output);
// Takes ownership of the filter.
- void Add(PacketFilter* filter) { filters_.push_back(filter); }
+ void Add(std::shared_ptr<PacketFilter> filter) { filters_.push_back(filter); }
private:
- std::vector<PacketFilter*> filters_;
+ std::vector<std::shared_ptr<PacketFilter>> filters_;
};
+typedef std::function<bool(TlsParser* parser, const TlsVersioned& header)>
+ TlsExtensionFinder;
+
class TlsExtensionFilter : public TlsHandshakeFilter {
+ public:
+ TlsExtensionFilter() : handshake_types_() {
+ handshake_types_.insert(kTlsHandshakeClientHello);
+ handshake_types_.insert(kTlsHandshakeServerHello);
+ }
+
+ TlsExtensionFilter(const std::set<uint8_t>& types)
+ : handshake_types_(types) {}
+
+ static bool FindExtensions(TlsParser* parser, const HandshakeHeader& header);
+
protected:
PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
@@ -221,15 +275,12 @@ class TlsExtensionFilter : public TlsHandshakeFilter {
const DataBuffer& input,
DataBuffer* output) = 0;
- public:
- static bool FindClientHelloExtensions(TlsParser* parser,
- const Versioned& header);
- static bool FindServerHelloExtensions(TlsParser* parser);
-
private:
PacketFilter::Action FilterExtensions(TlsParser* parser,
const DataBuffer& input,
DataBuffer* output);
+
+ std::set<uint8_t> handshake_types_;
};
class TlsExtensionCapture : public TlsExtensionFilter {
@@ -280,17 +331,17 @@ typedef std::function<void(void)> VoidFunction;
class AfterRecordN : public TlsRecordFilter {
public:
- AfterRecordN(TlsAgent* src, TlsAgent* dest, unsigned int record,
- VoidFunction func)
+ AfterRecordN(std::shared_ptr<TlsAgent>& src, std::shared_ptr<TlsAgent>& dest,
+ unsigned int record, VoidFunction func)
: src_(src), dest_(dest), record_(record), func_(func), counter_(0) {}
- virtual PacketFilter::Action FilterRecord(const RecordHeader& header,
+ virtual PacketFilter::Action FilterRecord(const TlsRecordHeader& header,
const DataBuffer& body,
DataBuffer* out) override;
private:
- TlsAgent* src_;
- TlsAgent* dest_;
+ std::weak_ptr<TlsAgent> src_;
+ std::weak_ptr<TlsAgent> dest_;
unsigned int record_;
VoidFunction func_;
unsigned int counter_;
@@ -300,14 +351,15 @@ class AfterRecordN : public TlsRecordFilter {
// ClientHelloVersion on |server|.
class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter {
public:
- TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {}
+ TlsInspectorClientHelloVersionChanger(std::shared_ptr<TlsAgent>& server)
+ : server_(server) {}
virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
const DataBuffer& input,
DataBuffer* output);
private:
- TlsAgent* server_;
+ std::weak_ptr<TlsAgent> server_;
};
// This class selectively drops complete writes. This relies on the fact that
@@ -338,6 +390,27 @@ class TlsInspectorClientHelloVersionSetter : public TlsHandshakeFilter {
uint16_t version_;
};
+// Damages the last byte of a handshake message.
+class TlsLastByteDamager : public TlsHandshakeFilter {
+ public:
+ TlsLastByteDamager(uint8_t type) : type_(type) {}
+ PacketFilter::Action FilterHandshake(
+ const TlsHandshakeFilter::HandshakeHeader& header,
+ const DataBuffer& input, DataBuffer* output) override {
+ if (header.handshake_type() != type_) {
+ return KEEP;
+ }
+
+ *output = input;
+
+ output->data()[output->len() - 1]++;
+ return CHANGE;
+ }
+
+ private:
+ uint8_t type_;
+};
+
} // namespace nss_test
#endif
diff --git a/nss/gtests/ssl_gtest/tls_parser.cc b/nss/gtests/ssl_gtest/tls_parser.cc
deleted file mode 100644
index e4c06aa..0000000
--- a/nss/gtests/ssl_gtest/tls_parser.cc
+++ /dev/null
@@ -1,73 +0,0 @@
-/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
-/* vim: set ts=2 et sw=2 tw=80: */
-/* This Source Code Form is subject to the terms of the Mozilla Public
- * License, v. 2.0. If a copy of the MPL was not distributed with this file,
- * You can obtain one at http://mozilla.org/MPL/2.0/. */
-
-#include "tls_parser.h"
-
-namespace nss_test {
-
-bool TlsParser::Read(uint8_t* val) {
- if (remaining() < 1) {
- return false;
- }
- *val = *ptr();
- consume(1);
- return true;
-}
-
-bool TlsParser::Read(uint32_t* val, size_t size) {
- if (size > sizeof(uint32_t)) {
- return false;
- }
-
- uint32_t v = 0;
- for (size_t i = 0; i < size; ++i) {
- uint8_t tmp;
- if (!Read(&tmp)) {
- return false;
- }
-
- v = (v << 8) | tmp;
- }
-
- *val = v;
- return true;
-}
-
-bool TlsParser::Read(DataBuffer* val, size_t len) {
- if (remaining() < len) {
- return false;
- }
-
- val->Assign(ptr(), len);
- consume(len);
- return true;
-}
-
-bool TlsParser::ReadVariable(DataBuffer* val, size_t len_size) {
- uint32_t len;
- if (!Read(&len, len_size)) {
- return false;
- }
- return Read(val, len);
-}
-
-bool TlsParser::Skip(size_t len) {
- if (len > remaining()) {
- return false;
- }
- consume(len);
- return true;
-}
-
-bool TlsParser::SkipVariable(size_t len_size) {
- uint32_t len;
- if (!Read(&len, len_size)) {
- return false;
- }
- return Skip(len);
-}
-
-} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/tls_parser.h b/nss/gtests/ssl_gtest/tls_parser.h
deleted file mode 100644
index c79d45a..0000000
--- a/nss/gtests/ssl_gtest/tls_parser.h
+++ /dev/null
@@ -1,131 +0,0 @@
-/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
-/* vim: set ts=2 et sw=2 tw=80: */
-/* This Source Code Form is subject to the terms of the Mozilla Public
- * License, v. 2.0. If a copy of the MPL was not distributed with this file,
- * You can obtain one at http://mozilla.org/MPL/2.0/. */
-
-#ifndef tls_parser_h_
-#define tls_parser_h_
-
-#include <cstdint>
-#include <cstring>
-#include <memory>
-#if defined(WIN32) || defined(WIN64)
-#include <winsock2.h>
-#else
-#include <arpa/inet.h>
-#endif
-#include "databuffer.h"
-
-namespace nss_test {
-
-const uint8_t kTlsChangeCipherSpecType = 20;
-const uint8_t kTlsAlertType = 21;
-const uint8_t kTlsHandshakeType = 22;
-const uint8_t kTlsApplicationDataType = 23;
-
-const uint8_t kTlsHandshakeClientHello = 1;
-const uint8_t kTlsHandshakeServerHello = 2;
-const uint8_t kTlsHandshakeHelloRetryRequest = 6;
-const uint8_t kTlsHandshakeEncryptedExtensions = 8;
-const uint8_t kTlsHandshakeCertificate = 11;
-const uint8_t kTlsHandshakeServerKeyExchange = 12;
-const uint8_t kTlsHandshakeCertificateVerify = 15;
-const uint8_t kTlsHandshakeClientKeyExchange = 16;
-const uint8_t kTlsHandshakeFinished = 20;
-
-const uint8_t kTlsAlertWarning = 1;
-const uint8_t kTlsAlertFatal = 2;
-
-const uint8_t kTlsAlertUnexpectedMessage = 10;
-const uint8_t kTlsAlertBadRecordMac = 20;
-const uint8_t kTlsAlertHandshakeFailure = 40;
-const uint8_t kTlsAlertIllegalParameter = 47;
-const uint8_t kTlsAlertDecodeError = 50;
-const uint8_t kTlsAlertDecryptError = 51;
-const uint8_t kTlsAlertMissingExtension = 109;
-const uint8_t kTlsAlertUnsupportedExtension = 110;
-const uint8_t kTlsAlertUnrecognizedName = 112;
-const uint8_t kTlsAlertNoApplicationProtocol = 120;
-
-const uint8_t kTlsFakeChangeCipherSpec[] = {
- kTlsChangeCipherSpecType, // Type
- 0xfe,
- 0xff, // Version
- 0x00,
- 0x00,
- 0x00,
- 0x00,
- 0x00,
- 0x00,
- 0x00,
- 0x10, // Fictitious sequence #
- 0x00,
- 0x01, // Length
- 0x01 // Value
-};
-
-static const uint8_t kTls13PskKe = 0;
-static const uint8_t kTls13PskDhKe = 1;
-static const uint8_t kTls13PskAuth = 0;
-static const uint8_t kTls13PskSignAuth = 1;
-
-inline bool IsDtls(uint16_t version) { return (version & 0x8000) == 0x8000; }
-
-inline uint16_t NormalizeTlsVersion(uint16_t version) {
- if (version == 0xfeff) {
- return 0x0302; // special: DTLS 1.0 == TLS 1.1
- }
- if (IsDtls(version)) {
- return (version ^ 0xffff) + 0x0201;
- }
- return version;
-}
-
-inline uint16_t TlsVersionToDtlsVersion(uint16_t version) {
- if (version == 0x0302) {
- return 0xfeff;
- }
- if (version == 0x0304) {
- return version;
- }
- return 0xffff - version + 0x0201;
-}
-
-inline size_t WriteVariable(DataBuffer* target, size_t index,
- const DataBuffer& buf, size_t len_size) {
- index = target->Write(index, static_cast<uint32_t>(buf.len()), len_size);
- return target->Write(index, buf.data(), buf.len());
-}
-
-class TlsParser {
- public:
- TlsParser(const uint8_t* data, size_t len) : buffer_(data, len), offset_(0) {}
- explicit TlsParser(const DataBuffer& buf) : buffer_(buf), offset_(0) {}
-
- bool Read(uint8_t* val);
- // Read an integral type of specified width.
- bool Read(uint32_t* val, size_t size);
- // Reads len bytes into dest buffer, overwriting it.
- bool Read(DataBuffer* dest, size_t len);
- // Reads bytes into dest buffer, overwriting it. The number of bytes is
- // determined by reading from len_size bytes from the stream first.
- bool ReadVariable(DataBuffer* dest, size_t len_size);
-
- bool Skip(size_t len);
- bool SkipVariable(size_t len_size);
-
- size_t consumed() const { return offset_; }
- size_t remaining() const { return buffer_.len() - offset_; }
-
- private:
- void consume(size_t len) { offset_ += len; }
- const uint8_t* ptr() const { return buffer_.data() + offset_; }
-
- DataBuffer buffer_;
- size_t offset_;
-};
-
-} // namespace nss_test
-
-#endif
diff --git a/nss/gtests/ssl_gtest/tls_protect.cc b/nss/gtests/ssl_gtest/tls_protect.cc
new file mode 100644
index 0000000..efcd89e
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_protect.cc
@@ -0,0 +1,145 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "tls_protect.h"
+#include "tls_filter.h"
+
+namespace nss_test {
+
+AeadCipher::~AeadCipher() {
+ if (key_) {
+ PK11_FreeSymKey(key_);
+ }
+}
+
+bool AeadCipher::Init(PK11SymKey *key, const uint8_t *iv) {
+ key_ = PK11_ReferenceSymKey(key);
+ if (!key_) return false;
+
+ memcpy(iv_, iv, sizeof(iv_));
+ return true;
+}
+
+void AeadCipher::FormatNonce(uint64_t seq, uint8_t *nonce) {
+ memcpy(nonce, iv_, 12);
+
+ for (size_t i = 0; i < 8; ++i) {
+ nonce[12 - (i + 1)] ^= seq & 0xff;
+ seq >>= 8;
+ }
+
+ DataBuffer d(nonce, 12);
+ std::cerr << "Nonce " << d << std::endl;
+}
+
+bool AeadCipher::AeadInner(bool decrypt, void *params, size_t param_length,
+ const uint8_t *in, size_t inlen, uint8_t *out,
+ size_t *outlen, size_t maxlen) {
+ SECStatus rv;
+ unsigned int uoutlen = 0;
+ SECItem param = {
+ siBuffer, static_cast<unsigned char *>(params),
+ static_cast<unsigned int>(param_length),
+ };
+
+ if (decrypt) {
+ rv = PK11_Decrypt(key_, mech_, &param, out, &uoutlen, maxlen, in, inlen);
+ } else {
+ rv = PK11_Encrypt(key_, mech_, &param, out, &uoutlen, maxlen, in, inlen);
+ }
+ *outlen = (int)uoutlen;
+
+ return rv == SECSuccess;
+}
+
+bool AeadCipherAesGcm::Aead(bool decrypt, uint64_t seq, const uint8_t *in,
+ size_t inlen, uint8_t *out, size_t *outlen,
+ size_t maxlen) {
+ CK_GCM_PARAMS aeadParams;
+ unsigned char nonce[12];
+
+ memset(&aeadParams, 0, sizeof(aeadParams));
+ aeadParams.pIv = nonce;
+ aeadParams.ulIvLen = sizeof(nonce);
+ aeadParams.pAAD = NULL;
+ aeadParams.ulAADLen = 0;
+ aeadParams.ulTagBits = 128;
+
+ FormatNonce(seq, nonce);
+ return AeadInner(decrypt, (unsigned char *)&aeadParams, sizeof(aeadParams),
+ in, inlen, out, outlen, maxlen);
+}
+
+bool AeadCipherChacha20Poly1305::Aead(bool decrypt, uint64_t seq,
+ const uint8_t *in, size_t inlen,
+ uint8_t *out, size_t *outlen,
+ size_t maxlen) {
+ CK_NSS_AEAD_PARAMS aeadParams;
+ unsigned char nonce[12];
+
+ memset(&aeadParams, 0, sizeof(aeadParams));
+ aeadParams.pNonce = nonce;
+ aeadParams.ulNonceLen = sizeof(nonce);
+ aeadParams.pAAD = NULL;
+ aeadParams.ulAADLen = 0;
+ aeadParams.ulTagLen = 16;
+
+ FormatNonce(seq, nonce);
+ return AeadInner(decrypt, (unsigned char *)&aeadParams, sizeof(aeadParams),
+ in, inlen, out, outlen, maxlen);
+}
+
+bool TlsCipherSpec::Init(SSLCipherAlgorithm cipher, PK11SymKey *key,
+ const uint8_t *iv) {
+ switch (cipher) {
+ case ssl_calg_aes_gcm:
+ aead_.reset(new AeadCipherAesGcm());
+ break;
+ case ssl_calg_chacha20:
+ aead_.reset(new AeadCipherChacha20Poly1305());
+ break;
+ default:
+ return false;
+ }
+
+ return aead_->Init(key, iv);
+}
+
+bool TlsCipherSpec::Unprotect(const TlsRecordHeader &header,
+ const DataBuffer &ciphertext,
+ DataBuffer *plaintext) {
+ // Make space.
+ plaintext->Allocate(ciphertext.len());
+
+ size_t len;
+ bool ret =
+ aead_->Aead(true, header.sequence_number(), ciphertext.data(),
+ ciphertext.len(), plaintext->data(), &len, plaintext->len());
+ if (!ret) return false;
+
+ plaintext->Truncate(len);
+
+ return true;
+}
+
+bool TlsCipherSpec::Protect(const TlsRecordHeader &header,
+ const DataBuffer &plaintext,
+ DataBuffer *ciphertext) {
+ // Make a padded buffer.
+
+ ciphertext->Allocate(plaintext.len() +
+ 32); // Room for any plausible auth tag
+ size_t len;
+ bool ret =
+ aead_->Aead(false, header.sequence_number(), plaintext.data(),
+ plaintext.len(), ciphertext->data(), &len, ciphertext->len());
+ if (!ret) return false;
+ ciphertext->Truncate(len);
+
+ return true;
+}
+
+} // namespace nss_test
diff --git a/nss/gtests/ssl_gtest/tls_protect.h b/nss/gtests/ssl_gtest/tls_protect.h
new file mode 100644
index 0000000..4efbd6e
--- /dev/null
+++ b/nss/gtests/ssl_gtest/tls_protect.h
@@ -0,0 +1,76 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#ifndef tls_protection_h_
+#define tls_protection_h_
+
+#include <cstdint>
+#include <memory>
+
+#include "databuffer.h"
+#include "pk11pub.h"
+#include "sslt.h"
+
+namespace nss_test {
+class TlsRecordHeader;
+
+class AeadCipher {
+ public:
+ AeadCipher(CK_MECHANISM_TYPE mech) : mech_(mech), key_(nullptr) {}
+ ~AeadCipher();
+
+ bool Init(PK11SymKey *key, const uint8_t *iv);
+ virtual bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen,
+ uint8_t *out, size_t *outlen, size_t maxlen) = 0;
+
+ protected:
+ void FormatNonce(uint64_t seq, uint8_t *nonce);
+ bool AeadInner(bool decrypt, void *params, size_t param_length,
+ const uint8_t *in, size_t inlen, uint8_t *out, size_t *outlen,
+ size_t maxlen);
+
+ CK_MECHANISM_TYPE mech_;
+ PK11SymKey *key_;
+ uint8_t iv_[12];
+};
+
+class AeadCipherChacha20Poly1305 : public AeadCipher {
+ public:
+ AeadCipherChacha20Poly1305() : AeadCipher(CKM_NSS_CHACHA20_POLY1305) {}
+
+ protected:
+ bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen,
+ uint8_t *out, size_t *outlen, size_t maxlen);
+};
+
+class AeadCipherAesGcm : public AeadCipher {
+ public:
+ AeadCipherAesGcm() : AeadCipher(CKM_AES_GCM) {}
+
+ protected:
+ bool Aead(bool decrypt, uint64_t seq, const uint8_t *in, size_t inlen,
+ uint8_t *out, size_t *outlen, size_t maxlen);
+};
+
+// Our analog of ssl3CipherSpec
+class TlsCipherSpec {
+ public:
+ TlsCipherSpec() : aead_() {}
+
+ bool Init(SSLCipherAlgorithm cipher, PK11SymKey *key, const uint8_t *iv);
+
+ bool Protect(const TlsRecordHeader &header, const DataBuffer &plaintext,
+ DataBuffer *ciphertext);
+ bool Unprotect(const TlsRecordHeader &header, const DataBuffer &ciphertext,
+ DataBuffer *plaintext);
+
+ private:
+ std::unique_ptr<AeadCipher> aead_;
+};
+
+} // namespace nss_test
+
+#endif