summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/tls_connect.cc
diff options
context:
space:
mode:
Diffstat (limited to 'nss/gtests/ssl_gtest/tls_connect.cc')
-rw-r--r--nss/gtests/ssl_gtest/tls_connect.cc175
1 files changed, 115 insertions, 60 deletions
diff --git a/nss/gtests/ssl_gtest/tls_connect.cc b/nss/gtests/ssl_gtest/tls_connect.cc
index d025499..861d162 100644
--- a/nss/gtests/ssl_gtest/tls_connect.cc
+++ b/nss/gtests/ssl_gtest/tls_connect.cc
@@ -13,23 +13,27 @@ extern "C" {
#include "databuffer.h"
#include "gtest_utils.h"
+#include "scoped_ptrs.h"
#include "sslproto.h"
extern std::string g_working_dir_path;
namespace nss_test {
-static const std::string kTlsModesStreamArr[] = {"TLS"};
-::testing::internal::ParamGenerator<std::string>
- TlsConnectTestBase::kTlsModesStream =
- ::testing::ValuesIn(kTlsModesStreamArr);
-static const std::string kTlsModesDatagramArr[] = {"DTLS"};
-::testing::internal::ParamGenerator<std::string>
- TlsConnectTestBase::kTlsModesDatagram =
- ::testing::ValuesIn(kTlsModesDatagramArr);
-static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"};
-::testing::internal::ParamGenerator<std::string>
- TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr);
+static const SSLProtocolVariant kTlsVariantsStreamArr[] = {ssl_variant_stream};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsStream =
+ ::testing::ValuesIn(kTlsVariantsStreamArr);
+static const SSLProtocolVariant kTlsVariantsDatagramArr[] = {
+ ssl_variant_datagram};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsDatagram =
+ ::testing::ValuesIn(kTlsVariantsDatagramArr);
+static const SSLProtocolVariant kTlsVariantsAllArr[] = {ssl_variant_stream,
+ ssl_variant_datagram};
+::testing::internal::ParamGenerator<SSLProtocolVariant>
+ TlsConnectTestBase::kTlsVariantsAll =
+ ::testing::ValuesIn(kTlsVariantsAllArr);
static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0};
::testing::internal::ParamGenerator<uint16_t> TlsConnectTestBase::kTlsV10 =
@@ -99,30 +103,29 @@ std::string VersionString(uint16_t version) {
}
}
-TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version)
- : mode_(mode),
- client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_)),
- server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_)),
+TlsConnectTestBase::TlsConnectTestBase(SSLProtocolVariant variant,
+ uint16_t version)
+ : variant_(variant),
+ client_(new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_)),
+ server_(new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_)),
client_model_(nullptr),
server_model_(nullptr),
version_(version),
expected_resumption_mode_(RESUME_NONE),
session_ids_(),
expect_extended_master_secret_(false),
- expect_early_data_accepted_(false) {
+ expect_early_data_accepted_(false),
+ skip_version_checks_(false) {
std::string v;
- if (mode_ == DGRAM && version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
+ if (variant_ == ssl_variant_datagram &&
+ version_ == SSL_LIBRARY_VERSION_TLS_1_1) {
v = "1.0";
} else {
v = VersionString(version_);
}
- std::cerr << "Version: " << mode_ << " " << v << std::endl;
+ std::cerr << "Version: " << variant_ << " " << v << std::endl;
}
-TlsConnectTestBase::TlsConnectTestBase(const std::string& mode,
- uint16_t version)
- : TlsConnectTestBase(TlsConnectTestBase::ToMode(mode), version) {}
-
TlsConnectTestBase::~TlsConnectTestBase() {}
// Check the group of each of the supported groups
@@ -173,18 +176,15 @@ void TlsConnectTestBase::ClearServerCache() {
void TlsConnectTestBase::SetUp() {
SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str());
SSLInt_ClearSessionTicketKey();
+ SSLInt_SetTicketLifetime(30);
+ SSLInt_SetMaxEarlyDataSize(1024);
ClearStats();
Init();
}
void TlsConnectTestBase::TearDown() {
- delete client_;
- delete server_;
- if (client_model_) {
- ASSERT_NE(server_model_, nullptr);
- delete client_model_;
- delete server_model_;
- }
+ client_ = nullptr;
+ server_ = nullptr;
SSL_ClearSessionCache();
SSLInt_ClearSessionTicketKey();
@@ -192,9 +192,6 @@ void TlsConnectTestBase::TearDown() {
}
void TlsConnectTestBase::Init() {
- EXPECT_TRUE(client_->Init());
- EXPECT_TRUE(server_->Init());
-
client_->SetPeer(server_);
server_->SetPeer(client_);
@@ -212,11 +209,12 @@ void TlsConnectTestBase::Reset() {
void TlsConnectTestBase::Reset(const std::string& server_name,
const std::string& client_name) {
- delete client_;
- delete server_;
-
- client_ = new TlsAgent(client_name, TlsAgent::CLIENT, mode_);
- server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_);
+ client_.reset(new TlsAgent(client_name, TlsAgent::CLIENT, variant_));
+ server_.reset(new TlsAgent(server_name, TlsAgent::SERVER, variant_));
+ if (skip_version_checks_) {
+ client_->SkipVersionChecks();
+ server_->SkipVersionChecks();
+ }
Init();
}
@@ -276,10 +274,12 @@ void TlsConnectTestBase::ConnectWithCipherSuite(uint16_t cipher_suite) {
}
void TlsConnectTestBase::CheckConnected() {
- // Check the version is as expected
EXPECT_EQ(client_->version(), server_->version());
- EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
- client_->version());
+ if (!skip_version_checks_) {
+ // Check the version is as expected
+ EXPECT_EQ(std::min(client_->max_version(), server_->max_version()),
+ client_->version());
+ }
EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state());
EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state());
@@ -345,6 +345,13 @@ void TlsConnectTestBase::CheckKeys(SSLKEAType kea_type,
scheme = ssl_sig_none;
break;
case ssl_auth_rsa_sign:
+ if (version_ >= SSL_LIBRARY_VERSION_TLS_1_2) {
+ scheme = ssl_sig_rsa_pss_sha256;
+ } else {
+ scheme = ssl_sig_rsa_pkcs1_sha256;
+ }
+ break;
+ case ssl_auth_rsa_pss:
scheme = ssl_sig_rsa_pss_sha256;
break;
case ssl_auth_ecdsa:
@@ -373,7 +380,36 @@ void TlsConnectTestBase::ConnectExpectFail() {
ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state());
}
+void TlsConnectTestBase::ExpectAlert(std::shared_ptr<TlsAgent>& sender,
+ uint8_t alert) {
+ EnsureTlsSetup();
+ auto receiver = (sender == client_) ? server_ : client_;
+ sender->ExpectSendAlert(alert);
+ receiver->ExpectReceiveAlert(alert);
+}
+
+void TlsConnectTestBase::ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender,
+ uint8_t alert) {
+ ExpectAlert(sender, alert);
+ ConnectExpectFail();
+}
+
+void TlsConnectTestBase::ConnectExpectFailOneSide(TlsAgent::Role failing_side) {
+ server_->StartConnect();
+ client_->StartConnect();
+ client_->SetServerKeyBits(server_->server_key_bits());
+ client_->Handshake();
+ server_->Handshake();
+
+ auto failing_agent = server_;
+ if (failing_side == TlsAgent::CLIENT) {
+ failing_agent = client_;
+ }
+ ASSERT_TRUE_WAIT(failing_agent->state() == TlsAgent::STATE_ERROR, 5000);
+}
+
void TlsConnectTestBase::ConfigureVersion(uint16_t version) {
+ version_ = version;
client_->SetVersionRange(version, version);
server_->SetVersionRange(version, version);
}
@@ -424,10 +460,16 @@ void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client,
client_->ConfigureSessionCache(client);
server_->ConfigureSessionCache(server);
if ((server & RESUME_TICKET) != 0) {
- // This is an abomination. NSS encrypts session tickets with the server's
- // RSA public key. That means we need the server to have an RSA certificate
- // even if it won't be used for the connection.
- server_->ConfigServerCert(TlsAgent::kServerRsaDecrypt);
+ ScopedCERTCertificate cert;
+ ScopedSECKEYPrivateKey privKey;
+ ASSERT_TRUE(TlsAgent::LoadCertificate(TlsAgent::kServerRsaDecrypt, &cert,
+ &privKey));
+
+ ScopedSECKEYPublicKey pubKey(CERT_ExtractPublicKey(cert.get()));
+ ASSERT_TRUE(pubKey);
+
+ EXPECT_EQ(SECSuccess,
+ SSL_SetSessionTicketKeyPair(pubKey.get(), privKey.get()));
}
}
@@ -472,13 +514,15 @@ void TlsConnectTestBase::EnsureModelSockets() {
// Make sure models agents are available.
if (!client_model_) {
ASSERT_EQ(server_model_, nullptr);
- client_model_ = new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, mode_);
- server_model_ = new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, mode_);
+ client_model_.reset(
+ new TlsAgent(TlsAgent::kClient, TlsAgent::CLIENT, variant_));
+ server_model_.reset(
+ new TlsAgent(TlsAgent::kServerRsa, TlsAgent::SERVER, variant_));
+ if (skip_version_checks_) {
+ client_model_->SkipVersionChecks();
+ server_model_->SkipVersionChecks();
+ }
}
-
- // Initialise agents.
- ASSERT_TRUE(client_model_->Init());
- ASSERT_TRUE(server_model_->Init());
}
void TlsConnectTestBase::CheckAlpn(const std::string& val) {
@@ -540,6 +584,10 @@ void TlsConnectTestBase::ZeroRttSendReceive(
const char* k0RttData = "ABCDEF";
const PRInt32 k0RttDataLen = static_cast<PRInt32>(strlen(k0RttData));
+ if (expect_writable && expect_readable) {
+ ExpectAlert(client_, kTlsAlertEndOfEarlyData);
+ }
+
client_->Handshake(); // Send ClientHello.
if (post_clienthello_check) {
if (!post_clienthello_check()) return;
@@ -599,6 +647,12 @@ void TlsConnectTestBase::DisableECDHEServerKeyReuse() {
server_->DisableECDHEServerKeyReuse();
}
+void TlsConnectTestBase::SkipVersionChecks() {
+ skip_version_checks_ = true;
+ client_->SkipVersionChecks();
+ server_->SkipVersionChecks();
+}
+
TlsConnectGeneric::TlsConnectGeneric()
: TlsConnectTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {}
@@ -616,16 +670,17 @@ TlsConnectTls13::TlsConnectTls13()
void TlsKeyExchangeTest::EnsureKeyShareSetup() {
EnsureTlsSetup();
- groups_capture_ = new TlsExtensionCapture(ssl_supported_groups_xtn);
- shares_capture_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn);
- shares_capture2_ = new TlsExtensionCapture(ssl_tls13_key_share_xtn, true);
- std::vector<PacketFilter*> captures;
- captures.push_back(groups_capture_);
- captures.push_back(shares_capture_);
- captures.push_back(shares_capture2_);
- client_->SetPacketFilter(new ChainedPacketFilter(captures));
- capture_hrr_ =
- new TlsInspectorRecordHandshakeMessage(kTlsHandshakeHelloRetryRequest);
+ groups_capture_ =
+ std::make_shared<TlsExtensionCapture>(ssl_supported_groups_xtn);
+ shares_capture_ =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn);
+ shares_capture2_ =
+ std::make_shared<TlsExtensionCapture>(ssl_tls13_key_share_xtn, true);
+ std::vector<std::shared_ptr<PacketFilter>> captures = {
+ groups_capture_, shares_capture_, shares_capture2_};
+ client_->SetPacketFilter(std::make_shared<ChainedPacketFilter>(captures));
+ capture_hrr_ = std::make_shared<TlsInspectorRecordHandshakeMessage>(
+ kTlsHandshakeHelloRetryRequest);
server_->SetPacketFilter(capture_hrr_);
}