diff options
author | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2021-05-20 09:47:09 +0200 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2021-06-07 11:15:42 +0000 |
commit | 189d4fd8fad9e3c776873be51938cd31a42b6177 (patch) | |
tree | 6497caeff5e383937996768766ab3bb2081a40b2 /chromium/net/websockets | |
parent | 8bc75099d364490b22f43a7ce366b366c08f4164 (diff) | |
download | qtwebengine-chromium-189d4fd8fad9e3c776873be51938cd31a42b6177.tar.gz |
BASELINE: Update Chromium to 90.0.4430.221
Change-Id: Iff4d9d18d2fcf1a576f3b1f453010f744a232920
Reviewed-by: Allan Sandfeld Jensen <allan.jensen@qt.io>
Diffstat (limited to 'chromium/net/websockets')
9 files changed, 267 insertions, 90 deletions
diff --git a/chromium/net/websockets/websocket_basic_handshake_stream.cc b/chromium/net/websockets/websocket_basic_handshake_stream.cc index 8fd1f570a9b..abf25b43f3c 100644 --- a/chromium/net/websockets/websocket_basic_handshake_stream.cc +++ b/chromium/net/websockets/websocket_basic_handshake_stream.cc @@ -390,6 +390,11 @@ HttpStream* WebSocketBasicHandshakeStream::RenewStreamForAuth() { return handshake_stream.release(); } +const std::vector<std::string>& WebSocketBasicHandshakeStream::GetDnsAliases() + const { + return state_.GetDnsAliases(); +} + std::unique_ptr<WebSocketStream> WebSocketBasicHandshakeStream::Upgrade() { // The HttpStreamParser object has a pointer to our ClientSocketHandle. Make // sure it does not touch it again before it is destroyed. diff --git a/chromium/net/websockets/websocket_basic_handshake_stream.h b/chromium/net/websockets/websocket_basic_handshake_stream.h index a3576c22347..1b92ded49bc 100644 --- a/chromium/net/websockets/websocket_basic_handshake_stream.h +++ b/chromium/net/websockets/websocket_basic_handshake_stream.h @@ -74,7 +74,7 @@ class NET_EXPORT_PRIVATE WebSocketBasicHandshakeStream final void SetPriority(RequestPriority priority) override; void PopulateNetErrorDetails(NetErrorDetails* details) override; HttpStream* RenewStreamForAuth() override; - + const std::vector<std::string>& GetDnsAliases() const override; // This is called from the top level once correct handshake response headers // have been received. It creates an appropriate subclass of WebSocketStream diff --git a/chromium/net/websockets/websocket_basic_handshake_stream_test.cc b/chromium/net/websockets/websocket_basic_handshake_stream_test.cc index 1df0a1658f7..0653a473ffe 100644 --- a/chromium/net/websockets/websocket_basic_handshake_stream_test.cc +++ b/chromium/net/websockets/websocket_basic_handshake_stream_test.cc @@ -22,6 +22,7 @@ #include "net/traffic_annotation/network_traffic_annotation.h" #include "net/traffic_annotation/network_traffic_annotation_test_helper.h" #include "net/websockets/websocket_test_util.h" +#include "testing/gmock/include/gmock/gmock.h" #include "url/gurl.h" #include "url/origin.h" @@ -85,5 +86,72 @@ TEST(WebSocketBasicHandshakeStreamTest, ConnectionClosedOnFailure) { EXPECT_FALSE(socket_ptr->IsConnected()); } +TEST(WebSocketBasicHandshakeStreamTest, DnsAliasesCanBeAccessed) { + std::string request = WebSocketStandardRequest( + "/", "www.example.org", + url::Origin::Create(GURL("http://origin.example.org")), "", ""); + std::string response = WebSocketStandardResponse(""); + MockWrite writes[] = {MockWrite(SYNCHRONOUS, 0, request.c_str())}; + MockRead reads[] = {MockRead(SYNCHRONOUS, 1, response.c_str()), + MockRead(SYNCHRONOUS, ERR_IO_PENDING, 2)}; + + IPEndPoint end_point(IPAddress(127, 0, 0, 1), 80); + SequencedSocketData sequenced_socket_data( + MockConnect(SYNCHRONOUS, OK, end_point), reads, writes); + auto socket = std::make_unique<MockTCPClientSocket>( + AddressList(end_point), nullptr, &sequenced_socket_data); + const int connect_result = socket->Connect(CompletionOnceCallback()); + EXPECT_EQ(connect_result, OK); + + std::vector<std::string> aliases({"alias1", "alias2", "www.example.org"}); + socket->SetDnsAliases(aliases); + EXPECT_THAT(socket->GetDnsAliases(), + testing::ElementsAre("alias1", "alias2", "www.example.org")); + + const MockTCPClientSocket* const socket_ptr = socket.get(); + auto handle = std::make_unique<ClientSocketHandle>(); + handle->SetSocket(std::move(socket)); + EXPECT_THAT(handle->socket()->GetDnsAliases(), + testing::ElementsAre("alias1", "alias2", "www.example.org")); + + DummyConnectDelegate delegate; + WebSocketEndpointLockManager endpoint_lock_manager; + TestWebSocketStreamRequestAPI stream_request_api; + std::vector<std::string> extensions = { + "permessage-deflate; client_max_window_bits"}; + WebSocketBasicHandshakeStream basic_handshake_stream( + std::move(handle), &delegate, false, {}, extensions, &stream_request_api, + &endpoint_lock_manager); + basic_handshake_stream.SetWebSocketKeyForTesting("dGhlIHNhbXBsZSBub25jZQ=="); + HttpRequestInfo request_info; + request_info.url = GURL("ws://www.example.com/"); + request_info.method = "GET"; + request_info.traffic_annotation = + MutableNetworkTrafficAnnotationTag(TRAFFIC_ANNOTATION_FOR_TESTS); + TestCompletionCallback callback1; + NetLogWithSource net_log; + const int result1 = + callback1.GetResult(basic_handshake_stream.InitializeStream( + &request_info, true, LOWEST, net_log, callback1.callback())); + EXPECT_EQ(result1, OK); + + auto request_headers = WebSocketCommonTestHeaders(); + HttpResponseInfo response_info; + TestCompletionCallback callback2; + const int result2 = callback2.GetResult(basic_handshake_stream.SendRequest( + request_headers, &response_info, callback2.callback())); + EXPECT_EQ(result2, OK); + + TestCompletionCallback callback3; + const int result3 = callback3.GetResult( + basic_handshake_stream.ReadResponseHeaders(callback2.callback())); + EXPECT_EQ(result3, OK); + + EXPECT_TRUE(socket_ptr->IsConnected()); + + EXPECT_THAT(basic_handshake_stream.GetDnsAliases(), + testing::ElementsAre("alias1", "alias2", "www.example.org")); +} + } // namespace } // namespace net diff --git a/chromium/net/websockets/websocket_channel_test.cc b/chromium/net/websockets/websocket_channel_test.cc index 90df01a8e53..f3499de3703 100644 --- a/chromium/net/websockets/websocket_channel_test.cc +++ b/chromium/net/websockets/websocket_channel_test.cc @@ -106,15 +106,16 @@ namespace { using ::base::TimeDelta; +using ::testing::_; using ::testing::AnyNumber; using ::testing::DefaultValue; +using ::testing::DoAll; using ::testing::InSequence; using ::testing::MockFunction; using ::testing::NotNull; using ::testing::Return; using ::testing::SaveArg; using ::testing::StrictMock; -using ::testing::_; // A selection of characters that have traditionally been mangled in some // environment or other, for testing 8-bit cleanliness. diff --git a/chromium/net/websockets/websocket_end_to_end_test.cc b/chromium/net/websockets/websocket_end_to_end_test.cc index d12fedd1115..79b58de0367 100644 --- a/chromium/net/websockets/websocket_end_to_end_test.cc +++ b/chromium/net/websockets/websocket_end_to_end_test.cc @@ -32,8 +32,11 @@ #include "net/base/host_port_pair.h" #include "net/base/ip_endpoint.h" #include "net/base/isolation_info.h" +#include "net/base/load_flags.h" +#include "net/base/net_errors.h" #include "net/base/proxy_delegate.h" #include "net/base/url_util.h" +#include "net/cert/ct_policy_status.h" #include "net/http/http_request_headers.h" #include "net/log/net_log.h" #include "net/proxy_resolution/configured_proxy_resolution_service.h" @@ -42,6 +45,8 @@ #include "net/proxy_resolution/proxy_config_service_fixed.h" #include "net/proxy_resolution/proxy_config_with_annotation.h" #include "net/proxy_resolution/proxy_info.h" +#include "net/socket/socket_test_util.h" +#include "net/test/cert_test_util.h" #include "net/test/embedded_test_server/embedded_test_server.h" #include "net/test/embedded_test_server/http_request.h" #include "net/test/embedded_test_server/http_response.h" @@ -52,8 +57,10 @@ #include "net/url_request/url_request.h" #include "net/url_request/url_request_context.h" #include "net/url_request/url_request_test_util.h" +#include "net/url_request/websocket_handshake_userdata_key.h" #include "net/websockets/websocket_channel.h" #include "net/websockets/websocket_event_interface.h" +#include "net/websockets/websocket_test_util.h" #include "testing/gtest/include/gtest/gtest.h" #include "url/gurl.h" #include "url/origin.h" @@ -70,13 +77,6 @@ using test_server::HttpResponse; static const char kEchoServer[] = "echo-with-no-extension"; -// Simplify changing URL schemes. -GURL ReplaceUrlScheme(const GURL& in_url, const base::StringPiece& scheme) { - GURL::Replacements replacements; - replacements.SetSchemeStr(scheme); - return in_url.ReplaceComponents(replacements); -} - // An implementation of WebSocketEventInterface that waits for and records the // results of the connect. class ConnectTestingEventInterface : public WebSocketEventInterface { @@ -504,81 +504,6 @@ TEST_F(WebSocketEndToEndTest, TruncatedResponse) { EXPECT_FALSE(ConnectAndWait(ws_url)); } -// Regression test for crbug.com/455215 "HSTS not applied to WebSocket" -TEST_F(WebSocketEndToEndTest, HstsHttpsToWebSocket) { - EmbeddedTestServer https_server(net::EmbeddedTestServer::Type::TYPE_HTTPS); - https_server.SetSSLConfig( - net::EmbeddedTestServer::CERT_COMMON_NAME_IS_DOMAIN); - https_server.ServeFilesFromSourceDirectory("net/data/url_request_unittest"); - - SpawnedTestServer::SSLOptions ssl_options( - SpawnedTestServer::SSLOptions::CERT_COMMON_NAME_IS_DOMAIN); - SpawnedTestServer wss_server(SpawnedTestServer::TYPE_WSS, ssl_options, - GetWebSocketTestDataDirectory()); - - ASSERT_TRUE(https_server.Start()); - ASSERT_TRUE(wss_server.Start()); - InitialiseContext(); - // Set HSTS via https: - TestDelegate delegate; - GURL https_page = https_server.GetURL("/hsts-headers.html"); - std::unique_ptr<URLRequest> request(context_.CreateRequest( - https_page, DEFAULT_PRIORITY, &delegate, TRAFFIC_ANNOTATION_FOR_TESTS)); - request->Start(); - delegate.RunUntilComplete(); - EXPECT_EQ(OK, delegate.request_status()); - - // Check HSTS with ws: - // Change the scheme from wss: to ws: to verify that it is switched back. - GURL ws_url = ReplaceUrlScheme(wss_server.GetURL(kEchoServer), "ws"); - EXPECT_TRUE(ConnectAndWait(ws_url)); -} - -TEST_F(WebSocketEndToEndTest, HstsWebSocketToHttps) { - EmbeddedTestServer https_server(net::EmbeddedTestServer::Type::TYPE_HTTPS); - https_server.SetSSLConfig( - net::EmbeddedTestServer::CERT_COMMON_NAME_IS_DOMAIN); - https_server.ServeFilesFromSourceDirectory("net/data/url_request_unittest"); - - SpawnedTestServer::SSLOptions ssl_options( - SpawnedTestServer::SSLOptions::CERT_COMMON_NAME_IS_DOMAIN); - SpawnedTestServer wss_server(SpawnedTestServer::TYPE_WSS, ssl_options, - GetWebSocketTestDataDirectory()); - ASSERT_TRUE(https_server.Start()); - ASSERT_TRUE(wss_server.Start()); - InitialiseContext(); - // Set HSTS via wss: - GURL wss_url = wss_server.GetURL("set-hsts"); - EXPECT_TRUE(ConnectAndWait(wss_url)); - - // Verify via http: - TestDelegate delegate; - GURL http_page = - ReplaceUrlScheme(https_server.GetURL("/simple.html"), "http"); - std::unique_ptr<URLRequest> request(context_.CreateRequest( - http_page, DEFAULT_PRIORITY, &delegate, TRAFFIC_ANNOTATION_FOR_TESTS)); - request->Start(); - delegate.RunUntilComplete(); - EXPECT_EQ(OK, delegate.request_status()); - EXPECT_TRUE(request->url().SchemeIs("https")); -} - -TEST_F(WebSocketEndToEndTest, HstsWebSocketToWebSocket) { - SpawnedTestServer::SSLOptions ssl_options( - SpawnedTestServer::SSLOptions::CERT_COMMON_NAME_IS_DOMAIN); - SpawnedTestServer wss_server(SpawnedTestServer::TYPE_WSS, ssl_options, - GetWebSocketTestDataDirectory()); - ASSERT_TRUE(wss_server.Start()); - InitialiseContext(); - // Set HSTS via wss: - GURL wss_url = wss_server.GetURL("set-hsts"); - EXPECT_TRUE(ConnectAndWait(wss_url)); - - // Verify via wss: - GURL ws_url = ReplaceUrlScheme(wss_server.GetURL(kEchoServer), "ws"); - EXPECT_TRUE(ConnectAndWait(ws_url)); -} - // Regression test for crbug.com/180504 "WebSocket handshake fails when HTTP // headers have trailing LWS". TEST_F(WebSocketEndToEndTest, TrailingWhitespace) { @@ -609,6 +534,161 @@ TEST_F(WebSocketEndToEndTest, HeaderContinuations) { event_interface_->extensions()); } +// These are not true end-to-end tests as the SpawnedTestServer doesn't +// support TLS 1.2. +// TODO(ricea): Make these be true end-to-end tests again when +// SpawnedTestServer supports TLS 1.2 or EmbeddedTestServer supports +// WebSockets. +class WebSocketHstsTest : public TestWithTaskEnvironment { + protected: + WebSocketHstsTest() : context_(true) { + context_.set_client_socket_factory(&socket_factory_); + context_.Init(); + } + + void MakeHttpConnection(const GURL& url) { + // Set up SSL details, because otherwise HSTS headers aren't processed. + SSLSocketDataProvider ssl_socket_data(net::ASYNC, net::OK); + ssl_socket_data.ssl_info.cert = + ImportCertFromFile(GetTestCertsDirectory(), "ok_cert.pem"); + ssl_socket_data.ssl_info.is_issued_by_known_root = true; + ssl_socket_data.ssl_info.ct_policy_compliance = + ct::CTPolicyCompliance::CT_POLICY_COMPLIES_VIA_SCTS; + ssl_socket_data.ssl_info.cert_status = 0; + socket_factory_.AddSSLSocketDataProvider(&ssl_socket_data); + + req_ = context_.CreateRequest(url, DEFAULT_PRIORITY, &delegate_, + TRAFFIC_ANNOTATION_FOR_TESTS); + + MockWrite writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.example.org\r\n" + "Connection: keep-alive\r\n" + "User-Agent: \r\n" + "Accept-Encoding: gzip, deflate\r\n" + "Accept-Language: en-us,fr\r\n\r\n")}; + MockRead reads[] = {MockRead("HTTP/1.1 200 OK\r\n" + "Strict-Transport-Security: max-age=123; " + "includeSubdomains\r\n\r\n"), + MockRead(ASYNC, 0)}; + + StaticSocketDataProvider data(reads, writes); + socket_factory_.AddSocketDataProvider(&data); + + req_->Start(); + base::RunLoop().RunUntilIdle(); + } + + void MakeWebsocketConnection(const GURL& url) { + // Set up SSL details, because otherwise HSTS headers aren't processed. + SSLSocketDataProvider ssl_socket_data(net::ASYNC, net::OK); + ssl_socket_data.ssl_info.cert = + ImportCertFromFile(GetTestCertsDirectory(), "ok_cert.pem"); + ssl_socket_data.ssl_info.is_issued_by_known_root = true; + ssl_socket_data.ssl_info.ct_policy_compliance = + ct::CTPolicyCompliance::CT_POLICY_COMPLIES_VIA_SCTS; + ssl_socket_data.ssl_info.cert_status = 0; + socket_factory_.AddSSLSocketDataProvider(&ssl_socket_data); + + req_ = context_.CreateRequest(url, DEFAULT_PRIORITY, &delegate_, + TRAFFIC_ANNOTATION_FOR_TESTS); + + HttpRequestHeaders headers; + headers.SetHeader("Connection", "Upgrade"); + headers.SetHeader("Upgrade", "websocket"); + headers.SetHeader("Origin", "null"); + headers.SetHeader("Sec-WebSocket-Version", "13"); + req_->SetExtraRequestHeaders(headers); + + MockWrite writes[] = { + MockWrite("GET / HTTP/1.1\r\n" + "Host: www.example.org\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Origin: null\r\n" + "Sec-WebSocket-Version: 13\r\n" + "User-Agent: \r\n" + "Accept-Encoding: gzip, deflate\r\n" + "Accept-Language: en-us,fr\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "Sec-WebSocket-Extensions: permessage-deflate; " + "client_max_window_bits\r\n\r\n")}; + MockRead reads[] = { + MockRead("HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=\r\n" + "Strict-Transport-Security: max-age=123; " + "includeSubdomains\r\n\r\n"), + MockRead(ASYNC, 0)}; + + StaticSocketDataProvider data(reads, writes); + socket_factory_.AddSocketDataProvider(&data); + + req_->SetUserData( + kWebSocketHandshakeUserDataKey, + std::make_unique<TestWebSocketHandshakeStreamCreateHelper>()); + req_->SetLoadFlags(LOAD_DISABLE_CACHE); + req_->Start(); + base::RunLoop().RunUntilIdle(); + } + + TestURLRequestContext context_; + MockClientSocketFactory socket_factory_; + TestDelegate delegate_; + std::unique_ptr<URLRequest> req_; +}; + +// Regression test for crbug.com/455215 "HSTS not applied to WebSocket" +TEST_F(WebSocketHstsTest, HTTPSToWebSocket) { + // Set HSTS via https: + MakeHttpConnection(GURL("https://www.example.org")); + EXPECT_EQ(OK, delegate_.request_status()); + + ASSERT_TRUE(context_.transport_security_state()->ShouldUpgradeToSSL( + "www.example.org")); + + // Check HSTS by starting a request over ws: and verifying that it gets + // ugpraded to wss:. + MakeWebsocketConnection(GURL("ws://www.example.org")); + EXPECT_EQ(OK, delegate_.request_status()); + EXPECT_TRUE(delegate_.response_completed()); + EXPECT_TRUE(req_->url().SchemeIs("wss")); +} + +TEST_F(WebSocketHstsTest, WebSocketToHTTP) { + // Set HSTS via wss: + MakeWebsocketConnection(GURL("wss://www.example.org")); + EXPECT_EQ(OK, delegate_.request_status()); + EXPECT_TRUE(delegate_.response_completed()); + + ASSERT_TRUE(context_.transport_security_state()->ShouldUpgradeToSSL( + "www.example.org")); + + // Check HSTS by starting a request over http: and verifying that it gets + // ugpraded to https:. + MakeHttpConnection(GURL("http://www.example.org")); + EXPECT_EQ(OK, delegate_.request_status()); + EXPECT_TRUE(req_->url().SchemeIs("https")); +} + +TEST_F(WebSocketHstsTest, WebSocketToWebSocket) { + // Set HSTS via wss: + MakeWebsocketConnection(GURL("wss://www.example.org")); + EXPECT_EQ(OK, delegate_.request_status()); + EXPECT_TRUE(delegate_.response_completed()); + + ASSERT_TRUE(context_.transport_security_state()->ShouldUpgradeToSSL( + "www.example.org")); + + // Check HSTS by starting a request over ws: and verifying that it gets + // ugpraded to wss:. + MakeWebsocketConnection(GURL("ws://www.example.org")); + EXPECT_EQ(OK, delegate_.request_status()); + EXPECT_TRUE(delegate_.response_completed()); + EXPECT_TRUE(req_->url().SchemeIs("wss")); +} + } // namespace } // namespace net diff --git a/chromium/net/websockets/websocket_extension_parser.cc b/chromium/net/websockets/websocket_extension_parser.cc index 8d673c1ea1f..c37362fff4c 100644 --- a/chromium/net/websockets/websocket_extension_parser.cc +++ b/chromium/net/websockets/websocket_extension_parser.cc @@ -5,6 +5,7 @@ #include "net/websockets/websocket_extension_parser.h" #include "base/check_op.h" +#include "base/strings/string_piece.h" #include "net/http/http_util.h" namespace net { @@ -50,7 +51,7 @@ bool WebSocketExtensionParser::ConsumeExtension(WebSocketExtension* extension) { base::StringPiece name; if (!ConsumeToken(&name)) return false; - *extension = WebSocketExtension(name.as_string()); + *extension = WebSocketExtension(std::string(name)); while (ConsumeIfMatch(';')) { WebSocketExtension::Parameter parameter((std::string())); @@ -71,7 +72,7 @@ bool WebSocketExtensionParser::ConsumeExtensionParameter( return false; if (!ConsumeIfMatch('=')) { - *parameter = WebSocketExtension::Parameter(name.as_string()); + *parameter = WebSocketExtension::Parameter(std::string(name)); return true; } @@ -81,9 +82,9 @@ bool WebSocketExtensionParser::ConsumeExtensionParameter( } else { if (!ConsumeToken(&value)) return false; - value_string = value.as_string(); + value_string = std::string(value); } - *parameter = WebSocketExtension::Parameter(name.as_string(), value_string); + *parameter = WebSocketExtension::Parameter(std::string(name), value_string); return true; } diff --git a/chromium/net/websockets/websocket_http2_handshake_stream.cc b/chromium/net/websockets/websocket_http2_handshake_stream.cc index ce60d0c7d6a..ee56bd42550 100644 --- a/chromium/net/websockets/websocket_http2_handshake_stream.cc +++ b/chromium/net/websockets/websocket_http2_handshake_stream.cc @@ -9,6 +9,7 @@ #include "base/bind.h" #include "base/check_op.h" +#include "base/no_destructor.h" #include "base/notreached.h" #include "base/strings/stringprintf.h" #include "base/time/time.h" @@ -236,6 +237,12 @@ HttpStream* WebSocketHttp2HandshakeStream::RenewStreamForAuth() { return nullptr; } +const std::vector<std::string>& WebSocketHttp2HandshakeStream::GetDnsAliases() + const { + static const base::NoDestructor<std::vector<std::string>> emptyvector_result; + return *emptyvector_result; +} + std::unique_ptr<WebSocketStream> WebSocketHttp2HandshakeStream::Upgrade() { DCHECK(extension_params_.get()); diff --git a/chromium/net/websockets/websocket_http2_handshake_stream.h b/chromium/net/websockets/websocket_http2_handshake_stream.h index 634806f732a..f6d0b0285cf 100644 --- a/chromium/net/websockets/websocket_http2_handshake_stream.h +++ b/chromium/net/websockets/websocket_http2_handshake_stream.h @@ -86,6 +86,7 @@ class NET_EXPORT_PRIVATE WebSocketHttp2HandshakeStream void SetPriority(RequestPriority priority) override; void PopulateNetErrorDetails(NetErrorDetails* details) override; HttpStream* RenewStreamForAuth() override; + const std::vector<std::string>& GetDnsAliases() const override; // WebSocketHandshakeStreamBase methods. diff --git a/chromium/net/websockets/websocket_stream_test.cc b/chromium/net/websockets/websocket_stream_test.cc index e9554557849..c35bd5bd08a 100644 --- a/chromium/net/websockets/websocket_stream_test.cc +++ b/chromium/net/websockets/websocket_stream_test.cc @@ -17,10 +17,13 @@ #include "base/metrics/statistics_recorder.h" #include "base/run_loop.h" #include "base/stl_util.h" +#include "base/strings/string_piece.h" #include "base/strings/stringprintf.h" #include "base/test/metrics/histogram_tester.h" +#include "base/test/scoped_feature_list.h" #include "base/timer/mock_timer.h" #include "base/timer/timer.h" +#include "net/base/features.h" #include "net/base/isolation_info.h" #include "net/base/net_errors.h" #include "net/base/url_util.h" @@ -103,7 +106,14 @@ class WebSocketStreamCreateTest : public TestWithParam<HandshakeStreamType>, : stream_type_(GetParam()), http2_response_status_("200"), reset_websocket_http2_stream_(false), - sequence_number_(0) {} + sequence_number_(0) { + // Make sure these tests all pass with connection partitioning enabled. The + // disabled case is less interesting, and is tested more directly at lower + // layers. + feature_list_.InitAndEnableFeature( + features::kPartitionConnectionsByNetworkIsolationKey); + } + ~WebSocketStreamCreateTest() override { // Permit any endpoint locks to be released. stream_request_.reset(); @@ -294,6 +304,8 @@ class WebSocketStreamCreateTest : public TestWithParam<HandshakeStreamType>, std::unique_ptr<URLRequest> request = context->CreateRequest( GURL("https://www.example.org/"), DEFAULT_PRIORITY, &delegate, TRAFFIC_ANNOTATION_FOR_TESTS); + // The IsolationInfo has to match for a socket to be reused. + request->set_isolation_info(CreateIsolationInfo()); request->Start(); EXPECT_TRUE(request->is_pending()); delegate.RunUntilComplete(); @@ -383,6 +395,8 @@ class WebSocketStreamCreateTest : public TestWithParam<HandshakeStreamType>, const HandshakeStreamType stream_type_; private: + base::test::ScopedFeatureList feature_list_; + std::unique_ptr<base::OneShotTimer> timer_; std::string additional_data_; const char* http2_response_status_; @@ -482,7 +496,7 @@ class WebSocketStreamCreateBasicAuthTest : public WebSocketStreamCreateTest { url, NoSubProtocols(), HttpRequestHeaders(), helper_.BuildAuthSocketData(kUnauthorizedResponse, RequestExpectation(base64_user_pass), - response2.as_string())); + std::string(response2))); } static std::string RequestExpectation(base::StringPiece base64_user_pass) { |