summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/tls_agent.h
diff options
context:
space:
mode:
Diffstat (limited to 'nss/gtests/ssl_gtest/tls_agent.h')
-rw-r--r--nss/gtests/ssl_gtest/tls_agent.h135
1 files changed, 87 insertions, 48 deletions
diff --git a/nss/gtests/ssl_gtest/tls_agent.h b/nss/gtests/ssl_gtest/tls_agent.h
index 78923c9..32f6175 100644
--- a/nss/gtests/ssl_gtest/tls_agent.h
+++ b/nss/gtests/ssl_gtest/tls_agent.h
@@ -14,9 +14,11 @@
#include <iostream>
#include "test_io.h"
+#include "tls_filter.h"
#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"
+#include "scoped_ptrs.h"
extern bool g_ssl_gtest_verbose;
@@ -42,6 +44,8 @@ const extern std::vector<SSLNamedGroup> kECDHEGroups;
const extern std::vector<SSLNamedGroup> kFFDHEGroups;
const extern std::vector<SSLNamedGroup> kFasterDHEGroups;
+// These functions are called from callbacks. They use bare pointers because
+// TlsAgent sets up the callback and it doesn't know who owns it.
typedef std::function<SECStatus(TlsAgent* agent, bool checksig, bool isServer)>
AuthCertificateCallbackFunction;
@@ -70,25 +74,24 @@ class TlsAgent : public PollTarget {
static const std::string kServerEcdhRsa;
static const std::string kServerDsa;
- TlsAgent(const std::string& name, Role role, Mode mode);
+ TlsAgent(const std::string& name, Role role, SSLProtocolVariant variant);
virtual ~TlsAgent();
- bool Init() {
- pr_fd_ = DummyPrSocket::CreateFD(role_str(), mode_);
- if (!pr_fd_) return false;
-
- adapter_ = DummyPrSocket::GetAdapter(pr_fd_);
- if (!adapter_) return false;
-
- return true;
+ void SetPeer(std::shared_ptr<TlsAgent>& peer) {
+ adapter_->SetPeer(peer->adapter_);
}
- void SetPeer(TlsAgent* peer) { adapter_->SetPeer(peer->adapter_); }
+ void SetTlsRecordFilter(std::shared_ptr<TlsRecordFilter> filter) {
+ filter->SetAgent(this);
+ adapter_->SetPacketFilter(filter);
+ }
- void SetPacketFilter(PacketFilter* filter) {
+ void SetPacketFilter(std::shared_ptr<PacketFilter> filter) {
adapter_->SetPacketFilter(filter);
}
+ void DeletePacketFilter() { adapter_->SetPacketFilter(nullptr); }
+
void StartConnect(PRFileDesc* model = nullptr);
void CheckKEA(SSLKEAType kea_type, SSLNamedGroup group,
size_t kea_size = 0) const;
@@ -107,6 +110,9 @@ class TlsAgent : public PollTarget {
void PrepareForRenegotiate();
// Prepares for renegotiation, then actually triggers it.
void StartRenegotiate();
+ static bool LoadCertificate(const std::string& name,
+ ScopedCERTCertificate* cert,
+ ScopedSECKEYPrivateKey* priv);
bool ConfigServerCert(const std::string& name, bool updateKeyBits = false,
const SSLExtraServerCertData* serverCertData = nullptr);
bool ConfigServerCertWithChain(const std::string& name);
@@ -114,13 +120,12 @@ class TlsAgent : public PollTarget {
void SetupClientAuth();
void RequestClientAuth(bool requireAuth);
- bool GetClientAuthCredentials(CERTCertificate** cert,
- SECKEYPrivateKey** priv) const;
void ConfigureSessionCache(SessionResumptionMode mode);
void SetSessionTicketsEnabled(bool en);
void SetSessionCacheEnabled(bool en);
void Set0RttEnabled(bool en);
+ void SetFallbackSCSVEnabled(bool en);
void SetShortHeadersEnabled();
void SetVersionRange(uint16_t minver, uint16_t maxver);
void GetVersionRange(uint16_t* minver, uint16_t* maxver);
@@ -132,6 +137,7 @@ class TlsAgent : public PollTarget {
void EnableFalseStart();
void ExpectResumption();
void ExpectShortHeaders();
+ void SkipVersionChecks();
void SetSignatureSchemes(const SSLSignatureScheme* schemes, size_t count);
void EnableAlpn(const uint8_t* val, size_t len);
void CheckAlpn(SSLNextProtoState expected_state,
@@ -157,6 +163,7 @@ class TlsAgent : public PollTarget {
void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
void DisableECDHEServerKeyReuse();
bool GetPeerChainLength(size_t* count);
+ void CheckCipherSuite(uint16_t cipher_suite);
const std::string& name() const { return name_; }
@@ -166,15 +173,15 @@ class TlsAgent : public PollTarget {
State state() const { return state_; }
const CERTCertificate* peer_cert() const {
- return SSL_PeerCertificate(ssl_fd_);
+ return SSL_PeerCertificate(ssl_fd_.get());
}
const char* state_str() const { return state_str(state()); }
static const char* state_str(State state) { return states[state]; }
- PRFileDesc* ssl_fd() { return ssl_fd_; }
- DummyPrSocket* adapter() { return adapter_; }
+ PRFileDesc* ssl_fd() const { return ssl_fd_.get(); }
+ std::shared_ptr<DummyPrSocket>& adapter() { return adapter_; }
bool is_compressed() const {
return info_.compressionMethod != ssl_compression_null;
@@ -239,6 +246,9 @@ class TlsAgent : public PollTarget {
sni_callback_ = sni_callback;
}
+ void ExpectReceiveAlert(uint8_t alert, uint8_t level = 0);
+ void ExpectSendAlert(uint8_t alert, uint8_t level = 0);
+
private:
const static char* states[];
@@ -320,6 +330,18 @@ class TlsAgent : public PollTarget {
return SECSuccess;
}
+ void CheckAlert(bool sent, const SSLAlert* alert);
+
+ static void AlertReceivedCallback(const PRFileDesc* fd, void* arg,
+ const SSLAlert* alert) {
+ reinterpret_cast<TlsAgent*>(arg)->CheckAlert(false, alert);
+ }
+
+ static void AlertSentCallback(const PRFileDesc* fd, void* arg,
+ const SSLAlert* alert) {
+ reinterpret_cast<TlsAgent*>(arg)->CheckAlert(true, alert);
+ }
+
static void HandshakeCallback(PRFileDesc* fd, void* arg) {
TlsAgent* agent = reinterpret_cast<TlsAgent*>(arg);
agent->handshake_callback_called_ = true;
@@ -336,14 +358,13 @@ class TlsAgent : public PollTarget {
void Connected();
const std::string name_;
- Mode mode_;
- uint16_t server_key_bits_;
- PRFileDesc* pr_fd_;
- DummyPrSocket* adapter_;
- PRFileDesc* ssl_fd_;
+ SSLProtocolVariant variant_;
Role role_;
+ uint16_t server_key_bits_;
+ std::shared_ptr<DummyPrSocket> adapter_;
+ ScopedPRFileDesc ssl_fd_;
State state_;
- Poller::Timer* timer_handle_;
+ std::shared_ptr<Poller::Timer> timer_handle_;
bool falsestart_enabled_;
uint16_t expected_version_;
uint16_t expected_cipher_suite_;
@@ -352,6 +373,10 @@ class TlsAgent : public PollTarget {
bool can_falsestart_hook_called_;
bool sni_hook_called_;
bool auth_certificate_hook_called_;
+ uint8_t expected_received_alert_;
+ uint8_t expected_received_alert_level_;
+ uint8_t expected_sent_alert_;
+ uint8_t expected_sent_alert_level_;
bool handshake_callback_called_;
SSLChannelInfo info_;
SSLCipherSuiteInfo csinfo_;
@@ -364,6 +389,7 @@ class TlsAgent : public PollTarget {
AuthCertificateCallbackFunction auth_certificate_callback_;
SniCallbackFunction sni_callback_;
bool expect_short_headers_;
+ bool skip_version_checks_;
};
inline std::ostream& operator<<(std::ostream& stream,
@@ -375,20 +401,23 @@ class TlsAgentTestBase : public ::testing::Test {
public:
static ::testing::internal::ParamGenerator<std::string> kTlsRolesAll;
- TlsAgentTestBase(TlsAgent::Role role, Mode mode)
- : agent_(nullptr), fd_(nullptr), role_(role), mode_(mode) {}
- ~TlsAgentTestBase() {
- if (fd_) {
- PR_Close(fd_);
- }
- }
+ TlsAgentTestBase(TlsAgent::Role role, SSLProtocolVariant variant,
+ uint16_t version = 0)
+ : agent_(nullptr),
+ role_(role),
+ variant_(variant),
+ version_(version),
+ sink_adapter_(new DummyPrSocket("sink", variant)) {}
+ virtual ~TlsAgentTestBase() {}
void SetUp();
void TearDown();
- static void MakeRecord(Mode mode, uint8_t type, uint16_t version,
- const uint8_t* buf, size_t len, DataBuffer* out,
- uint64_t seq_num = 0);
+ void ExpectAlert(uint8_t alert);
+
+ static void MakeRecord(SSLProtocolVariant variant, uint8_t type,
+ uint16_t version, const uint8_t* buf, size_t len,
+ DataBuffer* out, uint64_t seq_num = 0);
void MakeRecord(uint8_t type, uint16_t version, const uint8_t* buf,
size_t len, DataBuffer* out, uint64_t seq_num = 0) const;
void MakeHandshakeMessage(uint8_t hs_type, const uint8_t* data, size_t hs_len,
@@ -403,10 +432,6 @@ class TlsAgentTestBase : public ::testing::Test {
return str == "CLIENT" ? TlsAgent::CLIENT : TlsAgent::SERVER;
}
- static inline Mode ToMode(const std::string& str) {
- return str == "TLS" ? STREAM : DGRAM;
- }
-
void Init(const std::string& server_name = TlsAgent::kServerRsa);
void Reset(const std::string& server_name = TlsAgent::kServerRsa);
@@ -415,43 +440,57 @@ class TlsAgentTestBase : public ::testing::Test {
void ProcessMessage(const DataBuffer& buffer, TlsAgent::State expected_state,
int32_t error_code = 0);
- TlsAgent* agent_;
- PRFileDesc* fd_;
+ std::unique_ptr<TlsAgent> agent_;
TlsAgent::Role role_;
- Mode mode_;
+ SSLProtocolVariant variant_;
+ uint16_t version_;
+ // This adapter is here just to accept packets from this agent.
+ std::shared_ptr<DummyPrSocket> sink_adapter_;
};
-class TlsAgentTest : public TlsAgentTestBase,
- public ::testing::WithParamInterface<
- std::tuple<std::string, std::string>> {
+class TlsAgentTest
+ : public TlsAgentTestBase,
+ public ::testing::WithParamInterface<
+ std::tuple<std::string, SSLProtocolVariant, uint16_t>> {
public:
TlsAgentTest()
: TlsAgentTestBase(ToRole(std::get<0>(GetParam())),
- ToMode(std::get<1>(GetParam()))) {}
+ std::get<1>(GetParam()), std::get<2>(GetParam())) {}
};
class TlsAgentTestClient : public TlsAgentTestBase,
- public ::testing::WithParamInterface<std::string> {
+ public ::testing::WithParamInterface<
+ std::tuple<SSLProtocolVariant, uint16_t>> {
public:
TlsAgentTestClient()
- : TlsAgentTestBase(TlsAgent::CLIENT, ToMode(GetParam())) {}
+ : TlsAgentTestBase(TlsAgent::CLIENT, std::get<0>(GetParam()),
+ std::get<1>(GetParam())) {}
};
+class TlsAgentTestClient13 : public TlsAgentTestClient {};
+
class TlsAgentStreamTestClient : public TlsAgentTestBase {
public:
- TlsAgentStreamTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, STREAM) {}
+ TlsAgentStreamTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_stream) {}
};
class TlsAgentStreamTestServer : public TlsAgentTestBase {
public:
- TlsAgentStreamTestServer() : TlsAgentTestBase(TlsAgent::SERVER, STREAM) {}
+ TlsAgentStreamTestServer()
+ : TlsAgentTestBase(TlsAgent::SERVER, ssl_variant_stream) {}
};
class TlsAgentDgramTestClient : public TlsAgentTestBase {
public:
- TlsAgentDgramTestClient() : TlsAgentTestBase(TlsAgent::CLIENT, DGRAM) {}
+ TlsAgentDgramTestClient()
+ : TlsAgentTestBase(TlsAgent::CLIENT, ssl_variant_datagram) {}
};
+inline bool operator==(const SSLVersionRange& vr1, const SSLVersionRange& vr2) {
+ return vr1.min == vr2.min && vr1.max == vr2.max;
+}
+
} // namespace nss_test
#endif