summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/tls_connect.h
diff options
context:
space:
mode:
Diffstat (limited to 'nss/gtests/ssl_gtest/tls_connect.h')
-rw-r--r--nss/gtests/ssl_gtest/tls_connect.h86
1 files changed, 45 insertions, 41 deletions
diff --git a/nss/gtests/ssl_gtest/tls_connect.h b/nss/gtests/ssl_gtest/tls_connect.h
index aa4a32d..73e8dc8 100644
--- a/nss/gtests/ssl_gtest/tls_connect.h
+++ b/nss/gtests/ssl_gtest/tls_connect.h
@@ -25,9 +25,12 @@ extern std::string VersionString(uint16_t version);
// A generic TLS connection test base.
class TlsConnectTestBase : public ::testing::Test {
public:
- static ::testing::internal::ParamGenerator<std::string> kTlsModesStream;
- static ::testing::internal::ParamGenerator<std::string> kTlsModesDatagram;
- static ::testing::internal::ParamGenerator<std::string> kTlsModesAll;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsStream;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsDatagram;
+ static ::testing::internal::ParamGenerator<SSLProtocolVariant>
+ kTlsVariantsAll;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV11;
static ::testing::internal::ParamGenerator<uint16_t> kTlsV12;
@@ -39,8 +42,7 @@ class TlsConnectTestBase : public ::testing::Test {
static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll;
- TlsConnectTestBase(Mode mode, uint16_t version);
- TlsConnectTestBase(const std::string& mode, uint16_t version);
+ TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
virtual ~TlsConnectTestBase();
void SetUp();
@@ -68,6 +70,9 @@ class TlsConnectTestBase : public ::testing::Test {
void CheckConnected();
// Connect and expect it to fail.
void ConnectExpectFail();
+ void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
+ void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
+ void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
void ConnectWithCipherSuite(uint16_t cipher_suite);
// Check that the keys used in the handshake match expectations.
void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
@@ -108,13 +113,14 @@ class TlsConnectTestBase : public ::testing::Test {
void ExpectExtendedMasterSecret(bool expected);
void ExpectEarlyDataAccepted(bool expected);
void DisableECDHEServerKeyReuse();
+ void SkipVersionChecks();
protected:
- Mode mode_;
- TlsAgent* client_;
- TlsAgent* server_;
- TlsAgent* client_model_;
- TlsAgent* server_model_;
+ SSLProtocolVariant variant_;
+ std::shared_ptr<TlsAgent> client_;
+ std::shared_ptr<TlsAgent> server_;
+ std::unique_ptr<TlsAgent> client_model_;
+ std::unique_ptr<TlsAgent> server_model_;
uint16_t version_;
SessionResumptionMode expected_resumption_mode_;
std::vector<std::vector<uint8_t>> session_ids_;
@@ -126,16 +132,13 @@ class TlsConnectTestBase : public ::testing::Test {
const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61};
private:
- static inline Mode ToMode(const std::string& str) {
- return str == "TLS" ? STREAM : DGRAM;
- }
-
void CheckResumption(SessionResumptionMode expected);
void CheckExtendedMasterSecret();
void CheckEarlyDataAccepted();
bool expect_extended_master_secret_;
bool expect_early_data_accepted_;
+ bool skip_version_checks_;
// Track groups and make sure that there are no duplicates.
class DuplicateGroupChecker {
@@ -154,20 +157,20 @@ class TlsConnectTestBase : public ::testing::Test {
// A non-parametrized TLS test base.
class TlsConnectTest : public TlsConnectTestBase {
public:
- TlsConnectTest() : TlsConnectTestBase(STREAM, 0) {}
+ TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
};
// A non-parametrized DTLS-only test base.
class DtlsConnectTest : public TlsConnectTestBase {
public:
- DtlsConnectTest() : TlsConnectTestBase(DGRAM, 0) {}
+ DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {}
};
// A TLS-only test base.
class TlsConnectStream : public TlsConnectTestBase,
public ::testing::WithParamInterface<uint16_t> {
public:
- TlsConnectStream() : TlsConnectTestBase(STREAM, GetParam()) {}
+ TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {}
};
// A TLS-only test base for tests before 1.3
@@ -177,30 +180,30 @@ class TlsConnectStreamPre13 : public TlsConnectStream {};
class TlsConnectDatagram : public TlsConnectTestBase,
public ::testing::WithParamInterface<uint16_t> {
public:
- TlsConnectDatagram() : TlsConnectTestBase(DGRAM, GetParam()) {}
+ TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {}
};
-// A generic test class that can be either STREAM or DGRAM and a single version
-// of TLS. This is configured in ssl_loopback_unittest.cc. All uses of this
-// should use TEST_P().
-class TlsConnectGeneric
- : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+// A generic test class that can be either stream or datagram and a single
+// version of TLS. This is configured in ssl_loopback_unittest.cc.
+class TlsConnectGeneric : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectGeneric();
};
// A Pre TLS 1.2 generic test.
-class TlsConnectPre12
- : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsConnectPre12 : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectPre12();
};
// A TLS 1.2 only generic test.
-class TlsConnectTls12 : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::string> {
+class TlsConnectTls12
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
public:
TlsConnectTls12();
};
@@ -209,20 +212,21 @@ class TlsConnectTls12 : public TlsConnectTestBase,
class TlsConnectStreamTls12 : public TlsConnectTestBase {
public:
TlsConnectStreamTls12()
- : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_2) {}
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {}
};
// A TLS 1.2+ generic test.
-class TlsConnectTls12Plus
- : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::tuple<std::string, uint16_t>> {
+class TlsConnectTls12Plus : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsConnectTls12Plus();
};
// A TLS 1.3 only generic test.
-class TlsConnectTls13 : public TlsConnectTestBase,
- public ::testing::WithParamInterface<std::string> {
+class TlsConnectTls13
+ : public TlsConnectTestBase,
+ public ::testing::WithParamInterface<SSLProtocolVariant> {
public:
TlsConnectTls13();
};
@@ -231,13 +235,13 @@ class TlsConnectTls13 : public TlsConnectTestBase,
class TlsConnectStreamTls13 : public TlsConnectTestBase {
public:
TlsConnectStreamTls13()
- : TlsConnectTestBase(STREAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+ : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
class TlsConnectDatagram13 : public TlsConnectTestBase {
public:
TlsConnectDatagram13()
- : TlsConnectTestBase(DGRAM, SSL_LIBRARY_VERSION_TLS_1_3) {}
+ : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
};
// A variant that is used only with Pre13.
@@ -245,10 +249,10 @@ class TlsConnectGenericPre13 : public TlsConnectGeneric {};
class TlsKeyExchangeTest : public TlsConnectGeneric {
protected:
- TlsExtensionCapture* groups_capture_;
- TlsExtensionCapture* shares_capture_;
- TlsExtensionCapture* shares_capture2_;
- TlsInspectorRecordHandshakeMessage* capture_hrr_;
+ std::shared_ptr<TlsExtensionCapture> groups_capture_;
+ std::shared_ptr<TlsExtensionCapture> shares_capture_;
+ std::shared_ptr<TlsExtensionCapture> shares_capture2_;
+ std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_;
void EnsureKeyShareSetup();
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);