/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ /* vim: set ts=2 et sw=2 tw=80: */ /* This Source Code Form is subject to the terms of the Mozilla Public * License, v. 2.0. If a copy of the MPL was not distributed with this file, * You can obtain one at http://mozilla.org/MPL/2.0/. */ #include "tls_connect.h" #include #include "sslproto.h" #include "gtest_utils.h" extern std::string g_working_dir_path; namespace nss_test { static const std::string kTlsModesStreamArr[] = {"TLS"}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsModesStream = ::testing::ValuesIn(kTlsModesStreamArr); static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr); static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV10 = ::testing::ValuesIn(kTlsV10Arr); static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV11 = ::testing::ValuesIn(kTlsV11Arr); static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1, SSL_LIBRARY_VERSION_TLS_1_2}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV11V12 = ::testing::ValuesIn(kTlsV11V12Arr); // TODO: add TLS 1.3 static const uint16_t kTlsV12PlusArr[] = {SSL_LIBRARY_VERSION_TLS_1_2}; ::testing::internal::ParamGenerator TlsConnectTestBase::kTlsV12Plus = ::testing::ValuesIn(kTlsV12PlusArr); static std::string VersionString(uint16_t version) { switch(version) { case 0: return "(no version)"; case SSL_LIBRARY_VERSION_TLS_1_0: return "1.0"; case SSL_LIBRARY_VERSION_TLS_1_1: return "1.1"; case SSL_LIBRARY_VERSION_TLS_1_2: return "1.2"; default: std::cerr << "Invalid version: " << version << std::endl; EXPECT_TRUE(false); return ""; } } TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version) : mode_(mode), client_(new TlsAgent("client", TlsAgent::CLIENT, mode_, ssl_kea_rsa)), server_(new TlsAgent("server", TlsAgent::SERVER, mode_, ssl_kea_rsa)), version_(version), expected_resumption_mode_(RESUME_NONE), session_ids_(), expect_extended_master_secret_(false) { std::cerr << "Version: " << mode_ << " " << VersionString(version_) << std::endl; } TlsConnectTestBase::~TlsConnectTestBase() { } void TlsConnectTestBase::SetUp() { // Configure a fresh session cache. SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); // Clear statistics. SSL3Statistics* stats = SSL_GetStatistics(); memset(stats, 0, sizeof(*stats)); Init(); } void TlsConnectTestBase::TearDown() { delete client_; delete server_; SSL_ClearSessionCache(); SSL_ShutdownServerSessionIDCache(); } void TlsConnectTestBase::Init() { EXPECT_TRUE(client_->Init()); EXPECT_TRUE(server_->Init()); client_->SetPeer(server_); server_->SetPeer(client_); if (version_) { client_->SetVersionRange(version_, version_); server_->SetVersionRange(version_, version_); } } void TlsConnectTestBase::Reset(const std::string& server_name, SSLKEAType kea) { delete client_; delete server_; client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_, kea); server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_, kea); Init(); } void TlsConnectTestBase::ResetRsa() { Reset("server", ssl_kea_rsa); } void TlsConnectTestBase::ResetEcdsa() { Reset("ecdsa", ssl_kea_ecdh); } void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) { expected_resumption_mode_ = expected; if (expected != RESUME_NONE) { client_->ExpectResumption(); server_->ExpectResumption(); } } void TlsConnectTestBase::EnsureTlsSetup() { EXPECT_TRUE(client_->EnsureTlsSetup()); EXPECT_TRUE(server_->EnsureTlsSetup()); } void TlsConnectTestBase::Handshake() { client_->Handshake(); server_->Handshake(); ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) && (server_->state() != TlsAgent::STATE_CONNECTING), 5000); } void TlsConnectTestBase::EnableExtendedMasterSecret() { client_->EnableExtendedMasterSecret(); server_->EnableExtendedMasterSecret(); ExpectExtendedMasterSecret(true); } void TlsConnectTestBase::Connect() { server_->StartConnect(); client_->StartConnect(); Handshake(); CheckConnected(); } void TlsConnectTestBase::CheckConnected() { // Check the version is as expected EXPECT_EQ(client_->version(), server_->version()); EXPECT_EQ(std::min(client_->max_version(), server_->max_version()), client_->version()); EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); int16_t cipher_suite1, cipher_suite2; bool ret = client_->cipher_suite(&cipher_suite1); EXPECT_TRUE(ret); ret = server_->cipher_suite(&cipher_suite2); EXPECT_TRUE(ret); EXPECT_EQ(cipher_suite1, cipher_suite2); std::cerr << "Connected with version " << client_->version() << " cipher suite " << client_->cipher_suite_name() << std::endl; // Check and store session ids. std::vector sid_c1 = client_->session_id(); EXPECT_EQ(32U, sid_c1.size()); std::vector sid_s1 = server_->session_id(); EXPECT_EQ(32U, sid_s1.size()); EXPECT_EQ(sid_c1, sid_s1); session_ids_.push_back(sid_c1); CheckResumption(expected_resumption_mode_); // Check whether the extended master secret extension was negotiated. CheckExtendedMasterSecret(); } void TlsConnectTestBase::ConnectExpectFail() { server_->StartConnect(); client_->StartConnect(); Handshake(); ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); } void TlsConnectTestBase::SetExpectedVersion(uint16_t version) { client_->SetExpectedVersion(version); server_->SetExpectedVersion(version); } void TlsConnectTestBase::DisableDheCiphers() { client_->DisableCiphersByKeyExchange(ssl_kea_dh); server_->DisableCiphersByKeyExchange(ssl_kea_dh); } void TlsConnectTestBase::DisableEcdheCiphers() { client_->DisableCiphersByKeyExchange(ssl_kea_ecdh); server_->DisableCiphersByKeyExchange(ssl_kea_ecdh); } void TlsConnectTestBase::DisableDheAndEcdheCiphers() { DisableDheCiphers(); DisableEcdheCiphers(); } void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, SessionResumptionMode server) { client_->ConfigureSessionCache(client); server_->ConfigureSessionCache(server); } void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { EXPECT_NE(RESUME_BOTH, expected); int resume_ct = expected ? 1 : 0; int stateless_ct = (expected & RESUME_TICKET) ? 1 : 0; SSL3Statistics* stats = SSL_GetStatistics(); EXPECT_EQ(resume_ct, stats->hch_sid_cache_hits); EXPECT_EQ(resume_ct, stats->hsh_sid_cache_hits); EXPECT_EQ(stateless_ct, stats->hch_sid_stateless_resumes); EXPECT_EQ(stateless_ct, stats->hsh_sid_stateless_resumes); if (resume_ct) { // Check that the last two session ids match. EXPECT_GE(2U, session_ids_.size()); EXPECT_EQ(session_ids_[session_ids_.size()-1], session_ids_[session_ids_.size()-2]); } } void TlsConnectTestBase::EnableAlpn() { // A simple value of "a", "b". Note that the preferred value of "a" is placed // at the end, because the NSS API follows the now defunct NPN specification, // which places the preferred (and default) entry at the end of the list. // NSS will move this final entry to the front when used with ALPN. static const uint8_t val[] = { 0x01, 0x62, 0x01, 0x61 }; client_->EnableAlpn(val, sizeof(val)); server_->EnableAlpn(val, sizeof(val)); } void TlsConnectTestBase::EnableSrtp() { client_->EnableSrtp(); server_->EnableSrtp(); } void TlsConnectTestBase::CheckSrtp() const { client_->CheckSrtp(); server_->CheckSrtp(); } void TlsConnectTestBase::SendReceive() { client_->SendData(50); server_->SendData(50); WAIT_(client_->received_bytes() == 50U && server_->received_bytes() == 50U, 2000); ASSERT_EQ(50U, client_->received_bytes()); ASSERT_EQ(50U, server_->received_bytes()); } void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) { expect_extended_master_secret_ = expected; } void TlsConnectTestBase::CheckExtendedMasterSecret() { client_->CheckExtendedMasterSecret(expect_extended_master_secret_); server_->CheckExtendedMasterSecret(expect_extended_master_secret_); } TlsConnectGeneric::TlsConnectGeneric() : TlsConnectTestBase(TlsConnectTestBase::ToMode(std::get<0>(GetParam())), std::get<1>(GetParam())) {} TlsConnectPre12::TlsConnectPre12() : TlsConnectTestBase(TlsConnectTestBase::ToMode(std::get<0>(GetParam())), std::get<1>(GetParam())) {} TlsConnectTls12::TlsConnectTls12() : TlsConnectTestBase(TlsConnectTestBase::ToMode(GetParam()), SSL_LIBRARY_VERSION_TLS_1_2) {} } // namespace nss_test