summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
diff options
context:
space:
mode:
Diffstat (limited to 'nss/gtests/ssl_gtest/ssl_extension_unittest.cc')
-rw-r--r--nss/gtests/ssl_gtest/ssl_extension_unittest.cc985
1 files changed, 985 insertions, 0 deletions
diff --git a/nss/gtests/ssl_gtest/ssl_extension_unittest.cc b/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
new file mode 100644
index 0000000..9200e72
--- /dev/null
+++ b/nss/gtests/ssl_gtest/ssl_extension_unittest.cc
@@ -0,0 +1,985 @@
+/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
+/* vim: set ts=2 et sw=2 tw=80: */
+/* This Source Code Form is subject to the terms of the Mozilla Public
+ * License, v. 2.0. If a copy of the MPL was not distributed with this file,
+ * You can obtain one at http://mozilla.org/MPL/2.0/. */
+
+#include "ssl.h"
+#include "ssl3prot.h"
+#include "sslerr.h"
+#include "sslproto.h"
+
+#include <memory>
+
+#include "tls_connect.h"
+#include "tls_filter.h"
+#include "tls_parser.h"
+
+namespace nss_test {
+
+class TlsExtensionTruncator : public TlsExtensionFilter {
+ public:
+ TlsExtensionTruncator(uint16_t extension, size_t length)
+ : extension_(extension), length_(length) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+ if (input.len() <= length_) {
+ return KEEP;
+ }
+
+ output->Assign(input.data(), length_);
+ return CHANGE;
+ }
+
+ private:
+ uint16_t extension_;
+ size_t length_;
+};
+
+class TlsExtensionDamager : public TlsExtensionFilter {
+ public:
+ TlsExtensionDamager(uint16_t extension, size_t index)
+ : extension_(extension), index_(index) {}
+ virtual PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != extension_) {
+ return KEEP;
+ }
+
+ *output = input;
+ output->data()[index_] += 73; // Increment selected for maximum damage
+ return CHANGE;
+ }
+
+ private:
+ uint16_t extension_;
+ size_t index_;
+};
+
+class TlsExtensionInjector : public TlsHandshakeFilter {
+ public:
+ TlsExtensionInjector(uint16_t ext, DataBuffer& data)
+ : extension_(ext), data_(data) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ size_t offset;
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
+ return KEEP;
+ }
+ offset = parser.consumed();
+ } else if (header.handshake_type() == kTlsHandshakeServerHello) {
+ TlsParser parser(input);
+ if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) {
+ return KEEP;
+ }
+ offset = parser.consumed();
+ } else {
+ return KEEP;
+ }
+
+ *output = input;
+
+ // Increase the size of the extensions.
+ uint16_t ext_len;
+ memcpy(&ext_len, output->data() + offset, sizeof(ext_len));
+ ext_len = htons(ntohs(ext_len) + data_.len() + 4);
+ memcpy(output->data() + offset, &ext_len, sizeof(ext_len));
+
+ // Insert the extension type and length.
+ DataBuffer type_length;
+ type_length.Allocate(4);
+ type_length.Write(0, extension_, 2);
+ type_length.Write(2, data_.len(), 2);
+ output->Splice(type_length, offset + 2);
+
+ // Insert the payload.
+ if (data_.len() > 0) {
+ output->Splice(data_, offset + 6);
+ }
+
+ return CHANGE;
+ }
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionAppender : public TlsHandshakeFilter {
+ public:
+ TlsExtensionAppender(uint16_t ext, DataBuffer& data)
+ : extension_(ext), data_(data) {}
+
+ virtual PacketFilter::Action FilterHandshake(const HandshakeHeader& header,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ size_t offset;
+ TlsParser parser(input);
+ if (header.handshake_type() == kTlsHandshakeClientHello) {
+ if (!TlsExtensionFilter::FindClientHelloExtensions(&parser, header)) {
+ return KEEP;
+ }
+ } else if (header.handshake_type() == kTlsHandshakeServerHello) {
+ if (!TlsExtensionFilter::FindServerHelloExtensions(&parser)) {
+ return KEEP;
+ }
+ } else {
+ return KEEP;
+ }
+ offset = parser.consumed();
+ *output = input;
+
+ uint32_t ext_len;
+ if (!parser.Read(&ext_len, 2)) {
+ ADD_FAILURE();
+ return KEEP;
+ }
+
+ ext_len += 4 + data_.len();
+ output->Write(offset, ext_len, 2);
+
+ offset = output->len();
+ offset = output->Write(offset, extension_, 2);
+ WriteVariable(output, offset, data_, 2);
+
+ return CHANGE;
+ }
+
+ private:
+ const uint16_t extension_;
+ const DataBuffer data_;
+};
+
+class TlsExtensionTestBase : public TlsConnectTestBase {
+ protected:
+ TlsExtensionTestBase(Mode mode, uint16_t version)
+ : TlsConnectTestBase(mode, version) {}
+ TlsExtensionTestBase(const std::string& mode, uint16_t version)
+ : TlsConnectTestBase(mode, version) {}
+
+ void ClientHelloErrorTest(PacketFilter* filter,
+ uint8_t alert = kTlsAlertDecodeError) {
+ auto alert_recorder = new TlsAlertRecorder();
+ server_->SetPacketFilter(alert_recorder);
+ if (filter) {
+ client_->SetPacketFilter(filter);
+ }
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(alert, alert_recorder->description());
+ }
+
+ void ServerHelloErrorTest(PacketFilter* filter,
+ uint8_t alert = kTlsAlertDecodeError) {
+ auto alert_recorder = new TlsAlertRecorder();
+ client_->SetPacketFilter(alert_recorder);
+ if (filter) {
+ server_->SetPacketFilter(filter);
+ }
+ ConnectExpectFail();
+ EXPECT_EQ(kTlsAlertFatal, alert_recorder->level());
+ EXPECT_EQ(alert, alert_recorder->description());
+ }
+
+ static void InitSimpleSni(DataBuffer* extension) {
+ const char* name = "host.name";
+ const size_t namelen = PL_strlen(name);
+ extension->Allocate(namelen + 5);
+ extension->Write(0, namelen + 3, 2);
+ extension->Write(2, static_cast<uint32_t>(0), 1); // 0 == hostname
+ extension->Write(3, namelen, 2);
+ extension->Write(5, reinterpret_cast<const uint8_t*>(name), namelen);
+ }
+
+ void HrrThenRemoveExtensionsTest(SSLExtensionType type, PRInt32 client_error,
+ PRInt32 server_error) {
+ static const std::vector<SSLNamedGroup> client_groups = {
+ ssl_grp_ec_secp384r1, ssl_grp_ec_curve25519};
+ static const std::vector<SSLNamedGroup> server_groups = {
+ ssl_grp_ec_curve25519, ssl_grp_ec_secp384r1};
+ client_->ConfigNamedGroups(client_groups);
+ server_->ConfigNamedGroups(server_groups);
+ EnsureTlsSetup();
+ client_->StartConnect();
+ server_->StartConnect();
+ client_->Handshake(); // Send ClientHello
+ server_->Handshake(); // Send HRR.
+ client_->SetPacketFilter(new TlsExtensionDropper(type));
+ Handshake();
+ client_->CheckErrorCode(client_error);
+ server_->CheckErrorCode(server_error);
+ }
+};
+
+class TlsExtensionTestDtls : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<uint16_t> {
+ public:
+ TlsExtensionTestDtls() : TlsExtensionTestBase(DGRAM, GetParam()) {}
+};
+
+class TlsExtensionTest12Plus
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTest12Plus()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTest12
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTest12()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTest13 : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::string> {
+ public:
+ TlsExtensionTest13()
+ : TlsExtensionTestBase(GetParam(), SSL_LIBRARY_VERSION_TLS_1_3) {}
+
+ void ConnectWithBogusVersionList(const uint8_t* buf, size_t len) {
+ DataBuffer versions_buf(buf, len);
+ client_->SetPacketFilter(new TlsExtensionReplacer(
+ ssl_tls13_supported_versions_xtn, versions_buf));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+ }
+
+ void ConnectWithReplacementVersionList(uint16_t version) {
+ DataBuffer versions_buf;
+
+ size_t index = versions_buf.Write(0, 2, 1);
+ versions_buf.Write(index, version, 2);
+ client_->SetPacketFilter(new TlsExtensionReplacer(
+ ssl_tls13_supported_versions_xtn, versions_buf));
+ ConnectExpectFail();
+ }
+};
+
+class TlsExtensionTest13Stream : public TlsExtensionTestBase {
+ public:
+ TlsExtensionTest13Stream()
+ : TlsExtensionTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+};
+
+class TlsExtensionTestGeneric
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTestGeneric()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+class TlsExtensionTestPre13
+ : public TlsExtensionTestBase,
+ public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+ public:
+ TlsExtensionTestPre13()
+ : TlsExtensionTestBase(std::get<0>(GetParam()), std::get<1>(GetParam())) {
+ }
+};
+
+TEST_P(TlsExtensionTestGeneric, DamageSniLength) {
+ ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 1));
+}
+
+TEST_P(TlsExtensionTestGeneric, DamageSniHostLength) {
+ ClientHelloErrorTest(new TlsExtensionDamager(ssl_server_name_xtn, 4));
+}
+
+TEST_P(TlsExtensionTestGeneric, TruncateSni) {
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_server_name_xtn, 7));
+}
+
+// A valid extension that appears twice will be reported as unsupported.
+TEST_P(TlsExtensionTestGeneric, RepeatSni) {
+ DataBuffer extension;
+ InitSimpleSni(&extension);
+ ClientHelloErrorTest(new TlsExtensionInjector(ssl_server_name_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+// An SNI entry with zero length is considered invalid (strangely, not if it is
+// the last entry, which is probably a bug).
+TEST_P(TlsExtensionTestGeneric, BadSni) {
+ DataBuffer simple;
+ InitSimpleSni(&simple);
+ DataBuffer extension;
+ extension.Allocate(simple.len() + 3);
+ extension.Write(0, static_cast<uint32_t>(0), 3);
+ extension.Write(3, simple);
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_server_name_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, EmptySni) {
+ DataBuffer extension;
+ extension.Allocate(2);
+ extension.Write(0, static_cast<uint32_t>(0), 2);
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_server_name_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, EmptyAlpnExtension) {
+ EnableAlpn();
+ DataBuffer extension;
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertIllegalParameter);
+}
+
+// An empty ALPN isn't considered bad, though it does lead to there being no
+// protocol for the server to select.
+TEST_P(TlsExtensionTestGeneric, EmptyAlpnList) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension),
+ kTlsAlertNoApplicationProtocol);
+}
+
+TEST_P(TlsExtensionTestGeneric, OneByteAlpn) {
+ EnableAlpn();
+ ClientHelloErrorTest(
+ new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 1));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnMissingValue) {
+ EnableAlpn();
+ // This will leave the length of the second entry, but no value.
+ ClientHelloErrorTest(
+ new TlsExtensionTruncator(ssl_app_layer_protocol_xtn, 5));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnZeroLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, AlpnMismatch) {
+ const uint8_t client_alpn[] = {0x01, 0x61};
+ client_->EnableAlpn(client_alpn, sizeof(client_alpn));
+ const uint8_t server_alpn[] = {0x02, 0x61, 0x62};
+ server_->EnableAlpn(server_alpn, sizeof(server_alpn));
+
+ ClientHelloErrorTest(nullptr, kTlsAlertNoApplicationProtocol);
+}
+
+// Many of these tests fail in TLS 1.3 because the extension is encrypted, which
+// prevents modification of the value from the ServerHello.
+TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyList) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedEmptyName) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedListTrailingData) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x02, 0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedExtraEntry) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x04, 0x01, 0x61, 0x01, 0x62};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedBadListLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x99, 0x01, 0x61, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, AlpnReturnedBadNameLength) {
+ EnableAlpn();
+ const uint8_t val[] = {0x00, 0x02, 0x99, 0x61};
+ DataBuffer extension(val, sizeof(val));
+ ServerHelloErrorTest(
+ new TlsExtensionReplacer(ssl_app_layer_protocol_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestDtls, SrtpShort) {
+ EnableSrtp();
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_use_srtp_xtn, 3));
+}
+
+TEST_P(TlsExtensionTestDtls, SrtpOdd) {
+ EnableSrtp();
+ const uint8_t val[] = {0x00, 0x01, 0xff, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(new TlsExtensionReplacer(ssl_use_srtp_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsBadLength) {
+ const uint8_t val[] = {0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsTrailingData) {
+ const uint8_t val[] = {0x00, 0x02, 0x04, 0x01, 0x00}; // sha-256, rsa
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsEmpty) {
+ const uint8_t val[] = {0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTest12Plus, SignatureAlgorithmsOddLength) {
+ const uint8_t val[] = {0x00, 0x01, 0x04};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_signature_algorithms_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, NoSupportedGroups) {
+ ClientHelloErrorTest(new TlsExtensionDropper(ssl_supported_groups_xtn),
+ version_ < SSL_LIBRARY_VERSION_TLS_1_3
+ ? kTlsAlertDecryptError
+ : kTlsAlertMissingExtension);
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesShort) {
+ const uint8_t val[] = {0x00, 0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesBadLength) {
+ const uint8_t val[] = {0x09, 0x99, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestGeneric, SupportedCurvesTrailingData) {
+ const uint8_t val[] = {0x00, 0x02, 0x00, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_elliptic_curves_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsEmpty) {
+ const uint8_t val[] = {0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsBadLength) {
+ const uint8_t val[] = {0x99, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, SupportedPointsTrailingData) {
+ const uint8_t val[] = {0x01, 0x00, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_ec_point_formats_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoBadLength) {
+ const uint8_t val[] = {0x99};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+}
+
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoMismatch) {
+ const uint8_t val[] = {0x01, 0x00};
+ DataBuffer extension(val, sizeof(val));
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+}
+
+// The extension has to contain a length.
+TEST_P(TlsExtensionTestPre13, RenegotiationInfoExtensionEmpty) {
+ DataBuffer extension;
+ ClientHelloErrorTest(
+ new TlsExtensionReplacer(ssl_renegotiation_info_xtn, extension));
+}
+
+// This only works on TLS 1.2, since it relies on static RSA; otherwise libssl
+// picks the wrong cipher suite.
+TEST_P(TlsExtensionTest12, SignatureAlgorithmConfiguration) {
+ const SSLSignatureScheme schemes[] = {ssl_sig_rsa_pss_sha512,
+ ssl_sig_rsa_pss_sha384};
+
+ TlsExtensionCapture* capture =
+ new TlsExtensionCapture(ssl_signature_algorithms_xtn);
+ client_->SetSignatureSchemes(schemes, PR_ARRAY_SIZE(schemes));
+ client_->SetPacketFilter(capture);
+ EnableOnlyStaticRsaCiphers();
+ Connect();
+
+ const DataBuffer& ext = capture->extension();
+ EXPECT_EQ(2 + PR_ARRAY_SIZE(schemes) * 2, ext.len());
+ for (size_t i = 0, cursor = 2;
+ i < PR_ARRAY_SIZE(schemes) && cursor < ext.len(); ++i) {
+ uint32_t v = 0;
+ EXPECT_TRUE(ext.Read(cursor, 2, &v));
+ cursor += 2;
+ EXPECT_EQ(schemes[i], static_cast<SSLSignatureScheme>(v));
+ }
+}
+
+// Temporary test to verify that we choke on an empty ClientKeyShare.
+// This test will fail when we implement HelloRetryRequest.
+TEST_P(TlsExtensionTest13, EmptyClientKeyShare) {
+ ClientHelloErrorTest(new TlsExtensionTruncator(ssl_tls13_key_share_xtn, 2),
+ kTlsAlertHandshakeFailure);
+}
+
+// These tests only work in stream mode because the client sends a
+// cleartext alert which causes a MAC error on the server. With
+// stream this causes handshake failure but with datagram, the
+// packet gets dropped.
+TEST_F(TlsExtensionTest13Stream, DropServerKeyShare) {
+ EnsureTlsSetup();
+ server_->SetPacketFilter(new TlsExtensionDropper(ssl_tls13_key_share_xtn));
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
+TEST_F(TlsExtensionTest13Stream, WrongServerKeyShare) {
+ const uint16_t wrong_group = ssl_grp_ec_secp384r1;
+
+ static const uint8_t key_share[] = {
+ wrong_group >> 8,
+ wrong_group & 0xff, // Group we didn't offer.
+ 0x00,
+ 0x02, // length = 2
+ 0x01,
+ 0x02};
+ DataBuffer buf(key_share, sizeof(key_share));
+ EnsureTlsSetup();
+ server_->SetPacketFilter(
+ new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf));
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_RX_MALFORMED_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
+// TODO(ekr@rtfm.com): This is the wrong error code. See bug 1307269.
+TEST_F(TlsExtensionTest13Stream, UnknownServerKeyShare) {
+ const uint16_t wrong_group = 0xffff;
+
+ static const uint8_t key_share[] = {
+ wrong_group >> 8,
+ wrong_group & 0xff, // Group we didn't offer.
+ 0x00,
+ 0x02, // length = 2
+ 0x01,
+ 0x02};
+ DataBuffer buf(key_share, sizeof(key_share));
+ EnsureTlsSetup();
+ server_->SetPacketFilter(
+ new TlsExtensionReplacer(ssl_tls13_key_share_xtn, buf));
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_MISSING_KEY_SHARE, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
+TEST_F(TlsExtensionTest13Stream, AddServerSignatureAlgorithmsOnResumption) {
+ SetupForResume();
+ DataBuffer empty;
+ server_->SetPacketFilter(
+ new TlsExtensionInjector(ssl_signature_algorithms_xtn, empty));
+ ConnectExpectFail();
+ EXPECT_EQ(SSL_ERROR_EXTENSION_DISALLOWED_FOR_VERSION, client_->error_code());
+ EXPECT_EQ(SSL_ERROR_BAD_MAC_READ, server_->error_code());
+}
+
+struct PskIdentity {
+ DataBuffer identity;
+ uint32_t obfuscated_ticket_age;
+};
+
+class TlsPreSharedKeyReplacer;
+
+typedef std::function<void(TlsPreSharedKeyReplacer*)>
+ TlsPreSharedKeyReplacerFunc;
+
+class TlsPreSharedKeyReplacer : public TlsExtensionFilter {
+ public:
+ TlsPreSharedKeyReplacer(TlsPreSharedKeyReplacerFunc function)
+ : identities_(), binders_(), function_(function) {}
+
+ static size_t CopyAndMaybeReplace(TlsParser* parser, size_t size,
+ const std::unique_ptr<DataBuffer>& replace,
+ size_t index, DataBuffer* output) {
+ DataBuffer tmp;
+ bool ret = parser->ReadVariable(&tmp, size);
+ EXPECT_EQ(true, ret);
+ if (!ret) return 0;
+ if (replace) {
+ tmp = *replace;
+ }
+
+ return WriteVariable(output, index, tmp, size);
+ }
+
+ PacketFilter::Action FilterExtension(uint16_t extension_type,
+ const DataBuffer& input,
+ DataBuffer* output) {
+ if (extension_type != ssl_tls13_pre_shared_key_xtn) {
+ return KEEP;
+ }
+
+ if (!Decode(input)) {
+ return KEEP;
+ }
+
+ // Call the function.
+ function_(this);
+
+ Encode(output);
+
+ return CHANGE;
+ }
+
+ std::vector<PskIdentity> identities_;
+ std::vector<DataBuffer> binders_;
+
+ private:
+ bool Decode(const DataBuffer& input) {
+ std::unique_ptr<TlsParser> parser(new TlsParser(input));
+ DataBuffer identities;
+
+ if (!parser->ReadVariable(&identities, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ DataBuffer binders;
+ if (!parser->ReadVariable(&binders, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+ EXPECT_EQ(0UL, parser->remaining());
+
+ // Now parse the inner sections.
+ parser.reset(new TlsParser(identities));
+ while (parser->remaining()) {
+ PskIdentity identity;
+
+ if (!parser->ReadVariable(&identity.identity, 2)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ if (!parser->Read(&identity.obfuscated_ticket_age, 4)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ identities_.push_back(identity);
+ }
+
+ parser.reset(new TlsParser(binders));
+ while (parser->remaining()) {
+ DataBuffer binder;
+
+ if (!parser->ReadVariable(&binder, 1)) {
+ ADD_FAILURE();
+ return false;
+ }
+
+ binders_.push_back(binder);
+ }
+
+ return true;
+ }
+
+ void Encode(DataBuffer* output) {
+ DataBuffer identities;
+ size_t index = 0;
+ for (auto id : identities_) {
+ index = WriteVariable(&identities, index, id.identity, 2);
+ index = identities.Write(index, id.obfuscated_ticket_age, 4);
+ }
+
+ DataBuffer binders;
+ index = 0;
+ for (auto binder : binders_) {
+ index = WriteVariable(&binders, index, binder, 1);
+ }
+
+ output->Truncate(0);
+ index = 0;
+ index = WriteVariable(output, index, identities, 2);
+ index = WriteVariable(output, index, binders, 2);
+ }
+
+ TlsPreSharedKeyReplacerFunc function_;
+};
+
+TEST_F(TlsExtensionTest13Stream, ResumeEmptyPskLabel) {
+ SetupForResume();
+
+ client_->SetPacketFilter(new TlsPreSharedKeyReplacer([](
+ TlsPreSharedKeyReplacer* r) { r->identities_[0].identity.Truncate(0); }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Flip the first byte of the binder.
+TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderValue) {
+ SetupForResume();
+
+ client_->SetPacketFilter(
+ new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ r->binders_[0].Write(0, r->binders_[0].data()[0] ^ 0xff, 1);
+ }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+// Extend the binder by one.
+TEST_F(TlsExtensionTest13Stream, ResumeIncorrectBinderLength) {
+ SetupForResume();
+
+ client_->SetPacketFilter(
+ new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ r->binders_[0].Write(r->binders_[0].len(), 0xff, 1);
+ }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Binders must be at least 32 bytes.
+TEST_F(TlsExtensionTest13Stream, ResumeBinderTooShort) {
+ SetupForResume();
+
+ client_->SetPacketFilter(new TlsPreSharedKeyReplacer(
+ [](TlsPreSharedKeyReplacer* r) { r->binders_[0].Truncate(31); }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+// Duplicate the identity and binder. This will fail with an error
+// processing the binder (because we extended the identity list.)
+TEST_F(TlsExtensionTest13Stream, ResumeTwoPsks) {
+ SetupForResume();
+
+ client_->SetPacketFilter(
+ new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ r->identities_.push_back(r->identities_[0]);
+ r->binders_.push_back(r->binders_[0]);
+ }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+}
+
+// The next two tests have mismatches in the number of identities
+// and binders. This generates an illegal parameter alert.
+TEST_F(TlsExtensionTest13Stream, ResumeTwoIdentitiesOneBinder) {
+ SetupForResume();
+
+ client_->SetPacketFilter(
+ new TlsPreSharedKeyReplacer([](TlsPreSharedKeyReplacer* r) {
+ r->identities_.push_back(r->identities_[0]);
+ }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumeOneIdentityTwoBinders) {
+ SetupForResume();
+
+ client_->SetPacketFilter(new TlsPreSharedKeyReplacer([](
+ TlsPreSharedKeyReplacer* r) { r->binders_.push_back(r->binders_[0]); }));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumePskExtensionNotLast) {
+ SetupForResume();
+
+ const uint8_t empty_buf[] = {0};
+ DataBuffer empty(empty_buf, 0);
+ client_->SetPacketFilter(
+ // Inject an unused extension.
+ new TlsExtensionAppender(0xffff, empty));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_CLIENT_HELLO);
+}
+
+TEST_F(TlsExtensionTest13Stream, ResumeNoKeModes) {
+ SetupForResume();
+
+ DataBuffer empty;
+ client_->SetPacketFilter(
+ new TlsExtensionDropper(ssl_tls13_psk_key_exchange_modes_xtn));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_MISSING_EXTENSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_MISSING_PSK_KEY_EXCHANGE_MODES);
+}
+
+// The following test contains valid but unacceptable PreSharedKey
+// modes and therefore produces non-resumption followed by MAC
+// errors.
+TEST_F(TlsExtensionTest13Stream, ResumeBogusKeModes) {
+ SetupForResume();
+ const static uint8_t ke_modes[] = {1, // Length
+ kTls13PskKe};
+
+ DataBuffer modes(ke_modes, sizeof(ke_modes));
+ client_->SetPacketFilter(
+ new TlsExtensionReplacer(ssl_tls13_psk_key_exchange_modes_xtn, modes));
+ ConnectExpectFail();
+ client_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+ server_->CheckErrorCode(SSL_ERROR_BAD_MAC_READ);
+}
+
+TEST_P(TlsExtensionTest13, NoKeModesIfResumptionOff) {
+ ConfigureSessionCache(RESUME_NONE, RESUME_NONE);
+ auto capture = new TlsExtensionCapture(ssl_tls13_psk_key_exchange_modes_xtn);
+ client_->SetPacketFilter(capture);
+ Connect();
+ EXPECT_FALSE(capture->captured());
+}
+
+// In these tests, we downgrade to TLS 1.2, causing the
+// server to negotiate TLS 1.2.
+// 1. Both sides only support TLS 1.3, so we get a cipher version
+// error.
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionList) {
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->CheckErrorCode(SSL_ERROR_PROTOCOL_VERSION_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_UNSUPPORTED_VERSION);
+}
+
+// 2. Server supports 1.2 and 1.3, client supports 1.2, so we
+// can't negotiate any ciphers.
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListServerV12) {
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+ client_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+ server_->CheckErrorCode(SSL_ERROR_NO_CYPHER_OVERLAP);
+}
+
+// 3. Server supports 1.2 and 1.3, client supports 1.2 and 1.3
+// but advertises 1.2 (because we changed things).
+TEST_P(TlsExtensionTest13, RemoveTls13FromVersionListBothV12) {
+ client_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ server_->SetVersionRange(SSL_LIBRARY_VERSION_TLS_1_2,
+ SSL_LIBRARY_VERSION_TLS_1_3);
+ ConnectWithReplacementVersionList(SSL_LIBRARY_VERSION_TLS_1_2);
+#ifndef TLS_1_3_DRAFT_VERSION
+ client_->CheckErrorCode(SSL_ERROR_RX_MALFORMED_SERVER_HELLO);
+ server_->CheckErrorCode(SSL_ERROR_ILLEGAL_PARAMETER_ALERT);
+#else
+ client_->CheckErrorCode(SSL_ERROR_DECRYPT_ERROR_ALERT);
+ server_->CheckErrorCode(SSL_ERROR_BAD_HANDSHAKE_HASH_VALUE);
+#endif
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveSignatureAlgorithms) {
+ HrrThenRemoveExtensionsTest(ssl_signature_algorithms_xtn,
+ SSL_ERROR_MISSING_EXTENSION_ALERT,
+ SSL_ERROR_MISSING_SIGNATURE_ALGORITHMS_EXTENSION);
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveKeyShare) {
+ HrrThenRemoveExtensionsTest(ssl_tls13_key_share_xtn,
+ SSL_ERROR_ILLEGAL_PARAMETER_ALERT,
+ SSL_ERROR_BAD_2ND_CLIENT_HELLO);
+}
+
+TEST_P(TlsExtensionTest13, HrrThenRemoveSupportedGroups) {
+ HrrThenRemoveExtensionsTest(ssl_supported_groups_xtn,
+ SSL_ERROR_MISSING_EXTENSION_ALERT,
+ SSL_ERROR_MISSING_SUPPORTED_GROUPS_EXTENSION);
+}
+
+TEST_P(TlsExtensionTest13, EmptyVersionList) {
+ static const uint8_t ext[] = {0x00, 0x00};
+ ConnectWithBogusVersionList(ext, sizeof(ext));
+}
+
+TEST_P(TlsExtensionTest13, OddVersionList) {
+ static const uint8_t ext[] = {0x00, 0x01, 0x00};
+ ConnectWithBogusVersionList(ext, sizeof(ext));
+}
+
+INSTANTIATE_TEST_CASE_P(ExtensionStream, TlsExtensionTestGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsVAll));
+INSTANTIATE_TEST_CASE_P(ExtensionDatagram, TlsExtensionTestGeneric,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV11Plus));
+INSTANTIATE_TEST_CASE_P(ExtensionDatagramOnly, TlsExtensionTestDtls,
+ TlsConnectTestBase::kTlsV11Plus);
+
+INSTANTIATE_TEST_CASE_P(ExtensionTls12Plus, TlsExtensionTest12Plus,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV12Plus));
+
+INSTANTIATE_TEST_CASE_P(ExtensionPre13Stream, TlsExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesStream,
+ TlsConnectTestBase::kTlsV10ToV12));
+INSTANTIATE_TEST_CASE_P(ExtensionPre13Datagram, TlsExtensionTestPre13,
+ ::testing::Combine(TlsConnectTestBase::kTlsModesAll,
+ TlsConnectTestBase::kTlsV11V12));
+
+INSTANTIATE_TEST_CASE_P(ExtensionTls13, TlsExtensionTest13,
+ TlsConnectTestBase::kTlsModesAll);
+
+} // namespace nspr_test