diff options
Diffstat (limited to 'nss/external_tests/ssl_gtest')
22 files changed, 0 insertions, 4660 deletions
diff --git a/nss/external_tests/ssl_gtest/Makefile b/nss/external_tests/ssl_gtest/Makefile deleted file mode 100644 index e3bf89d..0000000 --- a/nss/external_tests/ssl_gtest/Makefile +++ /dev/null @@ -1,60 +0,0 @@ -#! gmake -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. - -####################################################################### -# (1) Include initial platform-independent assignments (MANDATORY). # -####################################################################### - -include manifest.mn - -####################################################################### -# (2) Include "global" configuration information. (OPTIONAL) # -####################################################################### - -include $(CORE_DEPTH)/coreconf/config.mk - -####################################################################### -# (3) Include "component" configuration information. (OPTIONAL) # -####################################################################### - - -####################################################################### -# (4) Include "local" platform-dependent assignments (OPTIONAL). # -####################################################################### - -include ../../cmd/platlibs.mk - -####################################################################### -# (5) Execute "global" rules. (OPTIONAL) # -####################################################################### - -include $(CORE_DEPTH)/coreconf/rules.mk - -####################################################################### -# (6) Execute "component" rules. (OPTIONAL) # -####################################################################### - - -####################################################################### -# (7) Execute "local" rules. (OPTIONAL). # -####################################################################### - -MKPROG = $(CCC) -CFLAGS += -I$(CORE_DEPTH)/lib/ssl - -include ../../cmd/platrules.mk - -ifeq (WINNT,$(OS_ARCH)) - # -EHsc because gtest has exception handlers - OS_CFLAGS += -EHsc -nologo - # http://www.suodenjoki.dk/us/archive/2010/min-max.htm - OS_CFLAGS += -DNOMINMAX - - # Linking to winsock to get htonl - OS_LIBS += Ws2_32.lib -else - CXXFLAGS += -std=c++0x -endif diff --git a/nss/external_tests/ssl_gtest/databuffer.h b/nss/external_tests/ssl_gtest/databuffer.h deleted file mode 100644 index 832b8c3..0000000 --- a/nss/external_tests/ssl_gtest/databuffer.h +++ /dev/null @@ -1,171 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef databuffer_h__ -#define databuffer_h__ - -#include <algorithm> -#include <cassert> -#include <cstring> -#include <iomanip> -#include <iostream> -#if defined(WIN32) || defined(WIN64) -#include <winsock2.h> -#else -#include <arpa/inet.h> -#endif - -namespace nss_test { - -class DataBuffer { - public: - DataBuffer() : data_(nullptr), len_(0) {} - DataBuffer(const uint8_t *data, size_t len) : data_(nullptr), len_(0) { - Assign(data, len); - } - explicit DataBuffer(const DataBuffer& other) : data_(nullptr), len_(0) { - Assign(other); - } - ~DataBuffer() { delete[] data_; } - - DataBuffer& operator=(const DataBuffer& other) { - if (&other != this) { - Assign(other); - } - return *this; - } - - void Allocate(size_t len) { - delete[] data_; - data_ = new uint8_t[len ? len : 1]; // Don't depend on new [0]. - len_ = len; - } - - void Truncate(size_t len) { - len_ = std::min(len_, len); - } - - void Assign(const DataBuffer& other) { - Assign(other.data(), other.len()); - } - void Assign(const uint8_t* data, size_t len) { - Allocate(len); - memcpy(static_cast<void *>(data_), static_cast<const void *>(data), len); - } - - // Write will do a new allocation and expand the size of the buffer if needed. - void Write(size_t index, const uint8_t* val, size_t count) { - if (index + count > len_) { - size_t newlen = index + count; - uint8_t* tmp = new uint8_t[newlen]; // Always > 0. - memcpy(static_cast<void*>(tmp), - static_cast<const void*>(data_), len_); - if (index > len_) { - memset(static_cast<void*>(tmp + len_), 0, index - len_); - } - delete[] data_; - data_ = tmp; - len_ = newlen; - } - memcpy(static_cast<void*>(data_ + index), - static_cast<const void*>(val), count); - } - - void Write(size_t index, const DataBuffer& buf) { - Write(index, buf.data(), buf.len()); - } - - // Write an integer, also performing host-to-network order conversion. - void Write(size_t index, uint32_t val, size_t count) { - assert(count <= sizeof(uint32_t)); - uint32_t nvalue = htonl(val); - auto* addr = reinterpret_cast<const uint8_t*>(&nvalue); - Write(index, addr + sizeof(uint32_t) - count, count); - } - - // This can't use the same trick as Write(), since we might be reading from a - // smaller data source. - bool Read(size_t index, size_t count, uint32_t* val) const { - assert(count < sizeof(uint32_t)); - assert(val); - if ((index > len()) || (count > (len() - index))) { - return false; - } - *val = 0; - for (size_t i = 0; i < count; ++i) { - *val = (*val << 8) | data()[index + i]; - } - return true; - } - - // Starting at |index|, remove |remove| bytes and replace them with the - // contents of |buf|. - void Splice(const DataBuffer& buf, size_t index, size_t remove = 0) { - Splice(buf.data(), buf.len(), index, remove); - } - - void Splice(const uint8_t* ins, size_t ins_len, size_t index, size_t remove = 0) { - uint8_t* old_value = data_; - size_t old_len = len_; - - // The amount of stuff remaining from the tail of the old. - size_t tail_len = old_len - std::min(old_len, index + remove); - // The new length: the head of the old, the new, and the tail of the old. - len_ = index + ins_len + tail_len; - data_ = new uint8_t[len_ ? len_ : 1]; - - // The head of the old. - Write(0, old_value, std::min(old_len, index)); - // Maybe a gap. - if (index > old_len) { - memset(old_value + index, 0, index - old_len); - } - // The new. - Write(index, ins, ins_len); - // The tail of the old. - if (tail_len > 0) { - Write(index + ins_len, - old_value + index + remove, tail_len); - } - - delete[] old_value; - } - - void Append(const DataBuffer& buf) { Splice(buf, len_); } - - const uint8_t *data() const { return data_; } - uint8_t* data() { return data_; } - size_t len() const { return len_; } - bool empty() const { return len_ == 0; } - - private: - uint8_t* data_; - size_t len_; -}; - -#ifdef DEBUG -static const size_t kMaxBufferPrint = 10000; -#else -static const size_t kMaxBufferPrint = 32; -#endif - -inline std::ostream& operator<<(std::ostream& stream, const DataBuffer& buf) { - stream << "[" << buf.len() << "] "; - for (size_t i = 0; i < buf.len(); ++i) { - if (i >= kMaxBufferPrint) { - stream << "..."; - break; - } - stream << std::hex << std::setfill('0') << std::setw(2) - << static_cast<unsigned>(buf.data()[i]); - } - stream << std::dec; - return stream; -} - -} // namespace nss_test - -#endif diff --git a/nss/external_tests/ssl_gtest/gtest_utils.h b/nss/external_tests/ssl_gtest/gtest_utils.h deleted file mode 100644 index 019ccfd..0000000 --- a/nss/external_tests/ssl_gtest/gtest_utils.h +++ /dev/null @@ -1,58 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef gtest_utils_h__ -#define gtest_utils_h__ - -#define GTEST_HAS_RTTI 0 -#include "gtest/gtest.h" -#include "test_io.h" - -namespace nss_test { - -// Gtest utilities -class Timeout : public PollTarget { - public: - Timeout(int32_t timer_ms) : handle_(nullptr), timed_out_(false) { - Poller::Instance()->SetTimer(timer_ms, this, &Timeout::ExpiredCallback, - &handle_); - } - ~Timeout() { - Cancel(); - } - - static void ExpiredCallback(PollTarget* target, Event event) { - Timeout* timeout = static_cast<Timeout*>(target); - timeout->timed_out_ = true; - } - - void Cancel() { handle_->Cancel(); } - - bool timed_out() const { return timed_out_; } - - private: - Poller::Timer* handle_; - bool timed_out_; -}; - -} // namespace nss_test - -#define WAIT_(expression, timeout) \ - do { \ - Timeout tm(timeout); \ - while (!(expression)) { \ - Poller::Instance()->Poll(); \ - if (tm.timed_out()) break; \ - } \ - } while (0) - -#define ASSERT_TRUE_WAIT(expression, timeout) \ - do { \ - WAIT_(expression, timeout); \ - ASSERT_TRUE(expression); \ - } while (0) - -#endif diff --git a/nss/external_tests/ssl_gtest/libssl_internals.c b/nss/external_tests/ssl_gtest/libssl_internals.c deleted file mode 100644 index db83ef6..0000000 --- a/nss/external_tests/ssl_gtest/libssl_internals.c +++ /dev/null @@ -1,26 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -/* This file contains functions for frobbing the internals of libssl */ -#include "libssl_internals.h" - -#include "seccomon.h" -#include "ssl.h" -#include "sslimpl.h" - -SECStatus -SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd) -{ - sslSocket *ss = (sslSocket *)fd->secret; - - if (!ss) { - return SECFailure; - } - - ++ss->clientHelloVersion; - - return SECSuccess; -} diff --git a/nss/external_tests/ssl_gtest/libssl_internals.h b/nss/external_tests/ssl_gtest/libssl_internals.h deleted file mode 100644 index db6d0af..0000000 --- a/nss/external_tests/ssl_gtest/libssl_internals.h +++ /dev/null @@ -1,17 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef libssl_internals_h_ -#define libssl_internals_h_ - -#include "prio.h" -#include "seccomon.h" - -SECStatus SSLInt_IncrementClientHandshakeVersion(PRFileDesc *fd); - -#endif - - diff --git a/nss/external_tests/ssl_gtest/manifest.mn b/nss/external_tests/ssl_gtest/manifest.mn deleted file mode 100644 index 6d70c0b..0000000 --- a/nss/external_tests/ssl_gtest/manifest.mn +++ /dev/null @@ -1,33 +0,0 @@ -# -# This Source Code Form is subject to the terms of the Mozilla Public -# License, v. 2.0. If a copy of the MPL was not distributed with this -# file, You can obtain one at http://mozilla.org/MPL/2.0/. -CORE_DEPTH = ../.. -DEPTH = ../.. -MODULE = nss - -# These sources have access to libssl internals -CSRCS = \ - libssl_internals.c \ - $(NULL) - -CPPSRCS = \ - ssl_agent_unittest.cc \ - ssl_loopback_unittest.cc \ - ssl_extension_unittest.cc \ - ssl_prf_unittest.cc \ - ssl_skip_unittest.cc \ - ssl_gtest.cc \ - test_io.cc \ - tls_agent.cc \ - tls_connect.cc \ - tls_filter.cc \ - tls_parser.cc \ - $(NULL) - -INCLUDES += -I$(CORE_DEPTH)/external_tests/google_test/gtest/include - -REQUIRES = nspr nss libdbm gtest - -PROGRAM = ssl_gtest -EXTRA_LIBS = $(DIST)/lib/$(LIB_PREFIX)gtest.$(LIB_SUFFIX) diff --git a/nss/external_tests/ssl_gtest/ssl_agent_unittest.cc b/nss/external_tests/ssl_gtest/ssl_agent_unittest.cc deleted file mode 100644 index d67bf56..0000000 --- a/nss/external_tests/ssl_gtest/ssl_agent_unittest.cc +++ /dev/null @@ -1,58 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "ssl.h" -#include "sslerr.h" -#include "sslproto.h" - -#include <memory> - -#include "databuffer.h" -#include "tls_agent.h" -#include "tls_connect.h" -#include "tls_parser.h" - -namespace nss_test { - -void MakeTrivialHandshakeMessage(uint8_t hs_type, size_t hs_len, - DataBuffer* out) { - size_t total_len = 5 + 4 + hs_len; - - out->Allocate(total_len); - - size_t index = 0; - out->Write(index, kTlsHandshakeType, 1); ++index; // Content Type - out->Write(index, 3, 1); ++index; // Version high - out->Write(index, 1, 1); ++index; // Version low - out->Write(index, 4 + hs_len, 2); index += 2; // Length - - out->Write(index, hs_type, 1); ++index; // Handshake record type. - out->Write(index, hs_len, 3); index += 3; // Handshake length - for (; index < total_len; ++index) { - out->Write(index, 1, 1); - } -} - -TEST_P(TlsAgentTest, EarlyFinished) { - DataBuffer buffer; - MakeTrivialHandshakeMessage(kTlsHandshakeFinished, 0, &buffer); - ProcessMessage(buffer, TlsAgent::STATE_ERROR, - SSL_ERROR_RX_UNEXPECTED_FINISHED); -} - -TEST_P(TlsAgentTest, EarlyCertificateVerify) { - DataBuffer buffer; - MakeTrivialHandshakeMessage(kTlsHandshakeCertificateVerify, 0, &buffer); - ProcessMessage(buffer, TlsAgent::STATE_ERROR, - SSL_ERROR_RX_UNEXPECTED_CERT_VERIFY); -} - -INSTANTIATE_TEST_CASE_P(AgentTests, TlsAgentTest, - ::testing::Combine( - TlsAgentTestBase::kTlsRolesAll, - TlsConnectTestBase::kTlsModesStream)); - -} // namespace nss_test diff --git a/nss/external_tests/ssl_gtest/ssl_extension_unittest.cc b/nss/external_tests/ssl_gtest/ssl_extension_unittest.cc deleted file mode 100644 index b8e0adf..0000000 --- a/nss/external_tests/ssl_gtest/ssl_extension_unittest.cc +++ /dev/null @@ -1,625 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "ssl.h" -#include "sslproto.h" - -#include <memory> - -#include "tls_parser.h" -#include "tls_filter.h" -#include "tls_connect.h" - -namespace nss_test { - -class TlsExtensionFilter : public TlsHandshakeFilter { - protected: - virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, - const DataBuffer& input, DataBuffer* output) { - if (handshake_type == kTlsHandshakeClientHello) { - TlsParser parser(input); - if (!FindClientHelloExtensions(parser, version)) { - return false; - } - return FilterExtensions(parser, input, output); - } - if (handshake_type == kTlsHandshakeServerHello) { - TlsParser parser(input); - if (!FindServerHelloExtensions(parser, version)) { - return false; - } - return FilterExtensions(parser, input, output); - } - return false; - } - - virtual bool FilterExtension(uint16_t extension_type, - const DataBuffer& input, DataBuffer* output) = 0; - - public: - static bool FindClientHelloExtensions(TlsParser& parser, uint16_t version) { - if (!parser.Skip(2 + 32)) { // version + random - return false; - } - if (!parser.SkipVariable(1)) { // session ID - return false; - } - if (IsDtls(version) && !parser.SkipVariable(1)) { // DTLS cookie - return false; - } - if (!parser.SkipVariable(2)) { // cipher suites - return false; - } - if (!parser.SkipVariable(1)) { // compression methods - return false; - } - return true; - } - - static bool FindServerHelloExtensions(TlsParser& parser, uint16_t version) { - if (!parser.Skip(2 + 32)) { // version + random - return false; - } - if (!parser.SkipVariable(1)) { // session ID - return false; - } - if (!parser.Skip(2)) { // cipher suite - return false; - } - if (NormalizeTlsVersion(version) <= SSL_LIBRARY_VERSION_TLS_1_2) { - if (!parser.Skip(1)) { // compression method - return false; - } - } - return true; - } - - private: - bool FilterExtensions(TlsParser& parser, - const DataBuffer& input, DataBuffer* output) { - size_t length_offset = parser.consumed(); - uint32_t all_extensions; - if (!parser.Read(&all_extensions, 2)) { - return false; // no extensions, odd but OK - } - if (all_extensions != parser.remaining()) { - return false; // malformed - } - - bool changed = false; - - // Write out the start of the message. - output->Allocate(input.len()); - output->Write(0, input.data(), parser.consumed()); - size_t output_offset = parser.consumed(); - - while (parser.remaining()) { - uint32_t extension_type; - if (!parser.Read(&extension_type, 2)) { - return false; // malformed - } - - // Copy extension type. - output->Write(output_offset, extension_type, 2); - - DataBuffer extension; - if (!parser.ReadVariable(&extension, 2)) { - return false; // malformed - } - output_offset = ApplyFilter(static_cast<uint16_t>(extension_type), extension, - output, output_offset + 2, &changed); - } - output->Truncate(output_offset); - - if (changed) { - size_t newlen = output->len() - length_offset - 2; - if (newlen >= 0x10000) { - return false; // bad: size increased too much - } - output->Write(length_offset, newlen, 2); - } - return changed; - } - - size_t ApplyFilter(uint16_t extension_type, const DataBuffer& extension, - DataBuffer* output, size_t offset, bool* changed) { - const DataBuffer* source = &extension; - DataBuffer filtered; - if (FilterExtension(extension_type, extension, &filtered) && - filtered.len() < 0x10000) { - *changed = true; - std::cerr << "extension old: " << extension << std::endl; - std::cerr << "extension new: " << filtered << std::endl; - source = &filtered; - } - - output->Write(offset, source->len(), 2); - output->Write(offset + 2, *source); - return offset + 2 + source->len(); - } -}; - -class TlsExtensionTruncator : public TlsExtensionFilter { - public: - TlsExtensionTruncator(uint16_t extension, size_t length) - : extension_(extension), length_(length) {} - virtual bool FilterExtension(uint16_t extension_type, - const DataBuffer& input, DataBuffer* output) { - if (extension_type != extension_) { - return false; - } - if (input.len() <= length_) { - return false; - } - - output->Assign(input.data(), length_); - return true; - } - private: - uint16_t extension_; - size_t length_; -}; - -class TlsExtensionDamager : public TlsExtensionFilter { - public: - TlsExtensionDamager(uint16_t extension, size_t index) - : extension_(extension), index_(index) {} - virtual bool FilterExtension(uint16_t extension_type, - const DataBuffer& input, DataBuffer* output) { - if (extension_type != extension_) { - return false; - } - - *output = input; - output->data()[index_] += 73; // Increment selected for maximum damage - return true; - } - private: - uint16_t extension_; - size_t index_; -}; - -class TlsExtensionReplacer : public TlsExtensionFilter { - public: - TlsExtensionReplacer(uint16_t extension, const DataBuffer& data) - : extension_(extension), data_(data) {} - virtual bool FilterExtension(uint16_t extension_type, - const DataBuffer& input, DataBuffer* output) { - if (extension_type != extension_) { - return false; - } - - *output = data_; - return true; - } - private: - const uint16_t extension_; - const DataBuffer data_; -}; - -class TlsExtensionInjector : public TlsHandshakeFilter { - public: - TlsExtensionInjector(uint16_t ext, DataBuffer& data) - : extension_(ext), data_(data) {} - - virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, - const DataBuffer& input, DataBuffer* output) { - size_t offset; - if (handshake_type == kTlsHandshakeClientHello) { - TlsParser parser(input); - if (!TlsExtensionFilter::FindClientHelloExtensions(parser, version)) { - return false; - } - offset = parser.consumed(); - } else if (handshake_type == kTlsHandshakeServerHello) { - TlsParser parser(input); - if (!TlsExtensionFilter::FindServerHelloExtensions(parser, version)) { - return false; - } - offset = parser.consumed(); - } else { - return false; - } - - *output = input; - - std::cerr << "Pre:" << input << std::endl; - std::cerr << "Lof:" << offset << std::endl; - - // Increase the size of the extensions. - uint16_t* len_addr = reinterpret_cast<uint16_t*>(output->data() + offset); - std::cerr << "L-p:" << ntohs(*len_addr) << std::endl; - *len_addr = htons(ntohs(*len_addr) + data_.len() + 4); - std::cerr << "L-i:" << ntohs(*len_addr) << std::endl; - - - // Insert the extension type and length. - DataBuffer type_length; - type_length.Allocate(4); - type_length.Write(0, extension_, 2); - type_length.Write(2, data_.len(), 2); - output->Splice(type_length, offset + 2); - - // Insert the payload. - output->Splice(data_, offset + 6); - - std::cerr << "Aft:" << *output << std::endl; - return true; - } - - private: - const uint16_t extension_; - const DataBuffer data_; -}; - -class TlsExtensionCapture : public TlsExtensionFilter { - public: - TlsExtensionCapture(uint16_t ext) - : extension_(ext), data_() {} - - virtual bool FilterExtension(uint16_t extension_type, - const DataBuffer& input, DataBuffer* output) { - if (extension_type == extension_) { - data_.Assign(input); - } - return false; - } - - const DataBuffer& extension() const { return data_; } - - private: - const uint16_t extension_; - DataBuffer data_; -}; - -class TlsExtensionTestBase : public TlsConnectTestBase { - protected: - TlsExtensionTestBase(Mode mode, uint16_t version) - : TlsConnectTestBase(mode, version) {} - - void ClientHelloErrorTest(PacketFilter* filter, - uint8_t alert = kTlsAlertDecodeError) { - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - if (filter) { - client_->SetPacketFilter(filter); - } - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); - } - - void ServerHelloErrorTest(PacketFilter* filter, - uint8_t alert = kTlsAlertDecodeError) { - auto alert_recorder = new TlsAlertRecorder(); - client_->SetPacketFilter(alert_recorder); - if (filter) { - server_->SetPacketFilter(filter); - } - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); - } - - static void InitSimpleSni(DataBuffer* extension) { - const char* name = "host.name"; - const size_t namelen = PL_strlen(name); - extension->Allocate(namelen + 5); - extension->Write(0, namelen + 3, 2); - extension->Write(2, static_cast<uint32_t>(0), 1); // 0 == hostname - extension->Write(3, namelen, 2); - extension->Write(5, reinterpret_cast<const uint8_t*>(name), namelen); - } -}; - -class TlsExtensionTestDtls - : public TlsExtensionTestBase, - public ::testing::WithParamInterface<uint16_t> { - public: - TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {} -}; - -class TlsExtensionTest12Plus - : public TlsExtensionTestBase, - public ::testing::WithParamInterface<std::string> { - public: - TlsExtensionTest12Plus() - : TlsExtensionTestBase(TlsConnectTestBase::ToMode(GetParam()), - SSL_LIBRARY_VERSION_TLS_1_2) {} -}; - -class TlsExtensionTestGeneric - : public TlsExtensionTestBase, - public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { - public: - TlsExtensionTestGeneric() - : TlsExtensionTestBase(TlsConnectTestBase::ToMode((std::get<0>(GetParam()))), - std::get<1>(GetParam())) {} -}; - -TEST_P(TlsExtensionTestGeneric, DamageSniLength) { - ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 1)); -} - -TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) { - ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 4)); -} - -TEST_P(TlsExtensionTestGeneric, TruncateSni) { - ClientHelloErrorTest(new TlsExtensionTruncator(ssl_server_name_xtn, 7)); -} - -// A valid extension that appears twice will be reported as unsupported. -TEST_P(TlsExtensionTestGeneric, RepeatSni) { - DataBuffer extension; - InitSimpleSni(&extension); - ClientHelloErrorTest(new TlsExtensionInjector(ssl_server_name_xtn, extension), - kTlsAlertIllegalParameter); -} - -// An SNI entry with zero length is considered invalid (strangely, not if it is -// the last entry, which is probably a bug). -TEST_P(TlsExtensionTestGeneric, BadSni) { - DataBuffer simple; - InitSimpleSni(&simple); - DataBuffer extension; - extension.Allocate(simple.len() + 3); - extension.Write(0, static_cast<uint32_t>(0), 3); - extension.Write(3, simple); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_server_name_xtn, extension)); -} - -TEST_P(TlsExtensionTestGeneric, EmptySni) { - DataBuffer extension; - extension.Allocate(2); - extension.Write(0, static_cast<uint32_t>(0), 2); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_server_name_xtn, extension)); -} - -TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) { - EnableAlpn(); - DataBuffer extension; - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension), - kTlsAlertIllegalParameter); -} - -// An empty ALPN isn't considered bad, though it does lead to there being no -// protocol for the server to select. -TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) { - EnableAlpn(); - const uint8_t val[] = { 0x00, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension), - kTlsAlertNoApplicationProtocol); -} - -TEST_P(TlsExtensionTestGeneric, OneByteAlpn) { - EnableAlpn(); - ClientHelloErrorTest(new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 1)); -} - -TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) { - EnableAlpn(); - // This will leave the length of the second entry, but no value. - ClientHelloErrorTest(new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 5)); -} - -TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) { - EnableAlpn(); - const uint8_t val[] = { 0x01, 0x61, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); -} - -TEST_P(TlsExtensionTestGeneric, AlpnMismatch) { - const uint8_t client_alpn[] = { 0x01, 0x61 }; - client_->EnableAlpn(client_alpn, sizeof(client_alpn)); - const uint8_t server_alpn[] = { 0x02, 0x61, 0x62 }; - server_->EnableAlpn(server_alpn, sizeof(server_alpn)); - - ClientHelloErrorTest(nullptr, kTlsAlertNoApplicationProtocol); -} - -TEST_P(TlsExtensionTestGeneric, AlpnReturnedEmptyList) { - EnableAlpn(); - const uint8_t val[] = { 0x00, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); -} - -TEST_P(TlsExtensionTestGeneric, AlpnReturnedEmptyName) { - EnableAlpn(); - const uint8_t val[] = { 0x00, 0x01, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); -} - -TEST_P(TlsExtensionTestGeneric, AlpnReturnedListTrailingData) { - EnableAlpn(); - const uint8_t val[] = { 0x00, 0x02, 0x01, 0x61, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); -} - -TEST_P(TlsExtensionTestGeneric, AlpnReturnedExtraEntry) { - EnableAlpn(); - const uint8_t val[] = { 0x00, 0x04, 0x01, 0x61, 0x01, 0x62 }; - DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); -} - -TEST_P(TlsExtensionTestGeneric, AlpnReturnedBadListLength) { - EnableAlpn(); - const uint8_t val[] = { 0x00, 0x99, 0x01, 0x61, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); -} - -TEST_P(TlsExtensionTestGeneric, AlpnReturnedBadNameLength) { - EnableAlpn(); - const uint8_t val[] = { 0x00, 0x02, 0x99, 0x61 }; - DataBuffer extension(val, sizeof(val)); - ServerHelloErrorTest(new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension)); -} - -TEST_P(TlsExtensionTestDtls, SrtpShort) { - EnableSrtp(); - ClientHelloErrorTest(new TlsExtensionTruncator(ssl_use_srtp_xtn, 3)); -} - -TEST_P(TlsExtensionTestDtls, SrtpOdd) { - EnableSrtp(); - const uint8_t val[] = { 0x00, 0x01, 0xff, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_use_srtp_xtn, extension)); -} - -TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) { - const uint8_t val[] = { 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, - extension)); -} - -TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) { - const uint8_t val[] = { 0x00, 0x02, 0x04, 0x01, 0x00 }; // sha-256, rsa - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, - extension)); -} - -TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) { - const uint8_t val[] = { 0x00, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, - extension)); -} - -TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) { - const uint8_t val[] = { 0x00, 0x01, 0x04 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, - extension)); -} - -// The extension handling ignores unsupported hashes, so breaking that has no -// effect on success rates. However, ssl3_SendServerKeyExchange catches an -// unsupported signature algorithm. - -// This actually fails with a decryption error (fatal alert 51). That's a bad -// to fail, since any tampering with the handshake will trigger that alert when -// verifying the Finished message. Thus, this test is disabled until this error -// is turned into an alert. -TEST_P(TlsExtensionTest12Plus, DISABLED_SignatureAlgorithmsSigUnsupported) { - const uint8_t val[] = { 0x00, 0x02, 0x04, 0x99 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_signature_algorithms_xtn, - extension)); -} - -TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) { - const uint8_t val[] = { 0x00, 0x01, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn, - extension)); -} - -TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) { - const uint8_t val[] = { 0x09, 0x99, 0x00, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn, - extension)); -} - -TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) { - const uint8_t val[] = { 0x00, 0x02, 0x00, 0x00, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_elliptic_curves_xtn, - extension)); -} - -TEST_P(TlsExtensionTestGeneric, SupportedPointsEmpty) { - const uint8_t val[] = { 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn, - extension)); -} - -TEST_P(TlsExtensionTestGeneric, SupportedPointsBadLength) { - const uint8_t val[] = { 0x99, 0x00, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn, - extension)); -} - -TEST_P(TlsExtensionTestGeneric, SupportedPointsTrailingData) { - const uint8_t val[] = { 0x01, 0x00, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_ec_point_formats_xtn, - extension)); -} - -TEST_P(TlsExtensionTestGeneric, RenegotiationInfoBadLength) { - const uint8_t val[] = { 0x99 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn, - extension)); -} - -TEST_P(TlsExtensionTestGeneric, RenegotiationInfoMismatch) { - const uint8_t val[] = { 0x01, 0x00 }; - DataBuffer extension(val, sizeof(val)); - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn, - extension)); -} - -// The extension has to contain a length. -TEST_P(TlsExtensionTestGeneric, RenegotiationInfoExtensionEmpty) { - DataBuffer extension; - ClientHelloErrorTest(new TlsExtensionReplacer(ssl_renegotiation_info_xtn, - extension)); -} - -TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmConfiguration) { - const SSLSignatureAndHashAlg algorithms[] = { - {ssl_hash_sha512, ssl_sign_rsa}, - {ssl_hash_sha384, ssl_sign_ecdsa} - }; - - TlsExtensionCapture *capture = - new TlsExtensionCapture(ssl_signature_algorithms_xtn); - client_->SetSignatureAlgorithms(algorithms, PR_ARRAY_SIZE(algorithms)); - client_->SetPacketFilter(capture); - DisableDheAndEcdheCiphers(); - Connect(); - - const DataBuffer& ext = capture->extension(); - EXPECT_EQ(2 + PR_ARRAY_SIZE(algorithms) * 2, ext.len()); - for (size_t i = 0, cursor = 2; - i < PR_ARRAY_SIZE(algorithms) && cursor < ext.len(); - ++i) { - uint32_t v; - EXPECT_TRUE(ext.Read(cursor++, 1, &v)); - EXPECT_EQ(algorithms[i].hashAlg, static_cast<SSLHashType>(v)); - EXPECT_TRUE(ext.Read(cursor++, 1, &v)); - EXPECT_EQ(algorithms[i].sigAlg, static_cast<SSLSignType>(v)); - } -} - -INSTANTIATE_TEST_CASE_P(ExtensionTls10, TlsExtensionTestGeneric, - ::testing::Combine( - TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10)); -INSTANTIATE_TEST_CASE_P(ExtensionVariants, TlsExtensionTestGeneric, - ::testing::Combine( - TlsConnectTestBase::kTlsModesAll, - TlsConnectTestBase::kTlsV11V12)); -INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus, - TlsConnectTestBase::kTlsModesAll); -INSTANTIATE_TEST_CASE_P(ExtensionDgram, TlsExtensionTestDtls, - TlsConnectTestBase::kTlsV11V12); - -} // namespace nspr_test diff --git a/nss/external_tests/ssl_gtest/ssl_gtest.cc b/nss/external_tests/ssl_gtest/ssl_gtest.cc deleted file mode 100644 index ee1c40c..0000000 --- a/nss/external_tests/ssl_gtest/ssl_gtest.cc +++ /dev/null @@ -1,39 +0,0 @@ -#include "nspr.h" -#include "nss.h" -#include "ssl.h" - -#include <cstdlib> - -#include "test_io.h" - -#define GTEST_HAS_RTTI 0 -#include "gtest/gtest.h" - -std::string g_working_dir_path; - -int main(int argc, char **argv) { - // Start the tests - ::testing::InitGoogleTest(&argc, argv); - g_working_dir_path = "."; - - char* workdir = getenv("NSS_GTEST_WORKDIR"); - if (workdir) - g_working_dir_path = workdir; - - for (int i = 0; i < argc; i++) { - if (!strcmp(argv[i], "-d")) { - g_working_dir_path = argv[i + 1]; - ++i; - } - } - - NSS_Initialize(g_working_dir_path.c_str(), "", "", SECMOD_DB, NSS_INIT_READONLY); - NSS_SetDomesticPolicy(); - int rv = RUN_ALL_TESTS(); - - NSS_Shutdown(); - - nss_test::Poller::Shutdown(); - - return rv; -} diff --git a/nss/external_tests/ssl_gtest/ssl_loopback_unittest.cc b/nss/external_tests/ssl_gtest/ssl_loopback_unittest.cc deleted file mode 100644 index 13b50a7..0000000 --- a/nss/external_tests/ssl_gtest/ssl_loopback_unittest.cc +++ /dev/null @@ -1,669 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "ssl.h" -#include "sslerr.h" -#include "sslproto.h" -#include <memory> - -extern "C" { -// This is not something that should make you happy. -#include "libssl_internals.h" -} - -#include "tls_parser.h" -#include "tls_filter.h" -#include "tls_connect.h" -#include "gtest_utils.h" - -namespace nss_test { - -uint8_t kBogusClientKeyExchange[] = { - 0x01, 0x00, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, -}; - -// When we see the ClientKeyExchange from |client|, increment the -// ClientHelloVersion on |server|. -class TlsInspectorClientHelloVersionChanger : public TlsHandshakeFilter { - public: - TlsInspectorClientHelloVersionChanger(TlsAgent* server) : server_(server) {} - - virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, - const DataBuffer& input, DataBuffer* output) { - if (handshake_type == kTlsHandshakeClientKeyExchange) { - EXPECT_EQ( - SECSuccess, - SSLInt_IncrementClientHandshakeVersion(server_->ssl_fd())); - } - return false; - } - - private: - TlsAgent* server_; -}; - -class TlsServerKeyExchangeEcdhe { - public: - bool Parse(const DataBuffer& buffer) { - TlsParser parser(buffer); - - uint8_t curve_type; - if (!parser.Read(&curve_type)) { - return false; - } - - if (curve_type != 3) { // named_curve - return false; - } - - uint32_t named_curve; - if (!parser.Read(&named_curve, 2)) { - return false; - } - - return parser.ReadVariable(&public_key_, 1); - } - - DataBuffer public_key_; -}; - -TEST_P(TlsConnectGeneric, SetupOnly) {} - -TEST_P(TlsConnectGeneric, Connect) { - SetExpectedVersion(std::get<1>(GetParam())); - Connect(); - client_->CheckAuthType(ssl_auth_rsa); -} - -TEST_P(TlsConnectGeneric, ConnectEcdsa) { - SetExpectedVersion(std::get<1>(GetParam())); - ResetEcdsa(); - Connect(); - client_->CheckAuthType(ssl_auth_ecdsa); -} - -TEST_P(TlsConnectGeneric, ConnectFalseStart) { - client_->EnableFalseStart(); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectResumed) { - ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); - Connect(); - - ResetRsa(); - ExpectResumption(RESUME_SESSIONID); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectClientCacheDisabled) { - ConfigureSessionCache(RESUME_NONE, RESUME_SESSIONID); - Connect(); - ResetRsa(); - ExpectResumption(RESUME_NONE); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectServerCacheDisabled) { - ConfigureSessionCache(RESUME_SESSIONID, RESUME_NONE); - Connect(); - ResetRsa(); - ExpectResumption(RESUME_NONE); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectSessionCacheDisabled) { - ConfigureSessionCache(RESUME_NONE, RESUME_NONE); - Connect(); - ResetRsa(); - ExpectResumption(RESUME_NONE); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectResumeSupportBoth) { - // This prefers tickets. - ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); - Connect(); - - ResetRsa(); - ConfigureSessionCache(RESUME_BOTH, RESUME_BOTH); - ExpectResumption(RESUME_TICKET); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectResumeClientTicketServerBoth) { - // This causes no resumption because the client needs the - // session cache to resume even with tickets. - ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH); - Connect(); - - ResetRsa(); - ConfigureSessionCache(RESUME_TICKET, RESUME_BOTH); - ExpectResumption(RESUME_NONE); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectResumeClientBothTicketServerTicket) { - // This causes a ticket resumption. - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - Connect(); - - ResetRsa(); - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - ExpectResumption(RESUME_TICKET); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectClientServerTicketOnly) { - // This causes no resumption because the client needs the - // session cache to resume even with tickets. - ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); - Connect(); - - ResetRsa(); - ConfigureSessionCache(RESUME_TICKET, RESUME_TICKET); - ExpectResumption(RESUME_NONE); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectClientBothServerNone) { - ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); - Connect(); - - ResetRsa(); - ConfigureSessionCache(RESUME_BOTH, RESUME_NONE); - ExpectResumption(RESUME_NONE); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectClientNoneServerBoth) { - ConfigureSessionCache(RESUME_NONE, RESUME_BOTH); - Connect(); - - ResetRsa(); - ConfigureSessionCache(RESUME_NONE, RESUME_BOTH); - ExpectResumption(RESUME_NONE); - Connect(); -} - -TEST_P(TlsConnectGeneric, ResumeWithHigherVersion) { - EnsureTlsSetup(); - SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_1); - ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_1); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_1); - Connect(); - - ResetRsa(); - EnsureTlsSetup(); - SetExpectedVersion(SSL_LIBRARY_VERSION_TLS_1_2); - client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_2); - server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_2); - ExpectResumption(RESUME_NONE); - Connect(); -} - -TEST_P(TlsConnectGeneric, ClientAuth) { - client_->SetupClientAuth(); - server_->RequestClientAuth(true); - Connect(); - server_->CheckAuthType(ssl_auth_rsa); -} - -TEST_P(TlsConnectGeneric, ClientAuthEcdsa) { - ResetEcdsa(); - client_->SetupClientAuth(); - server_->RequestClientAuth(true); - Connect(); - server_->CheckAuthType(ssl_auth_ecdsa); -} - -static const SSLSignatureAndHashAlg SignatureEcdsaSha384[] = { - {ssl_hash_sha384, ssl_sign_ecdsa} -}; -static const SSLSignatureAndHashAlg SignatureEcdsaSha256[] = { - {ssl_hash_sha256, ssl_sign_ecdsa} -}; -static const SSLSignatureAndHashAlg SignatureRsaSha384[] = { - {ssl_hash_sha384, ssl_sign_rsa} -}; -static const SSLSignatureAndHashAlg SignatureRsaSha256[] = { - {ssl_hash_sha256, ssl_sign_rsa} -}; - -// When signature algorithms match up, this should connect successfully; even -// for TLS 1.1 and 1.0, where they should be ignored. -TEST_P(TlsConnectGeneric, SignatureAlgorithmServerAuth) { - client_->SetSignatureAlgorithms(SignatureEcdsaSha384, - PR_ARRAY_SIZE(SignatureEcdsaSha384)); - server_->SetSignatureAlgorithms(SignatureEcdsaSha384, - PR_ARRAY_SIZE(SignatureEcdsaSha384)); - ResetEcdsa(); - Connect(); -} - -// Here the client picks a single option, which should work in all versions. -// Defaults on the server include the first option. -TEST_P(TlsConnectGeneric, SignatureAlgorithmClientOnly) { - const SSLSignatureAndHashAlg clientAlgorithms[] = { - {ssl_hash_sha384, ssl_sign_ecdsa}, - {ssl_hash_sha384, ssl_sign_rsa}, // supported but unusable - {ssl_hash_md5, ssl_sign_ecdsa} // unsupported and ignored - }; - client_->SetSignatureAlgorithms(clientAlgorithms, - PR_ARRAY_SIZE(clientAlgorithms)); - ResetEcdsa(); - Connect(); -} - -// Here the server picks a single option, which should work in all versions. -// Defaults on the client include the provided option. -TEST_P(TlsConnectGeneric, SignatureAlgorithmServerOnly) { - server_->SetSignatureAlgorithms(SignatureEcdsaSha384, - PR_ARRAY_SIZE(SignatureEcdsaSha384)); - ResetEcdsa(); - Connect(); -} - -// There is no need for overlap on signatures; since we don't actually use the -// signatures for static RSA, this should still connect successfully. -// This should also work in TLS 1.0 and 1.1 where the algorithms aren't used. -TEST_P(TlsConnectGeneric, SignatureAlgorithmNoOverlapStaticRsa) { - client_->SetSignatureAlgorithms(SignatureRsaSha384, - PR_ARRAY_SIZE(SignatureRsaSha384)); - server_->SetSignatureAlgorithms(SignatureRsaSha256, - PR_ARRAY_SIZE(SignatureRsaSha256)); - DisableDheAndEcdheCiphers(); - Connect(); - client_->CheckKEAType(ssl_kea_rsa); - client_->CheckAuthType(ssl_auth_rsa); -} - -// Signature algorithms governs both verification and generation of signatures. -// With ECDSA, we need to at least have a common signature algorithm configured. -TEST_P(TlsConnectTls12, SignatureAlgorithmNoOverlapEcdsa) { - ResetEcdsa(); - client_->SetSignatureAlgorithms(SignatureEcdsaSha384, - PR_ARRAY_SIZE(SignatureEcdsaSha384)); - server_->SetSignatureAlgorithms(SignatureEcdsaSha256, - PR_ARRAY_SIZE(SignatureEcdsaSha256)); - ConnectExpectFail(); -} - -// Pre 1.2, a mismatch on signature algorithms shouldn't affect anything. -TEST_P(TlsConnectPre12, SignatureAlgorithmNoOverlapEcdsa) { - ResetEcdsa(); - client_->SetSignatureAlgorithms(SignatureEcdsaSha384, - PR_ARRAY_SIZE(SignatureEcdsaSha384)); - server_->SetSignatureAlgorithms(SignatureEcdsaSha256, - PR_ARRAY_SIZE(SignatureEcdsaSha256)); - Connect(); -} - -// The server requests client auth but doesn't offer a SHA-256 option. -// This fails because NSS only uses SHA-256 for handshake transcript hashes. -TEST_P(TlsConnectTls12, RequestClientAuthWithoutSha256) { - server_->SetSignatureAlgorithms(SignatureRsaSha384, - PR_ARRAY_SIZE(SignatureRsaSha384)); - server_->RequestClientAuth(false); - ConnectExpectFail(); -} - -TEST_P(TlsConnectGeneric, ConnectAlpn) { - EnableAlpn(); - Connect(); - client_->CheckAlpn(SSL_NEXT_PROTO_SELECTED, "a"); - server_->CheckAlpn(SSL_NEXT_PROTO_NEGOTIATED, "a"); -} - -TEST_P(TlsConnectDatagram, ConnectSrtp) { - EnableSrtp(); - Connect(); - CheckSrtp(); -} - -TEST_P(TlsConnectStream, ConnectAndClientRenegotiate) { - Connect(); - server_->PrepareForRenegotiate(); - client_->StartRenegotiate(); - Handshake(); - CheckConnected(); -} - -TEST_P(TlsConnectStream, ConnectAndServerRenegotiate) { - Connect(); - client_->PrepareForRenegotiate(); - server_->StartRenegotiate(); - Handshake(); - CheckConnected(); -} - -TEST_P(TlsConnectStream, ConnectStaticRSA) { - DisableDheAndEcdheCiphers(); - Connect(); - client_->CheckKEAType(ssl_kea_rsa); -} - -TEST_P(TlsConnectStream, ConnectDhe) { - DisableEcdheCiphers(); - Connect(); - client_->CheckKEAType(ssl_kea_dh); -} - -// Test that a totally bogus EPMS is handled correctly. -// This test is stream so we can catch the bad_record_mac alert. -TEST_P(TlsConnectStream, ConnectStaticRSABogusCKE) { - DisableDheAndEcdheCiphers(); - TlsInspectorReplaceHandshakeMessage* i1 = - new TlsInspectorReplaceHandshakeMessage(kTlsHandshakeClientKeyExchange, - DataBuffer( - kBogusClientKeyExchange, - sizeof(kBogusClientKeyExchange))); - client_->SetPacketFilter(i1); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); -} - -// Test that a PMS with a bogus version number is handled correctly. -// This test is stream so we can catch the bad_record_mac alert. -TEST_P(TlsConnectStream, ConnectStaticRSABogusPMSVersionDetect) { - DisableDheAndEcdheCiphers(); - client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger( - server_)); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); -} - -// Test that a PMS with a bogus version number is ignored when -// rollback detection is disabled. This is a positive control for -// ConnectStaticRSABogusPMSVersionDetect. -TEST_P(TlsConnectGeneric, ConnectStaticRSABogusPMSVersionIgnore) { - DisableDheAndEcdheCiphers(); - client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger( - server_)); - server_->DisableRollbackDetection(); - Connect(); -} - -TEST_P(TlsConnectStream, ConnectEcdhe) { - Connect(); - client_->CheckKEAType(ssl_kea_ecdh); -} - -TEST_P(TlsConnectStream, ConnectEcdheTwiceReuseKey) { - TlsInspectorRecordHandshakeMessage* i1 = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i1); - Connect(); - client_->CheckKEAType(ssl_kea_ecdh); - TlsServerKeyExchangeEcdhe dhe1; - EXPECT_TRUE(dhe1.Parse(i1->buffer())); - - // Restart - ResetRsa(); - TlsInspectorRecordHandshakeMessage* i2 = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i2); - ConfigureSessionCache(RESUME_NONE, RESUME_NONE); - Connect(); - client_->CheckKEAType(ssl_kea_ecdh); - - TlsServerKeyExchangeEcdhe dhe2; - EXPECT_TRUE(dhe2.Parse(i2->buffer())); - - // Make sure they are the same. - EXPECT_EQ(dhe1.public_key_.len(), dhe2.public_key_.len()); - EXPECT_TRUE(!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(), - dhe1.public_key_.len())); -} - -TEST_P(TlsConnectStream, ConnectEcdheTwiceNewKey) { - server_->EnsureTlsSetup(); - SECStatus rv = - SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - EXPECT_EQ(SECSuccess, rv); - TlsInspectorRecordHandshakeMessage* i1 = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i1); - Connect(); - client_->CheckKEAType(ssl_kea_ecdh); - TlsServerKeyExchangeEcdhe dhe1; - EXPECT_TRUE(dhe1.Parse(i1->buffer())); - - // Restart - ResetRsa(); - server_->EnsureTlsSetup(); - rv = SSL_OptionSet(server_->ssl_fd(), SSL_REUSE_SERVER_ECDHE_KEY, PR_FALSE); - EXPECT_EQ(SECSuccess, rv); - TlsInspectorRecordHandshakeMessage* i2 = - new TlsInspectorRecordHandshakeMessage(kTlsHandshakeServerKeyExchange); - server_->SetPacketFilter(i2); - ConfigureSessionCache(RESUME_NONE, RESUME_NONE); - Connect(); - client_->CheckKEAType(ssl_kea_ecdh); - - TlsServerKeyExchangeEcdhe dhe2; - EXPECT_TRUE(dhe2.Parse(i2->buffer())); - - // Make sure they are different. - EXPECT_FALSE((dhe1.public_key_.len() == dhe2.public_key_.len()) && - (!memcmp(dhe1.public_key_.data(), dhe2.public_key_.data(), - dhe1.public_key_.len()))); -} - -TEST_P(TlsConnectGeneric, ConnectSendReceive) { - Connect(); - SendReceive(); -} - -// The next two tests takes advantage of the fact that we -// automatically read the first 1024 bytes, so if -// we provide 1200 bytes, they overrun the read buffer -// provided by the calling test. - -// DTLS should return an error. -TEST_P(TlsConnectDatagram, ShortRead) { - Connect(); - client_->SetExpectedReadError(true); - server_->SendData(1200, 1200); - WAIT_(client_->error_code() == SSL_ERROR_RX_SHORT_DTLS_READ, 2000); - // Don't call CheckErrorCode() because it requires us to being - // in state ERROR. - ASSERT_EQ(SSL_ERROR_RX_SHORT_DTLS_READ, client_->error_code()); - - // Now send and receive another packet. - client_->SetExpectedReadError(false); - server_->ResetSentBytes(); // Reset the counter. - SendReceive(); -} - -// TLS should get the write in two chunks. -TEST_P(TlsConnectStream, ShortRead) { - // This test behaves oddly with TLS 1.0 because of 1/n+1 splitting, - // so skip in that case. - if (version_ < SSL_LIBRARY_VERSION_TLS_1_1) - return; - - Connect(); - server_->SendData(1200, 1200); - // Read the first tranche. - WAIT_(client_->received_bytes() == 1024, 2000); - ASSERT_EQ(1024U, client_->received_bytes()); - // The second tranche should now immediately be available. - client_->ReadBytes(); - ASSERT_EQ(1200U, client_->received_bytes()); -} - -TEST_P(TlsConnectGeneric, ConnectExtendedMasterSecret) { - EnableExtendedMasterSecret(); - Connect(); - ResetRsa(); - ExpectResumption(RESUME_SESSIONID); - EnableExtendedMasterSecret(); - Connect(); -} - - -TEST_P(TlsConnectGeneric, ConnectExtendedMasterSecretStaticRSA) { - DisableDheAndEcdheCiphers(); - EnableExtendedMasterSecret(); - Connect(); -} - -// This test is stream so we can catch the bad_record_mac alert. -TEST_P(TlsConnectStream, ConnectExtendedMasterSecretStaticRSABogusCKE) { - DisableDheAndEcdheCiphers(); - EnableExtendedMasterSecret(); - TlsInspectorReplaceHandshakeMessage* inspect = - new TlsInspectorReplaceHandshakeMessage(kTlsHandshakeClientKeyExchange, - DataBuffer( - kBogusClientKeyExchange, - sizeof(kBogusClientKeyExchange))); - client_->SetPacketFilter(inspect); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); -} - -// This test is stream so we can catch the bad_record_mac alert. -TEST_P(TlsConnectStream, ConnectExtendedMasterSecretStaticRSABogusPMSVersionDetect) { - DisableDheAndEcdheCiphers(); - EnableExtendedMasterSecret(); - client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger( - server_)); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertBadRecordMac, alert_recorder->description()); -} - -TEST_P(TlsConnectStream, ConnectExtendedMasterSecretStaticRSABogusPMSVersionIgnore) { - DisableDheAndEcdheCiphers(); - EnableExtendedMasterSecret(); - client_->SetPacketFilter(new TlsInspectorClientHelloVersionChanger( - server_)); - server_->DisableRollbackDetection(); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectExtendedMasterSecretECDHE) { - EnableExtendedMasterSecret(); - Connect(); - - ResetRsa(); - EnableExtendedMasterSecret(); - ExpectResumption(RESUME_SESSIONID); - Connect(); -} - -TEST_P(TlsConnectGeneric, ConnectExtendedMasterSecretTicket) { - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - EnableExtendedMasterSecret(); - Connect(); - - ResetRsa(); - ConfigureSessionCache(RESUME_BOTH, RESUME_TICKET); - - EnableExtendedMasterSecret(); - ExpectResumption(RESUME_TICKET); - Connect(); -} - -TEST_P(TlsConnectGeneric, - ConnectExtendedMasterSecretClientOnly) { - client_->EnableExtendedMasterSecret(); - ExpectExtendedMasterSecret(false); - Connect(); -} - -TEST_P(TlsConnectGeneric, - ConnectExtendedMasterSecretServerOnly) { - server_->EnableExtendedMasterSecret(); - ExpectExtendedMasterSecret(false); - Connect(); -} - -TEST_P(TlsConnectGeneric, - ConnectExtendedMasterSecretResumeWithout) { - EnableExtendedMasterSecret(); - Connect(); - - ResetRsa(); - server_->EnableExtendedMasterSecret(); - auto alert_recorder = new TlsAlertRecorder(); - server_->SetPacketFilter(alert_recorder); - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(kTlsAlertHandshakeFailure, alert_recorder->description()); -} - -TEST_P(TlsConnectGeneric, - ConnectNormalResumeWithExtendedMasterSecret) { - ConfigureSessionCache(RESUME_SESSIONID, RESUME_SESSIONID); - ExpectExtendedMasterSecret(false); - Connect(); - - ResetRsa(); - EnableExtendedMasterSecret(); - ExpectResumption(RESUME_NONE); - Connect(); -} - -INSTANTIATE_TEST_CASE_P(VariantsStream10, TlsConnectGeneric, - ::testing::Combine( - TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10)); -INSTANTIATE_TEST_CASE_P(VariantsAll, TlsConnectGeneric, - ::testing::Combine( - TlsConnectTestBase::kTlsModesAll, - TlsConnectTestBase::kTlsV11V12)); -INSTANTIATE_TEST_CASE_P(VersionsDatagram, TlsConnectDatagram, - TlsConnectTestBase::kTlsV11V12); -INSTANTIATE_TEST_CASE_P(Variants12, TlsConnectTls12, - TlsConnectTestBase::kTlsModesAll); -INSTANTIATE_TEST_CASE_P(Pre12Stream, TlsConnectPre12, - ::testing::Combine( - TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10)); -INSTANTIATE_TEST_CASE_P(Pre12All, TlsConnectPre12, - ::testing::Combine( - TlsConnectTestBase::kTlsModesAll, - TlsConnectTestBase::kTlsV11)); -INSTANTIATE_TEST_CASE_P(VersionsStream10, TlsConnectStream, - TlsConnectTestBase::kTlsV10); -INSTANTIATE_TEST_CASE_P(VersionsStream, TlsConnectStream, - TlsConnectTestBase::kTlsV11V12); - -} // namespace nspr_test diff --git a/nss/external_tests/ssl_gtest/ssl_prf_unittest.cc b/nss/external_tests/ssl_gtest/ssl_prf_unittest.cc deleted file mode 100644 index ea2478b..0000000 --- a/nss/external_tests/ssl_gtest/ssl_prf_unittest.cc +++ /dev/null @@ -1,253 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "nss.h" -#include "pk11pub.h" -#include <memory> - -#include "gtest_utils.h" - -namespace nss_test { - -#define CONST_UINT8_TO_UCHAR(a) const_cast<unsigned char*>( \ - static_cast<const unsigned char *>(a)) - -const size_t kPmsSize = 48; -const size_t kMasterSecretSize = 48; -const size_t kPrfSeedSizeSha256 = 32; -const size_t kPrfSeedSizeTlsPrf = 36; - -// This is not the right size for anything -const size_t kIncorrectSize = 17; - -const uint8_t kPmsData[] = { - 0x00,0x01,0x02,0x03,0x04,0x05,0x06,0x07, - 0x08,0x09,0x0a,0x0b,0x0c,0x0d,0x0e,0x0f, - 0x10,0x11,0x12,0x13,0x14,0x15,0x16,0x17, - 0x18,0x19,0x1a,0x1b,0x1c,0x1d,0x1e,0x1f, - 0x20,0x21,0x22,0x23,0x24,0x25,0x26,0x27, - 0x28,0x29,0x2a,0x2b,0x2c,0x2d,0x2e,0x2f -}; - -const uint8_t kPrfSeed[] = { - 0xf0,0xf1,0xf2,0xf3,0xf4,0xf5,0xf6,0xf7, - 0xf8,0xf9,0xfa,0xfb,0xfc,0xfd,0xfe,0xff, - 0xe0,0xe1,0xe2,0xe3,0xe4,0xe5,0xe6,0xe7, - 0xe8,0xe9,0xea,0xeb,0xec,0xed,0xee,0xef, - 0xd0,0xd1,0xd2,0xd3 -}; - -const uint8_t kExpectedOutputEmsSha256[] = { - 0x75,0xa7,0xa5,0x98,0xef,0xab,0x90,0xe7, - 0x7c,0x67,0x80,0xde,0xab,0x3a,0x11,0xf3, - 0x5d,0xb2,0xf8,0x47,0xff,0x09,0x01,0xec, - 0xf8,0x93,0x89,0xfc,0x98,0x2e,0x6e,0xf9, - 0x2c,0xf5,0x9b,0x04,0x04,0x6f,0xd7,0x28, - 0x6e,0xea,0xe3,0x83,0xc4,0x4a,0xff,0x03 -}; - -const uint8_t kExpectedOutputEmsTlsPrf[] = { - 0x06,0xbf,0x29,0x86,0x5d,0xf3,0x3e,0x38, - 0xfd,0xfa,0x91,0x10,0x2a,0x20,0xff,0xd6, - 0xb9,0xd5,0x72,0x5a,0x6d,0x42,0x20,0x16, - 0xde,0xa4,0xa0,0x51,0xe5,0x53,0xc1,0x28, - 0x04,0x99,0xbc,0xb1,0x2c,0x9d,0xe8,0x0b, - 0x18,0xa2,0x0e,0x48,0x52,0x8d,0x61,0x13 -}; - -static unsigned char* toUcharPtr(const uint8_t* v) { - return const_cast<unsigned char*>( - static_cast<const unsigned char *>(v)); -} - -class TlsPrfTest : public ::testing::Test { - public: - TlsPrfTest() - : params_({siBuffer, nullptr, 0}) - , pms_item_({siBuffer, toUcharPtr(kPmsData), kPmsSize}) - , key_mech_(0) - , slot_(nullptr) - , pms_(nullptr) - , ms_(nullptr) - , pms_version_({0, 0}) {} - - ~TlsPrfTest() { - if (slot_) { PK11_FreeSlot(slot_); } - ClearTempVars(); - } - - void ClearTempVars() { - if (pms_) { PK11_FreeSymKey(pms_); } - if (ms_) { PK11_FreeSymKey(ms_); } - } - - void Init() { - params_.type = siBuffer; - - pms_item_.type = siBuffer; - pms_item_.data = const_cast<unsigned char*>( - static_cast<const unsigned char *>(kPmsData)); - - slot_ = PK11_GetInternalSlot(); - ASSERT_NE(nullptr, slot_); - } - - void CheckForError(CK_MECHANISM_TYPE hash_mech, - size_t seed_len, - size_t pms_len, - size_t output_len) { - // Error tests don't depend on the derivation mechansim - Inner(CKM_NSS_TLS_EXTENDED_MASTER_KEY_DERIVE, hash_mech, - seed_len, pms_len, output_len, nullptr, nullptr); - } - - void ComputeAndVerifyMs(CK_MECHANISM_TYPE derive_mech, - CK_MECHANISM_TYPE hash_mech, - CK_VERSION* version, - const uint8_t* expected) { - // Infer seed length from mechanism - int seed_len = 0; - switch (hash_mech) { - case CKM_TLS_PRF: seed_len = kPrfSeedSizeTlsPrf; break; - case CKM_SHA256: seed_len = kPrfSeedSizeSha256; break; - default: ASSERT_TRUE(false); - } - - Inner(derive_mech, hash_mech, seed_len, - kPmsSize, 0, version, expected); - } - - - // Set output == nullptr to test when errors occur - void Inner( - CK_MECHANISM_TYPE derive_mech, - CK_MECHANISM_TYPE hash_mech, - size_t seed_len, - size_t pms_len, - size_t output_len, - CK_VERSION* version, - const uint8_t* expected) { - ClearTempVars(); - - // Infer the key mechanism from the hash type - switch (hash_mech) { - case CKM_TLS_PRF: key_mech_ = CKM_TLS_KEY_AND_MAC_DERIVE; break; - case CKM_SHA256: key_mech_ = CKM_NSS_TLS_KEY_AND_MAC_DERIVE_SHA256; break; - default: ASSERT_TRUE(false); - } - - // Import the params - CK_NSS_TLS_EXTENDED_MASTER_KEY_DERIVE_PARAMS master_params = { - hash_mech, - toUcharPtr(kPrfSeed), - seed_len, - version - }; - params_.data = reinterpret_cast<unsigned char*>(&master_params); - params_.len = sizeof(master_params); - - // Import the PMS - pms_item_.len = pms_len; - pms_ = PK11_ImportSymKey(slot_, derive_mech, PK11_OriginUnwrap, - CKA_DERIVE, &pms_item_, NULL); - ASSERT_NE(nullptr, pms_); - - - // Compute the EMS - ms_ = PK11_DeriveWithFlags(pms_, derive_mech, ¶ms_, key_mech_, - CKA_DERIVE, output_len, CKF_SIGN | CKF_VERIFY); - - // Verify the EMS has the expected value (null or otherwise) - if (!expected) { - EXPECT_EQ(nullptr, ms_); - } else { - ASSERT_NE(nullptr, ms_); - - SECStatus rv = PK11_ExtractKeyValue(ms_); - ASSERT_EQ(SECSuccess, rv); - - SECItem *msData = PK11_GetKeyData(ms_); - ASSERT_NE(nullptr, msData); - - ASSERT_EQ(kMasterSecretSize, msData->len); - EXPECT_EQ(0, - memcmp(msData->data, expected, kMasterSecretSize)); - } - } - - protected: - SECItem params_; - SECItem pms_item_; - CK_MECHANISM_TYPE key_mech_; - PK11SlotInfo *slot_; - PK11SymKey *pms_; - PK11SymKey *ms_; - CK_VERSION pms_version_; -}; - -TEST_F(TlsPrfTest, ExtendedMsParamErr) { - Init(); - - // This should fail; it's the correct set from which the below are derived - // CheckForError(CKM_NSS_TLS_EXTENDED_MASTER_KEY_DERIVE, CKM_TLS_PRF, kPrfSeedSizeTlsPrf, kPmsSize, 0); - - // Output key size != 0, SSL3_MASTER_SECRET_LENGTH - CheckForError(CKM_TLS_PRF, kPrfSeedSizeTlsPrf, kPmsSize, kIncorrectSize); - - // not-DH && pms size != SSL3_PMS_LENGTH - CheckForError(CKM_TLS_PRF, kPrfSeedSizeTlsPrf, kIncorrectSize, 0); - - // CKM_TLS_PRF && seed length != MD5_LENGTH + SHA1_LENGTH - CheckForError(CKM_TLS_PRF, kIncorrectSize, kPmsSize, 0); - - // !CKM_TLS_PRF && seed length != hash output length - CheckForError(CKM_SHA256, kIncorrectSize, kPmsSize, 0); -} - -// Test matrix: -// -// DH RSA -// TLS_PRF 1 2 -// SHA256 3 4 -TEST_F(TlsPrfTest, ExtendedMsDhTlsPrf) { - Init(); - ComputeAndVerifyMs(CKM_NSS_TLS_EXTENDED_MASTER_KEY_DERIVE_DH, - CKM_TLS_PRF, - nullptr, - kExpectedOutputEmsTlsPrf); -} - -TEST_F(TlsPrfTest, ExtendedMsRsaTlsPrf) { - Init(); - ComputeAndVerifyMs(CKM_NSS_TLS_EXTENDED_MASTER_KEY_DERIVE, - CKM_TLS_PRF, - &pms_version_, - kExpectedOutputEmsTlsPrf); - EXPECT_EQ(0, pms_version_.major); - EXPECT_EQ(1, pms_version_.minor); -} - - -TEST_F(TlsPrfTest, ExtendedMsDhSha256) { - Init(); - ComputeAndVerifyMs(CKM_NSS_TLS_EXTENDED_MASTER_KEY_DERIVE_DH, - CKM_SHA256, - nullptr, - kExpectedOutputEmsSha256); -} - -TEST_F(TlsPrfTest, ExtendedMsRsaSha256) { - Init(); - ComputeAndVerifyMs(CKM_NSS_TLS_EXTENDED_MASTER_KEY_DERIVE, - CKM_SHA256, - &pms_version_, - kExpectedOutputEmsSha256); - EXPECT_EQ(0, pms_version_.major); - EXPECT_EQ(1, pms_version_.minor); -} - -} // namespace nss_test - diff --git a/nss/external_tests/ssl_gtest/ssl_skip_unittest.cc b/nss/external_tests/ssl_gtest/ssl_skip_unittest.cc deleted file mode 100644 index 3a893cf..0000000 --- a/nss/external_tests/ssl_gtest/ssl_skip_unittest.cc +++ /dev/null @@ -1,170 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "sslerr.h" - -#include "tls_parser.h" -#include "tls_filter.h" -#include "tls_connect.h" - -/* - * The tests in this file test that the TLS state machine is robust against - * attacks that alter the order of handshake messages. - * - * See <https://www.smacktls.com/smack.pdf> for a description of the problems - * that this sort of attack can enable. - */ -namespace nss_test { - -class TlsHandshakeSkipFilter : public TlsRecordFilter { - public: - // A TLS record filter that skips handshake messages of the identified type. - TlsHandshakeSkipFilter(uint8_t handshake_type) - : handshake_type_(handshake_type), - skipped_(false) {} - - protected: - // Takes a record; if it is a handshake record, it removes the first handshake - // message that is of handshake_type_ type. - virtual bool FilterRecord(uint8_t content_type, uint16_t version, - const DataBuffer& input, DataBuffer* output) { - if (content_type != kTlsHandshakeType) { - return false; - } - - size_t output_offset = 0U; - output->Allocate(input.len()); - - TlsParser parser(input); - while (parser.remaining()) { - size_t start = parser.consumed(); - uint8_t handshake_type; - if (!parser.Read(&handshake_type)) { - return false; - } - uint32_t length; - if (!TlsHandshakeFilter::ReadLength(&parser, version, &length)) { - return false; - } - - if (!parser.Skip(length)) { - return false; - } - - if (skipped_ || handshake_type != handshake_type_) { - size_t entire_length = parser.consumed() - start; - output->Write(output_offset, input.data() + start, - entire_length); - // DTLS sequence numbers need to be rewritten - if (skipped_ && IsDtls(version)) { - output->data()[start + 5] -= 1; - } - output_offset += entire_length; - } else { - std::cerr << "Dropping handshake: " - << static_cast<unsigned>(handshake_type_) << std::endl; - // We only need to report that the output contains changed data if we - // drop a handshake message. But once we've skipped one message, we - // have to modify all subsequent handshake messages so that they include - // the correct DTLS sequence numbers. - skipped_ = true; - } - } - output->Truncate(output_offset); - return skipped_; - } - - private: - // The type of handshake message to drop. - uint8_t handshake_type_; - // Whether this filter has ever skipped a handshake message. Track this so - // that sequence numbers on DTLS handshake messages can be rewritten in - // subsequent calls. - bool skipped_; -}; - -class TlsSkipTest - : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { - - protected: - TlsSkipTest() - : TlsConnectTestBase(TlsConnectTestBase::ToMode(std::get<0>(GetParam())), - std::get<1>(GetParam())) {} - - void ServerSkipTest(PacketFilter* filter, - uint8_t alert = kTlsAlertUnexpectedMessage) { - auto alert_recorder = new TlsAlertRecorder(); - client_->SetPacketFilter(alert_recorder); - if (filter) { - server_->SetPacketFilter(filter); - } - ConnectExpectFail(); - EXPECT_EQ(kTlsAlertFatal, alert_recorder->level()); - EXPECT_EQ(alert, alert_recorder->description()); - } -}; - -TEST_P(TlsSkipTest, SkipCertificateRsa) { - DisableDheAndEcdheCiphers(); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); -} - -TEST_P(TlsSkipTest, SkipCertificateDhe) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); -} - -TEST_P(TlsSkipTest, SkipCertificateEcdhe) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); -} - -TEST_P(TlsSkipTest, SkipCertificateEcdsa) { - ResetEcdsa(); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_SERVER_KEY_EXCH); -} - -TEST_P(TlsSkipTest, SkipServerKeyExchange) { - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); - client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); -} - -TEST_P(TlsSkipTest, SkipServerKeyExchangeEcdsa) { - ResetEcdsa(); - ServerSkipTest(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); - client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); -} - -TEST_P(TlsSkipTest, SkipCertAndKeyExch) { - auto chain = new ChainedPacketFilter(); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); - ServerSkipTest(chain); - client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); -} - -TEST_P(TlsSkipTest, SkipCertAndKeyExchEcdsa) { - ResetEcdsa(); - auto chain = new ChainedPacketFilter(); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeCertificate)); - chain->Add(new TlsHandshakeSkipFilter(kTlsHandshakeServerKeyExchange)); - ServerSkipTest(chain); - client_->CheckErrorCode(SSL_ERROR_RX_UNEXPECTED_HELLO_DONE); -} - -INSTANTIATE_TEST_CASE_P(SkipTls10, TlsSkipTest, - ::testing::Combine( - TlsConnectTestBase::kTlsModesStream, - TlsConnectTestBase::kTlsV10)); -INSTANTIATE_TEST_CASE_P(SkipVariants, TlsSkipTest, - ::testing::Combine( - TlsConnectTestBase::kTlsModesAll, - TlsConnectTestBase::kTlsV11V12)); - -} // namespace nss_test diff --git a/nss/external_tests/ssl_gtest/test_io.cc b/nss/external_tests/ssl_gtest/test_io.cc deleted file mode 100644 index 9c28969..0000000 --- a/nss/external_tests/ssl_gtest/test_io.cc +++ /dev/null @@ -1,493 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "test_io.h" - -#include <algorithm> -#include <cassert> -#include <iostream> -#include <memory> - -#include "prerror.h" -#include "prlog.h" -#include "prthread.h" - -#include "databuffer.h" - -namespace nss_test { - -static PRDescIdentity test_fd_identity = PR_INVALID_IO_LAYER; - -#define UNIMPLEMENTED() \ - std::cerr << "Call to unimplemented function " \ - << __FUNCTION__ << std::endl; \ - PR_ASSERT(PR_FALSE); \ - PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0) - -#define LOG(a) std::cerr << name_ << ": " << a << std::endl; - -class Packet : public DataBuffer { - public: - Packet(const DataBuffer& buf) : DataBuffer(buf), offset_(0) {} - - void Advance(size_t delta) { - PR_ASSERT(offset_ + delta <= len()); - offset_ = std::min(len(), offset_ + delta); - } - - size_t offset() const { return offset_; } - size_t remaining() const { return len() - offset_; } - - private: - size_t offset_; -}; - -// Implementation of NSPR methods -static PRStatus DummyClose(PRFileDesc *f) { - f->secret = nullptr; - return PR_SUCCESS; -} - -static int32_t DummyRead(PRFileDesc *f, void *buf, int32_t length) { - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - return io->Read(buf, length); -} - -static int32_t DummyWrite(PRFileDesc *f, const void *buf, int32_t length) { - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - return io->Write(buf, length); -} - -static int32_t DummyAvailable(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -int64_t DummyAvailable64(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummySync(PRFileDesc *f) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummySeek(PRFileDesc *f, int32_t offset, PRSeekWhence how) { - UNIMPLEMENTED(); - return -1; -} - -static int64_t DummySeek64(PRFileDesc *f, int64_t offset, PRSeekWhence how) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyFileInfo(PRFileDesc *f, PRFileInfo *info) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyFileInfo64(PRFileDesc *f, PRFileInfo64 *info) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummyWritev(PRFileDesc *f, const PRIOVec *iov, int32_t iov_size, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyConnect(PRFileDesc *f, const PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRFileDesc *DummyAccept(PRFileDesc *sd, PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return nullptr; -} - -static PRStatus DummyBind(PRFileDesc *f, const PRNetAddr *addr) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyListen(PRFileDesc *f, int32_t depth) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyShutdown(PRFileDesc *f, int32_t how) { - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - io->Reset(); - return PR_SUCCESS; -} - -// This function does not support peek. -static int32_t DummyRecv(PRFileDesc *f, void *buf, int32_t buflen, - int32_t flags, PRIntervalTime to) { - PR_ASSERT(flags == 0); - if (flags != 0) { - PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0); - return -1; - } - - DummyPrSocket *io = reinterpret_cast<DummyPrSocket *>(f->secret); - - if (io->mode() == DGRAM) { - return io->Recv(buf, buflen); - } else { - return io->Read(buf, buflen); - } -} - -// Note: this is always nonblocking and assumes a zero timeout. -static int32_t DummySend(PRFileDesc *f, const void *buf, int32_t amount, - int32_t flags, PRIntervalTime to) { - int32_t written = DummyWrite(f, buf, amount); - return written; -} - -static int32_t DummyRecvfrom(PRFileDesc *f, void *buf, int32_t amount, - int32_t flags, PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummySendto(PRFileDesc *f, const void *buf, int32_t amount, - int32_t flags, const PRNetAddr *addr, - PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static int16_t DummyPoll(PRFileDesc *f, int16_t in_flags, int16_t *out_flags) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummyAcceptRead(PRFileDesc *sd, PRFileDesc **nd, - PRNetAddr **raddr, void *buf, int32_t amount, - PRIntervalTime t) { - UNIMPLEMENTED(); - return -1; -} - -static int32_t DummyTransmitFile(PRFileDesc *sd, PRFileDesc *f, - const void *headers, int32_t hlen, - PRTransmitFileFlags flags, PRIntervalTime t) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyGetpeername(PRFileDesc *f, PRNetAddr *addr) { - // TODO: Modify to return unique names for each channel - // somehow, as opposed to always the same static address. The current - // implementation messes up the session cache, which is why it's off - // elsewhere - addr->inet.family = PR_AF_INET; - addr->inet.port = 0; - addr->inet.ip = 0; - - return PR_SUCCESS; -} - -static PRStatus DummyGetsockname(PRFileDesc *f, PRNetAddr *addr) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static PRStatus DummyGetsockoption(PRFileDesc *f, PRSocketOptionData *opt) { - switch (opt->option) { - case PR_SockOpt_Nonblocking: - opt->value.non_blocking = PR_TRUE; - return PR_SUCCESS; - default: - UNIMPLEMENTED(); - break; - } - - return PR_FAILURE; -} - -// Imitate setting socket options. These are mostly noops. -static PRStatus DummySetsockoption(PRFileDesc *f, - const PRSocketOptionData *opt) { - switch (opt->option) { - case PR_SockOpt_Nonblocking: - return PR_SUCCESS; - case PR_SockOpt_NoDelay: - return PR_SUCCESS; - default: - UNIMPLEMENTED(); - break; - } - - return PR_FAILURE; -} - -static int32_t DummySendfile(PRFileDesc *out, PRSendFileData *in, - PRTransmitFileFlags flags, PRIntervalTime to) { - UNIMPLEMENTED(); - return -1; -} - -static PRStatus DummyConnectContinue(PRFileDesc *f, int16_t flags) { - UNIMPLEMENTED(); - return PR_FAILURE; -} - -static int32_t DummyReserved(PRFileDesc *f) { - UNIMPLEMENTED(); - return -1; -} - -DummyPrSocket::~DummyPrSocket() { - Reset(); -} - -void DummyPrSocket::Reset() { - delete filter_; - peer_ = nullptr; - while (!input_.empty()) - { - Packet* front = input_.front(); - input_.pop(); - delete front; - } -} - -static const struct PRIOMethods DummyMethods = { - PR_DESC_LAYERED, DummyClose, DummyRead, - DummyWrite, DummyAvailable, DummyAvailable64, - DummySync, DummySeek, DummySeek64, - DummyFileInfo, DummyFileInfo64, DummyWritev, - DummyConnect, DummyAccept, DummyBind, - DummyListen, DummyShutdown, DummyRecv, - DummySend, DummyRecvfrom, DummySendto, - DummyPoll, DummyAcceptRead, DummyTransmitFile, - DummyGetsockname, DummyGetpeername, DummyReserved, - DummyReserved, DummyGetsockoption, DummySetsockoption, - DummySendfile, DummyConnectContinue, DummyReserved, - DummyReserved, DummyReserved, DummyReserved}; - -PRFileDesc *DummyPrSocket::CreateFD(const std::string &name, Mode mode) { - if (test_fd_identity == PR_INVALID_IO_LAYER) { - test_fd_identity = PR_GetUniqueIdentity("testtransportadapter"); - } - - PRFileDesc *fd = (PR_CreateIOLayerStub(test_fd_identity, &DummyMethods)); - fd->secret = reinterpret_cast<PRFilePrivate *>(new DummyPrSocket(name, mode)); - - return fd; -} - -DummyPrSocket *DummyPrSocket::GetAdapter(PRFileDesc *fd) { - return reinterpret_cast<DummyPrSocket *>(fd->secret); -} - -void DummyPrSocket::PacketReceived(const DataBuffer& packet) { - input_.push(new Packet(packet)); -} - -int32_t DummyPrSocket::Read(void *data, int32_t len) { - PR_ASSERT(mode_ == STREAM); - - if (mode_ != STREAM) { - PR_SetError(PR_INVALID_METHOD_ERROR, 0); - return -1; - } - - if (input_.empty()) { - LOG("Read --> wouldblock " << len); - PR_SetError(PR_WOULD_BLOCK_ERROR, 0); - return -1; - } - - Packet *front = input_.front(); - size_t to_read = std::min(static_cast<size_t>(len), - front->len() - front->offset()); - memcpy(data, static_cast<const void*>(front->data() + front->offset()), - to_read); - front->Advance(to_read); - - if (!front->remaining()) { - input_.pop(); - delete front; - } - - return static_cast<int32_t>(to_read); -} - -int32_t DummyPrSocket::Recv(void *buf, int32_t buflen) { - if (input_.empty()) { - PR_SetError(PR_WOULD_BLOCK_ERROR, 0); - return -1; - } - - Packet *front = input_.front(); - if (static_cast<size_t>(buflen) < front->len()) { - PR_ASSERT(false); - PR_SetError(PR_BUFFER_OVERFLOW_ERROR, 0); - return -1; - } - - size_t count = front->len(); - memcpy(buf, front->data(), count); - - input_.pop(); - delete front; - - return static_cast<int32_t>(count); -} - -int32_t DummyPrSocket::Write(const void *buf, int32_t length) { - if (!peer_) { - PR_SetError(PR_IO_ERROR, 0); - return -1; - } - - DataBuffer packet(static_cast<const uint8_t*>(buf), - static_cast<size_t>(length)); - DataBuffer filtered; - if (filter_ && filter_->Filter(packet, &filtered)) { - LOG("Filtered packet: " << filtered); - peer_->PacketReceived(filtered); - } else { - peer_->PacketReceived(packet); - } - // libssl can't handle it if this reports something other than the length - // of what was passed in (or less, but we're not doing partial writes). - return static_cast<int32_t>(packet.len()); -} - -Poller *Poller::instance; - -Poller *Poller::Instance() { - if (!instance) instance = new Poller(); - - return instance; -} - -void Poller::Shutdown() { - delete instance; - instance = nullptr; -} - -void Poller::Wait(Event event, DummyPrSocket *adapter, PollTarget *target, - PollCallback cb) { - auto it = waiters_.find(adapter); - Waiter *waiter; - - if (it == waiters_.end()) { - waiter = new Waiter(adapter); - } else { - waiter = it->second; - } - - assert(event < TIMER_EVENT); - if (event >= TIMER_EVENT) return; - - waiter->targets_[event] = target; - waiter->callbacks_[event] = cb; - waiters_[adapter] = waiter; -} - -void Poller::Cancel(Event event, DummyPrSocket *adapter) { - auto it = waiters_.find(adapter); - Waiter *waiter; - - if (it == waiters_.end()) { - return; - } - - waiter = it->second; - - waiter->targets_[event] = nullptr; - waiter->callbacks_[event] = nullptr; - - // Clean up if there are no callbacks. - for (size_t i=0; i<TIMER_EVENT; ++i) { - if (waiter->callbacks_[i]) - return; - } - - delete waiter; - waiters_.erase(adapter); -} - -void Poller::SetTimer(uint32_t timer_ms, PollTarget *target, PollCallback cb, - Timer **timer) { - Timer *t = new Timer(PR_Now() + timer_ms * 1000, target, cb); - timers_.push(t); - if (timer) *timer = t; -} - -bool Poller::Poll() { - std::cerr << "Poll()\n"; - PRIntervalTime timeout = PR_INTERVAL_NO_TIMEOUT; - PRTime now = PR_Now(); - bool fired = false; - - // Figure out the timer for the select. - if (!timers_.empty()) { - Timer *first_timer = timers_.top(); - if (now >= first_timer->deadline_) { - // Timer expired. - timeout = PR_INTERVAL_NO_WAIT; - } else { - timeout = - PR_MillisecondsToInterval((first_timer->deadline_ - now) / 1000); - } - } - - for (auto it = waiters_.begin(); it != waiters_.end(); ++it) { - Waiter *waiter = it->second; - - if (waiter->callbacks_[READABLE_EVENT]) { - if (waiter->io_->readable()) { - PollCallback callback = waiter->callbacks_[READABLE_EVENT]; - PollTarget *target = waiter->targets_[READABLE_EVENT]; - waiter->callbacks_[READABLE_EVENT] = nullptr; - waiter->targets_[READABLE_EVENT] = nullptr; - callback(target, READABLE_EVENT); - fired = true; - } - } - } - - if (fired) timeout = PR_INTERVAL_NO_WAIT; - - // Can't wait forever and also have nothing readable now. - if (timeout == PR_INTERVAL_NO_TIMEOUT) return false; - - // Sleep. - if (timeout != PR_INTERVAL_NO_WAIT) { - PR_Sleep(timeout); - } - - // Now process anything that timed out. - now = PR_Now(); - while (!timers_.empty()) { - if (now < timers_.top()->deadline_) break; - - Timer *timer = timers_.top(); - timers_.pop(); - if (timer->callback_) { - timer->callback_(timer->target_, TIMER_EVENT); - } - delete timer; - } - - return true; -} - -} // namespace nss_test diff --git a/nss/external_tests/ssl_gtest/test_io.h b/nss/external_tests/ssl_gtest/test_io.h deleted file mode 100644 index f5910c2..0000000 --- a/nss/external_tests/ssl_gtest/test_io.h +++ /dev/null @@ -1,141 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef test_io_h_ -#define test_io_h_ - -#include <string.h> -#include <map> -#include <memory> -#include <queue> -#include <string> -#include <ostream> - -#include "prio.h" - -namespace nss_test { - -class DataBuffer; -class Packet; -class DummyPrSocket; // Fwd decl. - -// Allow us to inspect a packet before it is written. -class PacketFilter { - public: - virtual ~PacketFilter() {} - - // The packet filter takes input and has the option of mutating it. - // - // A filter that modifies the data places the modified data in *output and - // returns true. A filter that does not modify data returns false, in which - // case the value in *output is ignored. - virtual bool Filter(const DataBuffer& input, DataBuffer* output) = 0; -}; - -enum Mode { STREAM, DGRAM }; - -inline std::ostream& operator<<(std::ostream& os, Mode m) { - return os << ((m == STREAM) ? "TLS" : "DTLS"); -} - -class DummyPrSocket { - public: - ~DummyPrSocket(); - - static PRFileDesc* CreateFD(const std::string& name, - Mode mode); // Returns an FD. - static DummyPrSocket* GetAdapter(PRFileDesc* fd); - - void SetPeer(DummyPrSocket* peer) { peer_ = peer; } - void SetPacketFilter(PacketFilter* filter) { filter_ = filter; } - // Drops peer, packet filter and any outstanding packets. - void Reset(); - - void PacketReceived(const DataBuffer& data); - int32_t Read(void* data, int32_t len); - int32_t Recv(void* buf, int32_t buflen); - int32_t Write(const void* buf, int32_t length); - - Mode mode() const { return mode_; } - bool readable() const { return !input_.empty(); } - bool writable() { return true; } - - private: - DummyPrSocket(const std::string& name, Mode mode) - : name_(name), - mode_(mode), - peer_(nullptr), - input_(), - filter_(nullptr) {} - - const std::string name_; - Mode mode_; - DummyPrSocket* peer_; - std::queue<Packet*> input_; - PacketFilter* filter_; -}; - -// Marker interface. -class PollTarget {}; - -enum Event { READABLE_EVENT, TIMER_EVENT /* Must be last */ }; - -typedef void (*PollCallback)(PollTarget*, Event); - -class Poller { - public: - static Poller* Instance(); // Get a singleton. - static void Shutdown(); // Shut it down. - - class Timer { - public: - Timer(PRTime deadline, PollTarget* target, PollCallback callback) - : deadline_(deadline), target_(target), callback_(callback) {} - void Cancel() { callback_ = nullptr; } - - PRTime deadline_; - PollTarget* target_; - PollCallback callback_; - }; - - void Wait(Event event, DummyPrSocket* adapter, PollTarget* target, - PollCallback cb); - void Cancel(Event event, DummyPrSocket* adapter); - void SetTimer(uint32_t timer_ms, PollTarget* target, PollCallback cb, - Timer** handle); - bool Poll(); - - private: - Poller() : waiters_(), timers_() {} - - class Waiter { - public: - Waiter(DummyPrSocket* io) : io_(io) { - memset(&callbacks_[0], 0, sizeof(callbacks_)); - } - - void WaitFor(Event event, PollCallback callback); - - DummyPrSocket* io_; - PollTarget* targets_[TIMER_EVENT]; - PollCallback callbacks_[TIMER_EVENT]; - }; - - class TimerComparator { - public: - bool operator()(const Timer* lhs, const Timer* rhs) { - return lhs->deadline_ > rhs->deadline_; - } - }; - - static Poller* instance; - std::map<DummyPrSocket*, Waiter*> waiters_; - std::priority_queue<Timer*, std::vector<Timer*>, TimerComparator> timers_; -}; - -} // end of namespace - -#endif diff --git a/nss/external_tests/ssl_gtest/tls_agent.cc b/nss/external_tests/ssl_gtest/tls_agent.cc deleted file mode 100644 index 2a41ecb..0000000 --- a/nss/external_tests/ssl_gtest/tls_agent.cc +++ /dev/null @@ -1,572 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "tls_agent.h" - -#include "pk11func.h" -#include "ssl.h" -#include "sslerr.h" -#include "sslproto.h" -#include "keyhi.h" - -#define GTEST_HAS_RTTI 0 -#include "gtest/gtest.h" - -namespace nss_test { - - -const char* TlsAgent::states[] = {"INIT", "CONNECTING", "CONNECTED", "ERROR"}; - -TlsAgent::TlsAgent(const std::string& name, Role role, Mode mode, SSLKEAType kea) - : name_(name), - mode_(mode), - kea_(kea), - pr_fd_(nullptr), - adapter_(nullptr), - ssl_fd_(nullptr), - role_(role), - state_(STATE_INIT), - falsestart_enabled_(false), - expected_version_(0), - expected_cipher_suite_(0), - expect_resumption_(false), - can_falsestart_hook_called_(false), - sni_hook_called_(false), - auth_certificate_hook_called_(false), - handshake_callback_called_(false), - error_code_(0), - send_ctr_(0), - recv_ctr_(0), - expected_read_error_(false) { - - memset(&info_, 0, sizeof(info_)); - memset(&csinfo_, 0, sizeof(csinfo_)); - SECStatus rv = SSL_VersionRangeGetDefault(mode_ == STREAM ? - ssl_variant_stream : ssl_variant_datagram, - &vrange_); - EXPECT_EQ(SECSuccess, rv); -} - -TlsAgent::~TlsAgent() { - if (adapter_) { - Poller::Instance()->Cancel(READABLE_EVENT, adapter_); - } - - if (pr_fd_) { - PR_Close(pr_fd_); - } - - if (ssl_fd_) { - PR_Close(ssl_fd_); - } -} - -bool TlsAgent::EnsureTlsSetup() { - // Don't set up twice - if (ssl_fd_) return true; - - if (adapter_->mode() == STREAM) { - ssl_fd_ = SSL_ImportFD(nullptr, pr_fd_); - } else { - ssl_fd_ = DTLS_ImportFD(nullptr, pr_fd_); - } - - EXPECT_NE(nullptr, ssl_fd_); - if (!ssl_fd_) return false; - pr_fd_ = nullptr; - - if (role_ == SERVER) { - CERTCertificate* cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); - EXPECT_NE(nullptr, cert); - if (!cert) return false; - - SECKEYPrivateKey* priv = PK11_FindKeyByAnyCert(cert, nullptr); - EXPECT_NE(nullptr, priv); - if (!priv) return false; // Leak cert. - - SECStatus rv = SSL_ConfigSecureServer(ssl_fd_, cert, priv, kea_); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; // Leak cert and key. - - SECKEY_DestroyPrivateKey(priv); - CERT_DestroyCertificate(cert); - - rv = SSL_SNISocketConfigHook(ssl_fd_, SniHook, this); - EXPECT_EQ(SECSuccess, rv); // don't abort, just fail - } else { - SECStatus rv = SSL_SetURL(ssl_fd_, "server"); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; - } - - SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; - - rv = SSL_AuthCertificateHook(ssl_fd_, AuthCertificateHook, this); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; - - rv = SSL_HandshakeCallback(ssl_fd_, HandshakeCallback, this); - EXPECT_EQ(SECSuccess, rv); - if (rv != SECSuccess) return false; - - return true; -} - -void TlsAgent::SetupClientAuth() { - EXPECT_TRUE(EnsureTlsSetup()); - ASSERT_EQ(CLIENT, role_); - - EXPECT_EQ(SECSuccess, - SSL_GetClientAuthDataHook(ssl_fd_, GetClientAuthDataHook, - reinterpret_cast<void*>(this))); -} - -bool TlsAgent::GetClientAuthCredentials(CERTCertificate **cert, - SECKEYPrivateKey **priv) const { - *cert = PK11_FindCertFromNickname(name_.c_str(), nullptr); - EXPECT_NE(nullptr, *cert); - if (!*cert) return false; - - *priv = PK11_FindKeyByAnyCert(*cert, nullptr); - EXPECT_NE(nullptr, *priv); - if (!*priv) return false; // Leak cert. - - return true; -} - -SECStatus TlsAgent::GetClientAuthDataHook(void* self, PRFileDesc* fd, - CERTDistNames* caNames, - CERTCertificate** cert, - SECKEYPrivateKey** privKey) { - TlsAgent* agent = reinterpret_cast<TlsAgent*>(self); - if (agent->GetClientAuthCredentials(cert, privKey)) { - return SECSuccess; - } - return SECFailure; -} - - -void TlsAgent::RequestClientAuth(bool requireAuth) { - EXPECT_TRUE(EnsureTlsSetup()); - ASSERT_EQ(SERVER, role_); - - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd_, SSL_REQUIRE_CERTIFICATE, - requireAuth ? PR_TRUE : PR_FALSE)); - - EXPECT_EQ(SECSuccess, - SSL_AuthCertificateHook(ssl_fd_, &TlsAgent::ClientAuthenticated, - this)); - expect_client_auth_ = true; -} - -void TlsAgent::StartConnect() { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv; - rv = SSL_ResetHandshake(ssl_fd_, role_ == SERVER ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); - SetState(STATE_CONNECTING); -} - -void TlsAgent::DisableCiphersByKeyExchange(SSLKEAType kea) { - EXPECT_TRUE(EnsureTlsSetup()); - - for (size_t i = 0; i < SSL_NumImplementedCiphers; ++i) { - SSLCipherSuiteInfo csinfo; - - SECStatus rv = SSL_GetCipherSuiteInfo(SSL_ImplementedCiphers[i], - &csinfo, sizeof(csinfo)); - ASSERT_EQ(SECSuccess, rv); - - if (csinfo.keaType == kea) { - rv = SSL_CipherPrefSet(ssl_fd_, SSL_ImplementedCiphers[i], PR_FALSE); - EXPECT_EQ(SECSuccess, rv); - } - } -} - -void TlsAgent::SetSessionTicketsEnabled(bool en) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_ENABLE_SESSION_TICKETS, - en ? PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SetSessionCacheEnabled(bool en) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd_, SSL_NO_CACHE, - en ? PR_FALSE : PR_TRUE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SetVersionRange(uint16_t minver, uint16_t maxver) { - vrange_.min = minver; - vrange_.max = maxver; - - if (ssl_fd_) { - SECStatus rv = SSL_VersionRangeSet(ssl_fd_, &vrange_); - EXPECT_EQ(SECSuccess, rv); - } -} - -void TlsAgent::SetExpectedVersion(uint16_t version) { - expected_version_ = version; -} - -void TlsAgent::SetExpectedReadError(bool err) { - expected_read_error_ = err; -} - -void TlsAgent::SetSignatureAlgorithms(const SSLSignatureAndHashAlg* algorithms, - size_t count) { - EXPECT_TRUE(EnsureTlsSetup()); - EXPECT_LE(count, SSL_SignatureMaxCount()); - EXPECT_EQ(SECSuccess, SSL_SignaturePrefSet(ssl_fd_, algorithms, - static_cast<unsigned int>(count))); - EXPECT_EQ(SECFailure, SSL_SignaturePrefSet(ssl_fd_, algorithms, 0)) - << "setting no algorithms should fail and do nothing"; - - std::vector<SSLSignatureAndHashAlg> configuredAlgorithms(count); - unsigned int configuredCount; - EXPECT_EQ(SECFailure, - SSL_SignaturePrefGet(ssl_fd_, nullptr, &configuredCount, 1)) - << "get algorithms, algorithms is nullptr"; - EXPECT_EQ(SECFailure, - SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0], - &configuredCount, 0)) - << "get algorithms, too little space"; - EXPECT_EQ(SECFailure, - SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0], - nullptr, configuredAlgorithms.size())) - << "get algorithms, algCountOut is nullptr"; - - EXPECT_EQ(SECSuccess, - SSL_SignaturePrefGet(ssl_fd_, &configuredAlgorithms[0], - &configuredCount, - configuredAlgorithms.size())); - // SignaturePrefSet drops unsupported algorithms silently, so the number that - // are configured might be fewer. - EXPECT_LE(configuredCount, count); - unsigned int i = 0; - for (unsigned int j = 0; j < count && i < configuredCount; ++j) { - if (i < configuredCount && - algorithms[j].hashAlg == configuredAlgorithms[i].hashAlg && - algorithms[j].sigAlg == configuredAlgorithms[i].sigAlg) { - ++i; - } - } - EXPECT_EQ(i, configuredCount) << "algorithms in use were all set"; -} - -void TlsAgent::CheckKEAType(SSLKEAType type) const { - EXPECT_EQ(STATE_CONNECTED, state_); - EXPECT_EQ(type, csinfo_.keaType); -} - -void TlsAgent::CheckAuthType(SSLAuthType type) const { - EXPECT_EQ(STATE_CONNECTED, state_); - EXPECT_EQ(type, csinfo_.authAlgorithm); -} - -void TlsAgent::EnableFalseStart() { - EXPECT_TRUE(EnsureTlsSetup()); - - falsestart_enabled_ = true; - EXPECT_EQ(SECSuccess, - SSL_SetCanFalseStartCallback(ssl_fd_, CanFalseStartCallback, this)); - EXPECT_EQ(SECSuccess, - SSL_OptionSet(ssl_fd_, SSL_ENABLE_FALSE_START, PR_TRUE)); -} - -void TlsAgent::ExpectResumption() { - expect_resumption_ = true; -} - -void TlsAgent::EnableAlpn(const uint8_t* val, size_t len) { - EXPECT_TRUE(EnsureTlsSetup()); - - EXPECT_EQ(SECSuccess, SSL_OptionSet(ssl_fd_, SSL_ENABLE_ALPN, PR_TRUE)); - EXPECT_EQ(SECSuccess, SSL_SetNextProtoNego(ssl_fd_, val, len)); -} - -void TlsAgent::CheckAlpn(SSLNextProtoState expected_state, - const std::string& expected) const { - SSLNextProtoState state; - char chosen[10]; - unsigned int chosen_len; - SECStatus rv = SSL_GetNextProto(ssl_fd_, &state, - reinterpret_cast<unsigned char*>(chosen), - &chosen_len, sizeof(chosen)); - EXPECT_EQ(SECSuccess, rv); - EXPECT_EQ(expected_state, state); - EXPECT_EQ(expected, std::string(chosen, chosen_len)); -} - -void TlsAgent::EnableSrtp() { - EXPECT_TRUE(EnsureTlsSetup()); - const uint16_t ciphers[] = { - SRTP_AES128_CM_HMAC_SHA1_80, SRTP_AES128_CM_HMAC_SHA1_32 - }; - EXPECT_EQ(SECSuccess, SSL_SetSRTPCiphers(ssl_fd_, ciphers, - PR_ARRAY_SIZE(ciphers))); -} - -void TlsAgent::CheckSrtp() const { - uint16_t actual; - EXPECT_EQ(SECSuccess, SSL_GetSRTPCipher(ssl_fd_, &actual)); - EXPECT_EQ(SRTP_AES128_CM_HMAC_SHA1_80, actual); -} - -void TlsAgent::CheckErrorCode(int32_t expected) const { - EXPECT_EQ(STATE_ERROR, state_); - EXPECT_EQ(expected, error_code_); -} - -void TlsAgent::CheckPreliminaryInfo() { - SSLPreliminaryChannelInfo info; - EXPECT_EQ(SECSuccess, - SSL_GetPreliminaryChannelInfo(ssl_fd_, &info, sizeof(info))); - EXPECT_TRUE(info.valuesSet & ssl_preinfo_version); - EXPECT_TRUE(info.valuesSet & ssl_preinfo_cipher_suite); - - // A version of 0 is invalid and indicates no expectation. This value is - // initialized to 0 so that tests that don't explicitly set an expected - // version can negotiate a version. - if (!expected_version_) { - expected_version_ = info.protocolVersion; - } - EXPECT_EQ(expected_version_, info.protocolVersion); - - // As with the version; 0 is the null cipher suite (and also invalid). - if (!expected_cipher_suite_) { - expected_cipher_suite_ = info.cipherSuite; - } - EXPECT_EQ(expected_cipher_suite_, info.cipherSuite); -} - -// Check that all the expected callbacks have been called. -void TlsAgent::CheckCallbacks() const { - // If false start happens, the handshake is reported as being complete at the - // point that false start happens. - if (expect_resumption_ || !falsestart_enabled_) { - EXPECT_TRUE(handshake_callback_called_); - } - - // These callbacks shouldn't fire if we are resuming. - if (role_ == SERVER) { - EXPECT_EQ(!expect_resumption_, sni_hook_called_); - } else { - EXPECT_EQ(!expect_resumption_, auth_certificate_hook_called_); - // Note that this isn't unconditionally called, even with false start on. - // But the callback is only skipped if a cipher that is ridiculously weak - // (80 bits) is chosen. Don't test that: plan to remove bad ciphers. - EXPECT_EQ(falsestart_enabled_ && !expect_resumption_, - can_falsestart_hook_called_); - } -} - -void TlsAgent::Connected() { - LOG("Handshake success"); - CheckCallbacks(); - - SECStatus rv = SSL_GetChannelInfo(ssl_fd_, &info_, sizeof(info_)); - EXPECT_EQ(SECSuccess, rv); - - // Preliminary values are exposed through callbacks during the handshake. - // If either expected values were set or the callbacks were called, check - // that the final values are correct. - EXPECT_EQ(expected_version_, info_.protocolVersion); - EXPECT_EQ(expected_cipher_suite_, info_.cipherSuite); - - rv = SSL_GetCipherSuiteInfo(info_.cipherSuite, &csinfo_, sizeof(csinfo_)); - EXPECT_EQ(SECSuccess, rv); - - SetState(STATE_CONNECTED); -} - -void TlsAgent::EnableExtendedMasterSecret() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd_, - SSL_ENABLE_EXTENDED_MASTER_SECRET, - PR_TRUE); - - ASSERT_EQ(SECSuccess, rv); -} - -void TlsAgent::CheckExtendedMasterSecret(bool expected) { - ASSERT_EQ(expected, info_.extendedMasterSecretUsed != PR_FALSE) - << "unexpected extended master secret state for " << name_; -} - -void TlsAgent::DisableRollbackDetection() { - ASSERT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd_, - SSL_ROLLBACK_DETECTION, - PR_FALSE); - - ASSERT_EQ(SECSuccess, rv); -} - -void TlsAgent::Handshake() { - SECStatus rv = SSL_ForceHandshake(ssl_fd_); - if (rv == SECSuccess) { - Connected(); - - Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, - &TlsAgent::ReadableCallback); - - return; - } - - int32_t err = PR_GetError(); - switch (err) { - case PR_WOULD_BLOCK_ERROR: - LOG("Would have blocked"); - // TODO(ekr@rtfm.com): set DTLS timeouts - Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, - &TlsAgent::ReadableCallback); - return; - break; - - // TODO(ekr@rtfm.com): needs special case for DTLS - case SSL_ERROR_RX_MALFORMED_HANDSHAKE: - default: - if (IS_SSL_ERROR(err)) { - LOG("Handshake failed with SSL error " << err - SSL_ERROR_BASE); - } else { - LOG("Handshake failed with error " << err); - } - error_code_ = err; - SetState(STATE_ERROR); - return; - } -} - -void TlsAgent::PrepareForRenegotiate() { - EXPECT_EQ(STATE_CONNECTED, state_); - - SetState(STATE_CONNECTING); -} - -void TlsAgent::StartRenegotiate() { - PrepareForRenegotiate(); - - SECStatus rv = SSL_ReHandshake(ssl_fd_, PR_TRUE); - EXPECT_EQ(SECSuccess, rv); -} - -void TlsAgent::SendData(size_t bytes, size_t blocksize) { - uint8_t block[4096]; - - ASSERT_LT(blocksize, sizeof(block)); - - while(bytes) { - size_t tosend = std::min(blocksize, bytes); - - for(size_t i = 0; i < tosend; ++i) { - block[i] = 0xff & send_ctr_; - ++send_ctr_; - } - - LOG("Writing " << tosend << " bytes"); - int32_t rv = PR_Write(ssl_fd_, block, tosend); - ASSERT_EQ(tosend, static_cast<size_t>(rv)); - - bytes -= tosend; - } -} - -void TlsAgent::ReadBytes() { - uint8_t block[1024]; - - LOG("Reading application data from socket"); - - int32_t rv = PR_Read(ssl_fd_, block, sizeof(block)); - - int32_t err = PR_GetError(); - if (err != PR_WOULD_BLOCK_ERROR) { - if (expected_read_error_) { - error_code_ = err; - } else { - ASSERT_LE(0, rv); - size_t count = static_cast<size_t>(rv); - LOG("Read " << count << " bytes"); - for (size_t i = 0; i < count; ++i) { - ASSERT_EQ(recv_ctr_ & 0xff, block[i]); - recv_ctr_++; - } - } - } - - Poller::Instance()->Wait(READABLE_EVENT, adapter_, this, - &TlsAgent::ReadableCallback); -} - -void TlsAgent::ResetSentBytes() { - send_ctr_ = 0; -} - -void TlsAgent::ConfigureSessionCache(SessionResumptionMode mode) { - EXPECT_TRUE(EnsureTlsSetup()); - - SECStatus rv = SSL_OptionSet(ssl_fd_, - SSL_NO_CACHE, - mode & RESUME_SESSIONID ? - PR_FALSE : PR_TRUE); - EXPECT_EQ(SECSuccess, rv); - - rv = SSL_OptionSet(ssl_fd_, - SSL_ENABLE_SESSION_TICKETS, - mode & RESUME_TICKET ? - PR_TRUE : PR_FALSE); - EXPECT_EQ(SECSuccess, rv); -} - -static const std::string kTlsRolesAllArr[] = {"CLIENT", "SERVER"}; -::testing::internal::ParamGenerator<std::string> - TlsAgentTestBase::kTlsRolesAll = ::testing::ValuesIn(kTlsRolesAllArr); - -void TlsAgentTestBase::Init() { - agent_ = new TlsAgent( - role_ == TlsAgent::CLIENT ? "client" : "server", - role_, mode_, kea_); - agent_->Init(); - fd_ = DummyPrSocket::CreateFD("dummy", mode_); - agent_->adapter()->SetPeer( - DummyPrSocket::GetAdapter(fd_)); - agent_->StartConnect(); -} - -void TlsAgentTestBase::EnsureInit() { - if (!agent_) { - Init(); - } -} - -void TlsAgentTestBase::ProcessMessage(const DataBuffer& buffer, - TlsAgent::State expected_state, - int32_t error_code) { - EnsureInit(); - agent_->adapter()->PacketReceived(buffer); - agent_->Handshake(); - - ASSERT_EQ(expected_state, agent_->state()); - - if (expected_state == TlsAgent::STATE_ERROR) { - ASSERT_EQ(error_code, agent_->error_code()); - } -} - -} // namespace nss_test diff --git a/nss/external_tests/ssl_gtest/tls_agent.h b/nss/external_tests/ssl_gtest/tls_agent.h deleted file mode 100644 index f15de13..0000000 --- a/nss/external_tests/ssl_gtest/tls_agent.h +++ /dev/null @@ -1,295 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef tls_agent_h_ -#define tls_agent_h_ - -#include "prio.h" -#include "ssl.h" - -#include <iostream> - -#include "test_io.h" - -#define GTEST_HAS_RTTI 0 -#include "gtest/gtest.h" - -namespace nss_test { - -#define LOG(msg) std::cerr << name_ << ": " << msg << std::endl - -enum SessionResumptionMode { - RESUME_NONE = 0, - RESUME_SESSIONID = 1, - RESUME_TICKET = 2, - RESUME_BOTH = RESUME_SESSIONID | RESUME_TICKET -}; - -class TlsAgent : public PollTarget { - public: - enum Role { CLIENT, SERVER }; - enum State { STATE_INIT, STATE_CONNECTING, STATE_CONNECTED, STATE_ERROR }; - - TlsAgent(const std::string& name, Role role, Mode mode, SSLKEAType kea); - virtual ~TlsAgent(); - - bool Init() { - pr_fd_ = DummyPrSocket::CreateFD(name_, mode_); - if (!pr_fd_) return false; - - adapter_ = DummyPrSocket::GetAdapter(pr_fd_); - if (!adapter_) return false; - - return true; - } - - void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); } - - void SetPacketFilter(PacketFilter* filter) { - adapter_->SetPacketFilter(filter); - } - - - void StartConnect(); - void CheckKEAType(SSLKEAType type) const; - void CheckAuthType(SSLAuthType type) const; - - void Handshake(); - // Marks the internal state as CONNECTING in anticipation of renegotiation. - void PrepareForRenegotiate(); - // Prepares for renegotiation, then actually triggers it. - void StartRenegotiate(); - void DisableCiphersByKeyExchange(SSLKEAType kea); - bool EnsureTlsSetup(); - - void SetupClientAuth(); - void RequestClientAuth(bool requireAuth); - bool GetClientAuthCredentials(CERTCertificate** cert, - SECKEYPrivateKey** priv) const; - - void ConfigureSessionCache(SessionResumptionMode mode); - void SetSessionTicketsEnabled(bool en); - void SetSessionCacheEnabled(bool en); - void SetVersionRange(uint16_t minver, uint16_t maxver); - void CheckPreliminaryInfo(); - void SetExpectedVersion(uint16_t version); - void SetExpectedReadError(bool err); - void EnableFalseStart(); - void ExpectResumption(); - void SetSignatureAlgorithms(const SSLSignatureAndHashAlg* algorithms, - size_t count); - void EnableAlpn(const uint8_t* val, size_t len); - void CheckAlpn(SSLNextProtoState expected_state, - const std::string& expected) const; - void EnableSrtp(); - void CheckSrtp() const; - void CheckErrorCode(int32_t expected) const; - void SendData(size_t bytes, size_t blocksize = 1024); - void ReadBytes(); - void ResetSentBytes(); // Hack to test drops. - void EnableExtendedMasterSecret(); - void CheckExtendedMasterSecret(bool expected); - void DisableRollbackDetection(); - - State state() const { return state_; } - - const char* state_str() const { return state_str(state()); } - - const char* state_str(State state) const { return states[state]; } - - PRFileDesc* ssl_fd() { return ssl_fd_; } - DummyPrSocket* adapter() { return adapter_; } - - uint16_t min_version() const { return vrange_.min; } - uint16_t max_version() const { return vrange_.max; } - uint16_t version() const { - EXPECT_EQ(STATE_CONNECTED, state_); - return info_.protocolVersion; - } - - bool cipher_suite(int16_t* cipher_suite) const { - if (state_ != STATE_CONNECTED) return false; - - *cipher_suite = info_.cipherSuite; - return true; - } - - std::string cipher_suite_name() const { - if (state_ != STATE_CONNECTED) return "UNKNOWN"; - - return csinfo_.cipherSuiteName; - } - - std::vector<uint8_t> session_id() const { - return std::vector<uint8_t>(info_.sessionID, - info_.sessionID + info_.sessionIDLength); - } - - size_t received_bytes() const { return recv_ctr_; } - int32_t error_code() const { return error_code_; } - - private: - const static char* states[]; - - void SetState(State state) { - if (state_ == state) return; - - LOG("Changing state from " << state_str(state_) << " to " - << state_str(state)); - state_ = state; - } - - // Dummy auth certificate hook. - static SECStatus AuthCertificateHook(void* arg, PRFileDesc* fd, - PRBool checksig, PRBool isServer) { - TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); - agent->CheckPreliminaryInfo(); - agent->auth_certificate_hook_called_ = true; - return SECSuccess; - } - - // Client auth certificate hook. - static SECStatus ClientAuthenticated(void* arg, PRFileDesc* fd, - PRBool checksig, PRBool isServer) { - TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); - EXPECT_TRUE(agent->expect_client_auth_); - EXPECT_TRUE(isServer); - return SECSuccess; - } - - static SECStatus GetClientAuthDataHook(void* self, PRFileDesc* fd, - CERTDistNames* caNames, - CERTCertificate** cert, - SECKEYPrivateKey** privKey); - - static void ReadableCallback(PollTarget* self, Event event) { - TlsAgent* agent = static_cast<TlsAgent*>(self); - agent->ReadableCallback_int(); - } - - - void ReadableCallback_int() { - LOG("Readable"); - switch (state_) { - case STATE_CONNECTING: - Handshake(); - break; - case STATE_CONNECTED: - ReadBytes(); - break; - default: - break; - } - } - - static PRInt32 SniHook(PRFileDesc *fd, const SECItem *srvNameArr, - PRUint32 srvNameArrSize, - void *arg) { - TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); - agent->CheckPreliminaryInfo(); - agent->sni_hook_called_ = true; - return SSL_SNI_CURRENT_CONFIG_IS_USED; - } - - static SECStatus CanFalseStartCallback(PRFileDesc *fd, void *arg, - PRBool *canFalseStart) { - TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); - agent->CheckPreliminaryInfo(); - EXPECT_TRUE(agent->falsestart_enabled_); - agent->can_falsestart_hook_called_ = true; - *canFalseStart = true; - return SECSuccess; - } - - static void HandshakeCallback(PRFileDesc *fd, void *arg) { - TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg); - agent->CheckPreliminaryInfo(); - agent->handshake_callback_called_ = true; - } - - void CheckCallbacks() const; - void Connected(); - - const std::string name_; - Mode mode_; - SSLKEAType kea_; - PRFileDesc* pr_fd_; - DummyPrSocket* adapter_; - PRFileDesc* ssl_fd_; - Role role_; - State state_; - bool falsestart_enabled_; - uint16_t expected_version_; - uint16_t expected_cipher_suite_; - bool expect_resumption_; - bool expect_client_auth_; - bool can_falsestart_hook_called_; - bool sni_hook_called_; - bool auth_certificate_hook_called_; - bool handshake_callback_called_; - SSLChannelInfo info_; - SSLCipherSuiteInfo csinfo_; - SSLVersionRange vrange_; - int32_t error_code_; - size_t send_ctr_; - size_t recv_ctr_; - bool expected_read_error_; -}; - -class TlsAgentTestBase : public ::testing::Test { - public: - static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll; - - TlsAgentTestBase(TlsAgent::Role role, - Mode mode) : agent_(nullptr), - fd_(nullptr), - role_(role), - mode_(mode), - kea_(ssl_kea_rsa) {} - ~TlsAgentTestBase() { - delete agent_; - if (fd_) { - PR_Close(fd_); - } - } - - static inline TlsAgent::Role ToRole(const std::string& str) { - return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER; - } - - static inline Mode ToMode(const std::string& str) { - return str == "TLS" ? STREAM : DGRAM; - } - - void Init(); - - protected: - void EnsureInit(); - void ProcessMessage(const DataBuffer& buffer, - TlsAgent::State expected_state, - int32_t error_code = 0); - - - TlsAgent* agent_; - PRFileDesc* fd_; - TlsAgent::Role role_; - Mode mode_; - SSLKEAType kea_; -}; - -class TlsAgentTest : - public TlsAgentTestBase, - public ::testing::WithParamInterface - <std::tuple<std::string,std::string>> { - public: - TlsAgentTest() : - TlsAgentTestBase(ToRole(std::get<0>(GetParam())), - ToMode(std::get<1>(GetParam()))) {} -}; - -} // namespace nss_test - -#endif diff --git a/nss/external_tests/ssl_gtest/tls_connect.cc b/nss/external_tests/ssl_gtest/tls_connect.cc deleted file mode 100644 index 34c6d12..0000000 --- a/nss/external_tests/ssl_gtest/tls_connect.cc +++ /dev/null @@ -1,295 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "tls_connect.h" - -#include <iostream> - -#include "sslproto.h" -#include "gtest_utils.h" - -extern std::string g_working_dir_path; - -namespace nss_test { - -static const std::string kTlsModesStreamArr[] = {"TLS"}; -::testing::internal::ParamGenerator<std::string> - TlsConnectTestBase::kTlsModesStream = ::testing::ValuesIn(kTlsModesStreamArr); -static const std::string kTlsModesAllArr[] = {"TLS", "DTLS"}; -::testing::internal::ParamGenerator<std::string> - TlsConnectTestBase::kTlsModesAll = ::testing::ValuesIn(kTlsModesAllArr); -static const uint16_t kTlsV10Arr[] = {SSL_LIBRARY_VERSION_TLS_1_0}; -::testing::internal::ParamGenerator<uint16_t> - TlsConnectTestBase::kTlsV10 = ::testing::ValuesIn(kTlsV10Arr); -static const uint16_t kTlsV11Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1}; -::testing::internal::ParamGenerator<uint16_t> - TlsConnectTestBase::kTlsV11 = ::testing::ValuesIn(kTlsV11Arr); -static const uint16_t kTlsV11V12Arr[] = {SSL_LIBRARY_VERSION_TLS_1_1, - SSL_LIBRARY_VERSION_TLS_1_2}; -::testing::internal::ParamGenerator<uint16_t> - TlsConnectTestBase::kTlsV11V12 = ::testing::ValuesIn(kTlsV11V12Arr); -// TODO: add TLS 1.3 -static const uint16_t kTlsV12PlusArr[] = {SSL_LIBRARY_VERSION_TLS_1_2}; -::testing::internal::ParamGenerator<uint16_t> - TlsConnectTestBase::kTlsV12Plus = ::testing::ValuesIn(kTlsV12PlusArr); - -static std::string VersionString(uint16_t version) { - switch(version) { - case 0: - return "(no version)"; - case SSL_LIBRARY_VERSION_TLS_1_0: - return "1.0"; - case SSL_LIBRARY_VERSION_TLS_1_1: - return "1.1"; - case SSL_LIBRARY_VERSION_TLS_1_2: - return "1.2"; - default: - std::cerr << "Invalid version: " << version << std::endl; - EXPECT_TRUE(false); - return ""; - } -} - -TlsConnectTestBase::TlsConnectTestBase(Mode mode, uint16_t version) - : mode_(mode), - client_(new TlsAgent("client", TlsAgent::CLIENT, mode_, ssl_kea_rsa)), - server_(new TlsAgent("server", TlsAgent::SERVER, mode_, ssl_kea_rsa)), - version_(version), - expected_resumption_mode_(RESUME_NONE), - session_ids_(), - expect_extended_master_secret_(false) { - std::cerr << "Version: " << mode_ << " " << VersionString(version_) << std::endl; -} - -TlsConnectTestBase::~TlsConnectTestBase() { -} - -void TlsConnectTestBase::SetUp() { - // Configure a fresh session cache. - SSL_ConfigServerSessionIDCache(1024, 0, 0, g_working_dir_path.c_str()); - - // Clear statistics. - SSL3Statistics* stats = SSL_GetStatistics(); - memset(stats, 0, sizeof(*stats)); - - Init(); -} - -void TlsConnectTestBase::TearDown() { - delete client_; - delete server_; - - SSL_ClearSessionCache(); - SSL_ShutdownServerSessionIDCache(); -} - -void TlsConnectTestBase::Init() { - EXPECT_TRUE(client_->Init()); - EXPECT_TRUE(server_->Init()); - - client_->SetPeer(server_); - server_->SetPeer(client_); - - if (version_) { - client_->SetVersionRange(version_, version_); - server_->SetVersionRange(version_, version_); - } -} - -void TlsConnectTestBase::Reset(const std::string& server_name, SSLKEAType kea) { - delete client_; - delete server_; - - client_ = new TlsAgent("client", TlsAgent::CLIENT, mode_, kea); - server_ = new TlsAgent(server_name, TlsAgent::SERVER, mode_, kea); - - Init(); -} - -void TlsConnectTestBase::ResetRsa() { - Reset("server", ssl_kea_rsa); -} - -void TlsConnectTestBase::ResetEcdsa() { - Reset("ecdsa", ssl_kea_ecdh); -} - -void TlsConnectTestBase::ExpectResumption(SessionResumptionMode expected) { - expected_resumption_mode_ = expected; - if (expected != RESUME_NONE) { - client_->ExpectResumption(); - server_->ExpectResumption(); - } -} - -void TlsConnectTestBase::EnsureTlsSetup() { - EXPECT_TRUE(client_->EnsureTlsSetup()); - EXPECT_TRUE(server_->EnsureTlsSetup()); -} - -void TlsConnectTestBase::Handshake() { - client_->Handshake(); - server_->Handshake(); - - ASSERT_TRUE_WAIT((client_->state() != TlsAgent::STATE_CONNECTING) && - (server_->state() != TlsAgent::STATE_CONNECTING), - 5000); -} - -void TlsConnectTestBase::EnableExtendedMasterSecret() { - client_->EnableExtendedMasterSecret(); - server_->EnableExtendedMasterSecret(); - ExpectExtendedMasterSecret(true); -} - -void TlsConnectTestBase::Connect() { - server_->StartConnect(); - client_->StartConnect(); - Handshake(); - CheckConnected(); -} - -void TlsConnectTestBase::CheckConnected() { - // Check the version is as expected - EXPECT_EQ(client_->version(), server_->version()); - EXPECT_EQ(std::min(client_->max_version(), - server_->max_version()), - client_->version()); - - EXPECT_EQ(TlsAgent::STATE_CONNECTED, client_->state()); - EXPECT_EQ(TlsAgent::STATE_CONNECTED, server_->state()); - - int16_t cipher_suite1, cipher_suite2; - bool ret = client_->cipher_suite(&cipher_suite1); - EXPECT_TRUE(ret); - ret = server_->cipher_suite(&cipher_suite2); - EXPECT_TRUE(ret); - EXPECT_EQ(cipher_suite1, cipher_suite2); - - std::cerr << "Connected with version " << client_->version() - << " cipher suite " << client_->cipher_suite_name() - << std::endl; - - // Check and store session ids. - std::vector<uint8_t> sid_c1 = client_->session_id(); - EXPECT_EQ(32U, sid_c1.size()); - std::vector<uint8_t> sid_s1 = server_->session_id(); - EXPECT_EQ(32U, sid_s1.size()); - EXPECT_EQ(sid_c1, sid_s1); - session_ids_.push_back(sid_c1); - - CheckResumption(expected_resumption_mode_); - // Check whether the extended master secret extension was negotiated. - CheckExtendedMasterSecret(); -} - -void TlsConnectTestBase::ConnectExpectFail() { - server_->StartConnect(); - client_->StartConnect(); - Handshake(); - - ASSERT_EQ(TlsAgent::STATE_ERROR, client_->state()); - ASSERT_EQ(TlsAgent::STATE_ERROR, server_->state()); -} - -void TlsConnectTestBase::SetExpectedVersion(uint16_t version) { - client_->SetExpectedVersion(version); - server_->SetExpectedVersion(version); -} - -void TlsConnectTestBase::DisableDheCiphers() { - client_->DisableCiphersByKeyExchange(ssl_kea_dh); - server_->DisableCiphersByKeyExchange(ssl_kea_dh); -} - -void TlsConnectTestBase::DisableEcdheCiphers() { - client_->DisableCiphersByKeyExchange(ssl_kea_ecdh); - server_->DisableCiphersByKeyExchange(ssl_kea_ecdh); -} - -void TlsConnectTestBase::DisableDheAndEcdheCiphers() { - DisableDheCiphers(); - DisableEcdheCiphers(); -} - -void TlsConnectTestBase::ConfigureSessionCache(SessionResumptionMode client, - SessionResumptionMode server) { - client_->ConfigureSessionCache(client); - server_->ConfigureSessionCache(server); -} - -void TlsConnectTestBase::CheckResumption(SessionResumptionMode expected) { - EXPECT_NE(RESUME_BOTH, expected); - - int resume_ct = expected ? 1 : 0; - int stateless_ct = (expected & RESUME_TICKET) ? 1 : 0; - - SSL3Statistics* stats = SSL_GetStatistics(); - EXPECT_EQ(resume_ct, stats->hch_sid_cache_hits); - EXPECT_EQ(resume_ct, stats->hsh_sid_cache_hits); - - EXPECT_EQ(stateless_ct, stats->hch_sid_stateless_resumes); - EXPECT_EQ(stateless_ct, stats->hsh_sid_stateless_resumes); - - if (resume_ct) { - // Check that the last two session ids match. - EXPECT_GE(2U, session_ids_.size()); - EXPECT_EQ(session_ids_[session_ids_.size()-1], - session_ids_[session_ids_.size()-2]); - } -} - -void TlsConnectTestBase::EnableAlpn() { - // A simple value of "a", "b". Note that the preferred value of "a" is placed - // at the end, because the NSS API follows the now defunct NPN specification, - // which places the preferred (and default) entry at the end of the list. - // NSS will move this final entry to the front when used with ALPN. - static const uint8_t val[] = { 0x01, 0x62, 0x01, 0x61 }; - client_->EnableAlpn(val, sizeof(val)); - server_->EnableAlpn(val, sizeof(val)); -} - -void TlsConnectTestBase::EnableSrtp() { - client_->EnableSrtp(); - server_->EnableSrtp(); -} - -void TlsConnectTestBase::CheckSrtp() const { - client_->CheckSrtp(); - server_->CheckSrtp(); -} - -void TlsConnectTestBase::SendReceive() { - client_->SendData(50); - server_->SendData(50); - WAIT_(client_->received_bytes() == 50U && - server_->received_bytes() == 50U, 2000); - ASSERT_EQ(50U, client_->received_bytes()); - ASSERT_EQ(50U, server_->received_bytes()); -} - -void TlsConnectTestBase::ExpectExtendedMasterSecret(bool expected) { - expect_extended_master_secret_ = expected; -} - -void TlsConnectTestBase::CheckExtendedMasterSecret() { - client_->CheckExtendedMasterSecret(expect_extended_master_secret_); - server_->CheckExtendedMasterSecret(expect_extended_master_secret_); -} - -TlsConnectGeneric::TlsConnectGeneric() - : TlsConnectTestBase(TlsConnectTestBase::ToMode(std::get<0>(GetParam())), - std::get<1>(GetParam())) {} - -TlsConnectPre12::TlsConnectPre12() - : TlsConnectTestBase(TlsConnectTestBase::ToMode(std::get<0>(GetParam())), - std::get<1>(GetParam())) {} - -TlsConnectTls12::TlsConnectTls12() - : TlsConnectTestBase(TlsConnectTestBase::ToMode(GetParam()), - SSL_LIBRARY_VERSION_TLS_1_2) {} - -} // namespace nss_test diff --git a/nss/external_tests/ssl_gtest/tls_connect.h b/nss/external_tests/ssl_gtest/tls_connect.h deleted file mode 100644 index 88caf59..0000000 --- a/nss/external_tests/ssl_gtest/tls_connect.h +++ /dev/null @@ -1,133 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef tls_connect_h_ -#define tls_connect_h_ - -#include <tuple> - -#include "sslt.h" - -#include "tls_agent.h" - -#define GTEST_HAS_RTTI 0 -#include "gtest/gtest.h" - -namespace nss_test { - -// A generic TLS connection test base. -class TlsConnectTestBase : public ::testing::Test { - public: - static ::testing::internal::ParamGenerator<std::string> kTlsModesStream; - static ::testing::internal::ParamGenerator<std::string> kTlsModesAll; - static ::testing::internal::ParamGenerator<uint16_t> kTlsV10; - static ::testing::internal::ParamGenerator<uint16_t> kTlsV11; - static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12; - static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus; - - static inline Mode ToMode(const std::string& str) { - return str == "TLS" ? STREAM : DGRAM; - } - - TlsConnectTestBase(Mode mode, uint16_t version); - virtual ~TlsConnectTestBase(); - - void SetUp(); - void TearDown(); - - // Initialize client and server. - void Init(); - // Re-initialize client and server with the default RSA cert. - void ResetRsa(); - // Re-initialize client and server with an ECDSA cert on the server - // and some ECDHE suites. - void ResetEcdsa(); - // Make sure TLS is configured for a connection. - void EnsureTlsSetup(); - - // Run the handshake. - void Handshake(); - // Connect and check that it works. - void Connect(); - // Check that the connection was successfully established. - void CheckConnected(); - // Connect and expect it to fail. - void ConnectExpectFail(); - - void SetExpectedVersion(uint16_t version); - // Expect resumption of a particular type. - void ExpectResumption(SessionResumptionMode expected); - void DisableDheAndEcdheCiphers(); - void DisableDheCiphers(); - void DisableEcdheCiphers(); - void EnableExtendedMasterSecret(); - void ConfigureSessionCache(SessionResumptionMode client, - SessionResumptionMode server); - void EnableAlpn(); - void EnableSrtp(); - void CheckSrtp() const; - void SendReceive(); - void ExpectExtendedMasterSecret(bool expected); - - protected: - Mode mode_; - TlsAgent* client_; - TlsAgent* server_; - uint16_t version_; - SessionResumptionMode expected_resumption_mode_; - std::vector<std::vector<uint8_t>> session_ids_; - - private: - void Reset(const std::string& server_name, SSLKEAType kea); - void CheckResumption(SessionResumptionMode expected); - void CheckExtendedMasterSecret(); - - bool expect_extended_master_secret_; -}; - -// A TLS-only test base. -class TlsConnectStream : public TlsConnectTestBase, - public ::testing::WithParamInterface<uint16_t> { - public: - TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {} -}; - -// A DTLS-only test base. -class TlsConnectDatagram : public TlsConnectTestBase, - public ::testing::WithParamInterface<uint16_t> { - public: - TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {} -}; - -// A generic test class that can be either STREAM or DGRAM and a single version -// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this -// should use TEST_P(). -class TlsConnectGeneric - : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { - public: - TlsConnectGeneric(); -}; - -// A Pre TLS 1.2 generic test. -class TlsConnectPre12 - : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> { - public: - TlsConnectPre12(); -}; - -// A TLS 1.2 only generic test. -class TlsConnectTls12 - : public TlsConnectTestBase, - public ::testing::WithParamInterface<std::string> { - public: - TlsConnectTls12(); -}; - -} // namespace nss_test - -#endif diff --git a/nss/external_tests/ssl_gtest/tls_filter.cc b/nss/external_tests/ssl_gtest/tls_filter.cc deleted file mode 100644 index 07654ee..0000000 --- a/nss/external_tests/ssl_gtest/tls_filter.cc +++ /dev/null @@ -1,244 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "tls_filter.h" - -#include <iostream> - -namespace nss_test { - -bool TlsRecordFilter::Filter(const DataBuffer& input, DataBuffer* output) { - bool changed = false; - size_t output_offset = 0U; - output->Allocate(input.len()); - - TlsParser parser(input); - while (parser.remaining()) { - size_t start = parser.consumed(); - uint8_t content_type; - if (!parser.Read(&content_type)) { - return false; - } - uint32_t version; - if (!parser.Read(&version, 2)) { - return false; - } - - if (IsDtls(version)) { - if (!parser.Skip(8)) { - return false; - } - } - size_t header_len = parser.consumed() - start; - output->Write(output_offset, input.data() + start, header_len); - - DataBuffer record; - if (!parser.ReadVariable(&record, 2)) { - return false; - } - - // Move the offset in the output forward. ApplyFilter() returns the index - // of the end of the record it wrote to the output, so we need to skip - // over the content type and version for the value passed to it. - output_offset = ApplyFilter(content_type, version, record, output, - output_offset + header_len, - &changed); - } - output->Truncate(output_offset); - - // Record how many packets we actually touched. - if (changed) { - ++count_; - } - - return changed; -} - -size_t TlsRecordFilter::ApplyFilter(uint8_t content_type, uint16_t version, - const DataBuffer& record, - DataBuffer* output, - size_t offset, bool* changed) { - const DataBuffer* source = &record; - DataBuffer filtered; - if (FilterRecord(content_type, version, record, &filtered) && - filtered.len() < 0x10000) { - *changed = true; - std::cerr << "record old: " << record << std::endl; - std::cerr << "record new: " << filtered << std::endl; - source = &filtered; - } - - output->Write(offset, source->len(), 2); - output->Write(offset + 2, *source); - return offset + 2 + source->len(); -} - -bool TlsHandshakeFilter::FilterRecord(uint8_t content_type, uint16_t version, - const DataBuffer& input, - DataBuffer* output) { - // Check that the first byte is as requested. - if (content_type != kTlsHandshakeType) { - return false; - } - - bool changed = false; - size_t output_offset = 0U; - output->Allocate(input.len()); // Preallocate a little. - - TlsParser parser(input); - while (parser.remaining()) { - size_t start = parser.consumed(); - uint8_t handshake_type; - if (!parser.Read(&handshake_type)) { - return false; // malformed - } - uint32_t length; - if (!ReadLength(&parser, version, &length)) { - return false; - } - - size_t header_len = parser.consumed() - start; - output->Write(output_offset, input.data() + start, header_len); - - DataBuffer handshake; - if (!parser.Read(&handshake, length)) { - return false; - } - - // Move the offset in the output forward. ApplyFilter() returns the index - // of the end of the message it wrote to the output, so we need to identify - // offsets from the start of the message for length and the handshake - // message. - output_offset = ApplyFilter(version, handshake_type, handshake, - output, output_offset + 1, - output_offset + header_len, - &changed); - } - output->Truncate(output_offset); - return changed; -} - -bool TlsHandshakeFilter::ReadLength(TlsParser* parser, uint16_t version, uint32_t *length) { - if (!parser->Read(length, 3)) { - return false; // malformed - } - - if (!IsDtls(version)) { - return true; // nothing left to do - } - - // Read and check DTLS parameters - if (!parser->Skip(2)) { // sequence number - return false; - } - - uint32_t fragment_offset; - if (!parser->Read(&fragment_offset, 3)) { - return false; - } - - uint32_t fragment_length; - if (!parser->Read(&fragment_length, 3)) { - return false; - } - - // All current tests where we are using this code don't fragment. - return (fragment_offset == 0 && fragment_length == *length); -} - -size_t TlsHandshakeFilter::ApplyFilter( - uint16_t version, uint8_t handshake_type, const DataBuffer& handshake, - DataBuffer* output, size_t length_offset, size_t value_offset, - bool* changed) { - const DataBuffer* source = &handshake; - DataBuffer filtered; - if (FilterHandshake(version, handshake_type, handshake, &filtered) && - filtered.len() < 0x1000000) { - *changed = true; - std::cerr << "handshake old: " << handshake << std::endl; - std::cerr << "handshake new: " << filtered << std::endl; - source = &filtered; - } - - // Back up and overwrite the (two) length field(s): the handshake message - // length and the DTLS fragment length. - output->Write(length_offset, source->len(), 3); - if (IsDtls(version)) { - output->Write(length_offset + 8, source->len(), 3); - } - output->Write(value_offset, *source); - return value_offset + source->len(); -} - -bool TlsInspectorRecordHandshakeMessage::FilterHandshake( - uint16_t version, uint8_t handshake_type, - const DataBuffer& input, DataBuffer* output) { - // Only do this once. - if (buffer_.len()) { - return false; - } - - if (handshake_type == handshake_type_) { - buffer_ = input; - } - return false; -} - - -bool TlsInspectorReplaceHandshakeMessage::FilterHandshake( - uint16_t version, uint8_t handshake_type, - const DataBuffer& input, DataBuffer* output) { - if (handshake_type == handshake_type_) { - *output = buffer_; - return true; - } - - return false; -} - -bool TlsAlertRecorder::FilterRecord(uint8_t content_type, uint16_t version, - const DataBuffer& input, DataBuffer* output) { - if (level_ == kTlsAlertFatal) { // already fatal - return false; - } - if (content_type != kTlsAlertType) { - return false; - } - - std::cerr << "Alert: " << input << std::endl; - - TlsParser parser(input); - uint8_t lvl; - if (!parser.Read(&lvl)) { - return false; - } - if (lvl == kTlsAlertWarning) { // not strong enough - return false; - } - level_ = lvl; - (void)parser.Read(&description_); - return false; -} - -ChainedPacketFilter::~ChainedPacketFilter() { - for (auto it = filters_.begin(); it != filters_.end(); ++it) { - delete *it; - } -} - -bool ChainedPacketFilter::Filter(const DataBuffer& input, DataBuffer* output) { - DataBuffer in(input); - bool changed = false; - for (auto it = filters_.begin(); it != filters_.end(); ++it) { - if ((*it)->Filter(in, output)) { - in = *output; - changed = true; - } - } - return changed; -} - -} // namespace nss_test diff --git a/nss/external_tests/ssl_gtest/tls_filter.h b/nss/external_tests/ssl_gtest/tls_filter.h deleted file mode 100644 index 1eec64b..0000000 --- a/nss/external_tests/ssl_gtest/tls_filter.h +++ /dev/null @@ -1,131 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef tls_filter_h_ -#define tls_filter_h_ - -#include <memory> -#include <vector> - -#include "test_io.h" -#include "tls_parser.h" - -namespace nss_test { - -// Abstract filter that operates on entire (D)TLS records. -class TlsRecordFilter : public PacketFilter { - public: - TlsRecordFilter() : count_(0) {} - - virtual bool Filter(const DataBuffer& input, DataBuffer* output); - - // Report how many packets were altered by the filter. - size_t filtered_packets() const { return count_; } - - protected: - virtual bool FilterRecord(uint8_t content_type, uint16_t version, - const DataBuffer& data, DataBuffer* changed) = 0; - private: - size_t ApplyFilter(uint8_t content_type, uint16_t version, - const DataBuffer& record, DataBuffer* output, - size_t offset, bool* changed); - - size_t count_; -}; - -// Abstract filter that operates on handshake messages rather than records. -// This assumes that the handshake messages are written in a block as entire -// records and that they don't span records or anything crazy like that. -class TlsHandshakeFilter : public TlsRecordFilter { - public: - TlsHandshakeFilter() {} - - // Reads the length from the record header. - // This also reads the DTLS fragment information and checks it. - static bool ReadLength(TlsParser* parser, uint16_t version, uint32_t *length); - - protected: - virtual bool FilterRecord(uint8_t content_type, uint16_t version, - const DataBuffer& input, DataBuffer* output); - virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, - const DataBuffer& input, DataBuffer* output) = 0; - - private: - size_t ApplyFilter(uint16_t version, uint8_t handshake_type, - const DataBuffer& record, DataBuffer* output, - size_t length_offset, size_t value_offset, bool* changed); -}; - -// Make a copy of the first instance of a handshake message. -class TlsInspectorRecordHandshakeMessage : public TlsHandshakeFilter { - public: - TlsInspectorRecordHandshakeMessage(uint8_t handshake_type) - : handshake_type_(handshake_type), buffer_() {} - - virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, - const DataBuffer& input, DataBuffer* output); - - const DataBuffer& buffer() const { return buffer_; } - - private: - uint8_t handshake_type_; - DataBuffer buffer_; -}; - -// Replace all instances of a handshake message. -class TlsInspectorReplaceHandshakeMessage : public TlsHandshakeFilter { - public: - TlsInspectorReplaceHandshakeMessage(uint8_t handshake_type, - const DataBuffer& replacement) - : handshake_type_(handshake_type), buffer_(replacement) {} - - virtual bool FilterHandshake(uint16_t version, uint8_t handshake_type, - const DataBuffer& input, DataBuffer* output); - - private: - uint8_t handshake_type_; - DataBuffer buffer_; -}; - -// Records an alert. If an alert has already been recorded, it won't save the -// new alert unless the old alert is a warning and the new one is fatal. -class TlsAlertRecorder : public TlsRecordFilter { - public: - TlsAlertRecorder() : level_(255), description_(255) {} - - virtual bool FilterRecord(uint8_t content_type, uint16_t version, - const DataBuffer& input, DataBuffer* output); - - uint8_t level() const { return level_; } - uint8_t description() const { return description_; } - - private: - uint8_t level_; - uint8_t description_; -}; - -// Runs multiple packet filters in series. -class ChainedPacketFilter : public PacketFilter { - public: - ChainedPacketFilter() {} - ChainedPacketFilter(const std::vector<PacketFilter*> filters) - : filters_(filters.begin(), filters.end()) {} - virtual ~ChainedPacketFilter(); - - virtual bool Filter(const DataBuffer& input, DataBuffer* output); - - // Takes ownership of the filter. - void Add(PacketFilter* filter) { - filters_.push_back(filter); - } - - private: - std::vector<PacketFilter*> filters_; -}; - -} // namespace nss_test - -#endif diff --git a/nss/external_tests/ssl_gtest/tls_parser.cc b/nss/external_tests/ssl_gtest/tls_parser.cc deleted file mode 100644 index 1d56fff..0000000 --- a/nss/external_tests/ssl_gtest/tls_parser.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#include "tls_parser.h" - -namespace nss_test { - -bool TlsParser::Read(uint8_t* val) { - if (remaining() < 1) { - return false; - } - *val = *ptr(); - consume(1); - return true; -} - -bool TlsParser::Read(uint32_t* val, size_t size) { - if (size > sizeof(uint32_t)) { - return false; - } - - uint32_t v = 0; - for (size_t i = 0; i < size; ++i) { - uint8_t tmp; - if (!Read(&tmp)) { - return false; - } - - v = (v << 8) | tmp; - } - - *val = v; - return true; -} - -bool TlsParser::Read(DataBuffer* val, size_t len) { - if (remaining() < len) { - return false; - } - - val->Assign(ptr(), len); - consume(len); - return true; -} - -bool TlsParser::ReadVariable(DataBuffer* val, size_t len_size) { - uint32_t len; - if (!Read(&len, len_size)) { - return false; - } - return Read(val, len); -} - -bool TlsParser::Skip(size_t len) { - if (len > remaining()) { return false; } - consume(len); - return true; -} - -bool TlsParser::SkipVariable(size_t len_size) { - uint32_t len; - if (!Read(&len, len_size)) { - return false; - } - return Skip(len); -} - -} // namespace nss_test diff --git a/nss/external_tests/ssl_gtest/tls_parser.h b/nss/external_tests/ssl_gtest/tls_parser.h deleted file mode 100644 index da3f3a7..0000000 --- a/nss/external_tests/ssl_gtest/tls_parser.h +++ /dev/null @@ -1,106 +0,0 @@ -/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */ -/* vim: set ts=2 et sw=2 tw=80: */ -/* This Source Code Form is subject to the terms of the Mozilla Public - * License, v. 2.0. If a copy of the MPL was not distributed with this file, - * You can obtain one at http://mozilla.org/MPL/2.0/. */ - -#ifndef tls_parser_h_ -#define tls_parser_h_ - -#include <memory> -#include <cstdint> -#include <cstring> -#if defined(WIN32) || defined(WIN64) -#include <winsock2.h> -#else -#include <arpa/inet.h> -#endif -#include "databuffer.h" - -namespace nss_test { - -const uint8_t kTlsChangeCipherSpecType = 20; -const uint8_t kTlsAlertType = 21; -const uint8_t kTlsHandshakeType = 22; - -const uint8_t kTlsHandshakeClientHello = 1; -const uint8_t kTlsHandshakeServerHello = 2; -const uint8_t kTlsHandshakeCertificate = 11; -const uint8_t kTlsHandshakeServerKeyExchange = 12; -const uint8_t kTlsHandshakeCertificateVerify = 15; -const uint8_t kTlsHandshakeClientKeyExchange = 16; -const uint8_t kTlsHandshakeFinished = 20; - -const uint8_t kTlsAlertWarning = 1; -const uint8_t kTlsAlertFatal = 2; - -const uint8_t kTlsAlertUnexpectedMessage = 10; -const uint8_t kTlsAlertBadRecordMac = 20; -const uint8_t kTlsAlertHandshakeFailure = 40; -const uint8_t kTlsAlertIllegalParameter = 47; -const uint8_t kTlsAlertDecodeError = 50; -const uint8_t kTlsAlertUnsupportedExtension = 110; -const uint8_t kTlsAlertNoApplicationProtocol = 120; - -const uint8_t kTlsFakeChangeCipherSpec[] = { - kTlsChangeCipherSpecType, // Type - 0xfe, 0xff, // Version - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x10, // Fictitious sequence # - 0x00, 0x01, // Length - 0x01 // Value -}; - -inline bool IsDtls(uint16_t version) { - return (version & 0x8000) == 0x8000; -} - -inline uint16_t NormalizeTlsVersion(uint16_t version) { - if (version == 0xfeff) { - return 0x0302; // special: DTLS 1.0 == TLS 1.1 - } - if (IsDtls(version)) { - return (version ^ 0xffff) + 0x0201; - } - return version; -} - -inline void WriteVariable(DataBuffer* target, size_t index, - const DataBuffer& buf, size_t len_size) { - target->Write(index, static_cast<uint32_t>(buf.len()), len_size); - target->Write(index + len_size, buf.data(), buf.len()); -} - -class TlsParser { - public: - TlsParser(const uint8_t* data, size_t len) - : buffer_(data, len), offset_(0) {} - explicit TlsParser(const DataBuffer& buf) - : buffer_(buf), offset_(0) {} - - bool Read(uint8_t* val); - // Read an integral type of specified width. - bool Read(uint32_t* val, size_t size); - // Reads len bytes into dest buffer, overwriting it. - bool Read(DataBuffer* dest, size_t len); - // Reads bytes into dest buffer, overwriting it. The number of bytes is - // determined by reading from len_size bytes from the stream first. - bool ReadVariable(DataBuffer* dest, size_t len_size); - - bool Skip(size_t len); - bool SkipVariable(size_t len_size); - - size_t consumed() const { return offset_; } - size_t remaining() const { return buffer_.len() - offset_; } - - private: - void consume(size_t len) { offset_ += len; } - const uint8_t* ptr() const { return buffer_.data() + offset_; } - - DataBuffer buffer_; - size_t offset_; -}; - -} // namespace nss_test - -#endif |