summaryrefslogtreecommitdiff
path: root/nss/gtests/ssl_gtest/tls_connect.h
blob: 73e8dc81a928d76e006901bd618c84c9ebdde5a4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
/* -*- Mode: C++; tab-width: 8; indent-tabs-mode: nil; c-basic-offset: 2 -*- */
/* vim: set ts=2 et sw=2 tw=80: */
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this file,
 * You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef tls_connect_h_
#define tls_connect_h_

#include <tuple>

#include "sslproto.h"
#include "sslt.h"

#include "tls_agent.h"
#include "tls_filter.h"

#define GTEST_HAS_RTTI 0
#include "gtest/gtest.h"

namespace nss_test {

extern std::string VersionString(uint16_t version);

// A generic TLS connection test base.
class TlsConnectTestBase : public ::testing::Test {
 public:
  static ::testing::internal::ParamGenerator<SSLProtocolVariant>
      kTlsVariantsStream;
  static ::testing::internal::ParamGenerator<SSLProtocolVariant>
      kTlsVariantsDatagram;
  static ::testing::internal::ParamGenerator<SSLProtocolVariant>
      kTlsVariantsAll;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV10;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV11;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV12;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV10V11;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV11V12;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV10ToV12;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV13;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV11Plus;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsV12Plus;
  static ::testing::internal::ParamGenerator<uint16_t> kTlsVAll;

  TlsConnectTestBase(SSLProtocolVariant variant, uint16_t version);
  virtual ~TlsConnectTestBase();

  void SetUp();
  void TearDown();

  // Initialize client and server.
  void Init();
  // Clear the statistics.
  void ClearStats();
  // Clear the server session cache.
  void ClearServerCache();
  // Make sure TLS is configured for a connection.
  void EnsureTlsSetup();
  // Reset and keep the same certificate names
  void Reset();
  // Reset, and update the certificate names on both peers
  void Reset(const std::string& server_name,
             const std::string& client_name = "client");

  // Run the handshake.
  void Handshake();
  // Connect and check that it works.
  void Connect();
  // Check that the connection was successfully established.
  void CheckConnected();
  // Connect and expect it to fail.
  void ConnectExpectFail();
  void ExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
  void ConnectExpectAlert(std::shared_ptr<TlsAgent>& sender, uint8_t alert);
  void ConnectExpectFailOneSide(TlsAgent::Role failingSide);
  void ConnectWithCipherSuite(uint16_t cipher_suite);
  // Check that the keys used in the handshake match expectations.
  void CheckKeys(SSLKEAType kea_type, SSLNamedGroup kea_group,
                 SSLAuthType auth_type, SSLSignatureScheme sig_scheme) const;
  // This version guesses some of the values.
  void CheckKeys(SSLKEAType kea_type, SSLAuthType auth_type) const;
  // This version assumes defaults.
  void CheckKeys() const;
  void CheckGroups(const DataBuffer& groups,
                   std::function<void(SSLNamedGroup)> check_group);
  void CheckShares(const DataBuffer& shares,
                   std::function<void(SSLNamedGroup)> check_group);

  void ConfigureVersion(uint16_t version);
  void SetExpectedVersion(uint16_t version);
  // Expect resumption of a particular type.
  void ExpectResumption(SessionResumptionMode expected);
  void DisableAllCiphers();
  void EnableOnlyStaticRsaCiphers();
  void EnableOnlyDheCiphers();
  void EnableSomeEcdhCiphers();
  void EnableExtendedMasterSecret();
  void ConfigureSessionCache(SessionResumptionMode client,
                             SessionResumptionMode server);
  void EnableAlpn();
  void EnableAlpn(const uint8_t* val, size_t len);
  void EnsureModelSockets();
  void CheckAlpn(const std::string& val);
  void EnableSrtp();
  void CheckSrtp() const;
  void SendReceive();
  void SetupForZeroRtt();
  void SetupForResume();
  void ZeroRttSendReceive(
      bool expect_writable, bool expect_readable,
      std::function<bool()> post_clienthello_check = nullptr);
  void Receive(size_t amount);
  void ExpectExtendedMasterSecret(bool expected);
  void ExpectEarlyDataAccepted(bool expected);
  void DisableECDHEServerKeyReuse();
  void SkipVersionChecks();

 protected:
  SSLProtocolVariant variant_;
  std::shared_ptr<TlsAgent> client_;
  std::shared_ptr<TlsAgent> server_;
  std::unique_ptr<TlsAgent> client_model_;
  std::unique_ptr<TlsAgent> server_model_;
  uint16_t version_;
  SessionResumptionMode expected_resumption_mode_;
  std::vector<std::vector<uint8_t>> session_ids_;

  // A simple value of "a", "b".  Note that the preferred value of "a" is placed
  // at the end, because the NSS API follows the now defunct NPN specification,
  // which places the preferred (and default) entry at the end of the list.
  // NSS will move this final entry to the front when used with ALPN.
  const uint8_t alpn_dummy_val_[4] = {0x01, 0x62, 0x01, 0x61};

 private:
  void CheckResumption(SessionResumptionMode expected);
  void CheckExtendedMasterSecret();
  void CheckEarlyDataAccepted();

  bool expect_extended_master_secret_;
  bool expect_early_data_accepted_;
  bool skip_version_checks_;

  // Track groups and make sure that there are no duplicates.
  class DuplicateGroupChecker {
   public:
    void AddAndCheckGroup(SSLNamedGroup group) {
      EXPECT_EQ(groups_.end(), groups_.find(group))
          << "Group " << group << " should not be duplicated";
      groups_.insert(group);
    }

   private:
    std::set<SSLNamedGroup> groups_;
  };
};

// A non-parametrized TLS test base.
class TlsConnectTest : public TlsConnectTestBase {
 public:
  TlsConnectTest() : TlsConnectTestBase(ssl_variant_stream, 0) {}
};

// A non-parametrized DTLS-only test base.
class DtlsConnectTest : public TlsConnectTestBase {
 public:
  DtlsConnectTest() : TlsConnectTestBase(ssl_variant_datagram, 0) {}
};

// A TLS-only test base.
class TlsConnectStream : public TlsConnectTestBase,
                         public ::testing::WithParamInterface<uint16_t> {
 public:
  TlsConnectStream() : TlsConnectTestBase(ssl_variant_stream, GetParam()) {}
};

// A TLS-only test base for tests before 1.3
class TlsConnectStreamPre13 : public TlsConnectStream {};

// A DTLS-only test base.
class TlsConnectDatagram : public TlsConnectTestBase,
                           public ::testing::WithParamInterface<uint16_t> {
 public:
  TlsConnectDatagram() : TlsConnectTestBase(ssl_variant_datagram, GetParam()) {}
};

// A generic test class that can be either stream or datagram and a single
// version of TLS.  This is configured in ssl_loopback_unittest.cc.
class TlsConnectGeneric : public TlsConnectTestBase,
                          public ::testing::WithParamInterface<
                              std::tuple<SSLProtocolVariant, uint16_t>> {
 public:
  TlsConnectGeneric();
};

// A Pre TLS 1.2 generic test.
class TlsConnectPre12 : public TlsConnectTestBase,
                        public ::testing::WithParamInterface<
                            std::tuple<SSLProtocolVariant, uint16_t>> {
 public:
  TlsConnectPre12();
};

// A TLS 1.2 only generic test.
class TlsConnectTls12
    : public TlsConnectTestBase,
      public ::testing::WithParamInterface<SSLProtocolVariant> {
 public:
  TlsConnectTls12();
};

// A TLS 1.2 only stream test.
class TlsConnectStreamTls12 : public TlsConnectTestBase {
 public:
  TlsConnectStreamTls12()
      : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_2) {}
};

// A TLS 1.2+ generic test.
class TlsConnectTls12Plus : public TlsConnectTestBase,
                            public ::testing::WithParamInterface<
                                std::tuple<SSLProtocolVariant, uint16_t>> {
 public:
  TlsConnectTls12Plus();
};

// A TLS 1.3 only generic test.
class TlsConnectTls13
    : public TlsConnectTestBase,
      public ::testing::WithParamInterface<SSLProtocolVariant> {
 public:
  TlsConnectTls13();
};

// A TLS 1.3 only stream test.
class TlsConnectStreamTls13 : public TlsConnectTestBase {
 public:
  TlsConnectStreamTls13()
      : TlsConnectTestBase(ssl_variant_stream, SSL_LIBRARY_VERSION_TLS_1_3) {}
};

class TlsConnectDatagram13 : public TlsConnectTestBase {
 public:
  TlsConnectDatagram13()
      : TlsConnectTestBase(ssl_variant_datagram, SSL_LIBRARY_VERSION_TLS_1_3) {}
};

// A variant that is used only with Pre13.
class TlsConnectGenericPre13 : public TlsConnectGeneric {};

class TlsKeyExchangeTest : public TlsConnectGeneric {
 protected:
  std::shared_ptr<TlsExtensionCapture> groups_capture_;
  std::shared_ptr<TlsExtensionCapture> shares_capture_;
  std::shared_ptr<TlsExtensionCapture> shares_capture2_;
  std::shared_ptr<TlsInspectorRecordHandshakeMessage> capture_hrr_;

  void EnsureKeyShareSetup();
  void ConfigNamedGroups(const std::vector<SSLNamedGroup>& groups);
  std::vector<SSLNamedGroup> GetGroupDetails(const DataBuffer& ext);
  std::vector<SSLNamedGroup> GetShareDetails(const DataBuffer& ext);
  void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
                       const std::vector<SSLNamedGroup>& expectedShares);
  void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
                       const std::vector<SSLNamedGroup>& expectedShares,
                       SSLNamedGroup expectedShare2);

 private:
  void CheckKEXDetails(const std::vector<SSLNamedGroup>& expectedGroups,
                       const std::vector<SSLNamedGroup>& expectedShares,
                       bool expect_hrr);
};

class TlsKeyExchangeTest13 : public TlsKeyExchangeTest {};
class TlsKeyExchangeTestPre13 : public TlsKeyExchangeTest {};

}  // namespace nss_test

#endif