diff options
author | Allan Sandfeld Jensen <allan.jensen@theqtcompany.com> | 2016-05-09 14:22:11 +0200 |
---|---|---|
committer | Allan Sandfeld Jensen <allan.jensen@qt.io> | 2016-05-09 15:11:45 +0000 |
commit | 2ddb2d3e14eef3de7dbd0cef553d669b9ac2361c (patch) | |
tree | e75f511546c5fd1a173e87c1f9fb11d7ac8d1af3 /chromium/net/socket | |
parent | a4f3d46271c57e8155ba912df46a05559d14726e (diff) | |
download | qtwebengine-chromium-2ddb2d3e14eef3de7dbd0cef553d669b9ac2361c.tar.gz |
BASELINE: Update Chromium to 51.0.2704.41
Also adds in all smaller components by reversing logic for exclusion.
Change-Id: Ibf90b506e7da088ea2f65dcf23f2b0992c504422
Reviewed-by: Joerg Bornemann <joerg.bornemann@theqtcompany.com>
Diffstat (limited to 'chromium/net/socket')
70 files changed, 2613 insertions, 2379 deletions
diff --git a/chromium/net/socket/client_socket_handle.cc b/chromium/net/socket/client_socket_handle.cc index b177fb6f3b2..97b1b89d844 100644 --- a/chromium/net/socket/client_socket_handle.cc +++ b/chromium/net/socket/client_socket_handle.cc @@ -10,6 +10,7 @@ #include "base/bind_helpers.h" #include "base/compiler_specific.h" #include "base/logging.h" +#include "base/trace_event/trace_event.h" #include "net/base/net_errors.h" #include "net/socket/client_socket_pool.h" @@ -140,6 +141,7 @@ void ClientSocketHandle::SetSocket(scoped_ptr<StreamSocket> s) { } void ClientSocketHandle::OnIOComplete(int result) { + TRACE_EVENT0("net", "ClientSocketHandle::OnIOComplete"); CompletionCallback callback = user_callback_; user_callback_.Reset(); HandleInitCompletion(result); diff --git a/chromium/net/socket/client_socket_handle.h b/chromium/net/socket/client_socket_handle.h index c5b2720f7d3..a8af5c078be 100644 --- a/chromium/net/socket/client_socket_handle.h +++ b/chromium/net/socket/client_socket_handle.h @@ -50,6 +50,8 @@ class NET_EXPORT ClientSocketHandle { // ClientSocketPool to obtain a connected socket, possibly reusing one. This // method returns either OK or ERR_IO_PENDING. On ERR_IO_PENDING, |priority| // is used to determine the placement in ClientSocketPool's wait list. + // If |respect_limits| is DISABLED, will bypass the wait list, but |priority| + // must also be HIGHEST, if set. // // If this method succeeds, then the socket member will be set to an existing // connected socket if an existing connected socket was available to reuse, @@ -78,6 +80,7 @@ class NET_EXPORT ClientSocketHandle { int Init(const std::string& group_name, const scoped_refptr<typename PoolType::SocketParams>& socket_params, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const CompletionCallback& callback, PoolType* pool, const BoundNetLog& net_log); @@ -236,6 +239,7 @@ int ClientSocketHandle::Init( const std::string& group_name, const scoped_refptr<typename PoolType::SocketParams>& socket_params, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const CompletionCallback& callback, PoolType* pool, const BoundNetLog& net_log) { @@ -247,8 +251,8 @@ int ClientSocketHandle::Init( pool_ = pool; group_name_ = group_name; init_time_ = base::TimeTicks::Now(); - int rv = pool_->RequestSocket( - group_name, &socket_params, priority, this, callback_, net_log); + int rv = pool_->RequestSocket(group_name, &socket_params, priority, + respect_limits, this, callback_, net_log); if (rv == ERR_IO_PENDING) { user_callback_ = callback; } else { diff --git a/chromium/net/socket/client_socket_pool.h b/chromium/net/socket/client_socket_pool.h index e1785a3a7c7..73d90dd485c 100644 --- a/chromium/net/socket/client_socket_pool.h +++ b/chromium/net/socket/client_socket_pool.h @@ -11,7 +11,6 @@ #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" -#include "base/template_util.h" #include "base/time/time.h" #include "net/base/completion_callback.h" #include "net/base/load_states.h" @@ -61,11 +60,14 @@ class NET_EXPORT LowerLayeredPool { // A ClientSocketPool is used to restrict the number of sockets open at a time. // It also maintains a list of idle persistent sockets. // +// Subclasses must also have an inner class SocketParams which is +// the type for the |params| argument in RequestSocket() and +// RequestSockets() below. class NET_EXPORT ClientSocketPool : public LowerLayeredPool { public: - // Subclasses must also have an inner class SocketParams which is - // the type for the |params| argument in RequestSocket() and - // RequestSockets() below. + // Indicates whether or not a request for a socket should respect the + // SocketPool's global and per-group socket limits. + enum class RespectLimits { DISABLED, ENABLED }; // Requests a connected socket for a group_name. // @@ -96,9 +98,12 @@ class NET_EXPORT ClientSocketPool : public LowerLayeredPool { // client of completion. // // Profiling information for the request is saved to |net_log| if non-NULL. + // + // If |respect_limits| is DISABLED, priority must be HIGHEST. virtual int RequestSocket(const std::string& group_name, const void* params, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) = 0; diff --git a/chromium/net/socket/client_socket_pool_base.cc b/chromium/net/socket/client_socket_pool_base.cc index 4a76913b7a4..b240e3db026 100644 --- a/chromium/net/socket/client_socket_pool_base.cc +++ b/chromium/net/socket/client_socket_pool_base.cc @@ -16,6 +16,7 @@ #include "base/strings/string_util.h" #include "base/thread_task_runner_handle.h" #include "base/time/time.h" +#include "base/trace_event/trace_event.h" #include "base/values.h" #include "net/base/net_errors.h" #include "net/log/net_log.h" @@ -53,11 +54,13 @@ bool g_connect_backup_jobs_enabled = true; ConnectJob::ConnectJob(const std::string& group_name, base::TimeDelta timeout_duration, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, Delegate* delegate, const BoundNetLog& net_log) : group_name_(group_name), timeout_duration_(timeout_duration), priority_(priority), + respect_limits_(respect_limits), delegate_(delegate), net_log_(net_log), idle_(true) { @@ -102,6 +105,7 @@ void ConnectJob::SetSocket(scoped_ptr<StreamSocket> socket) { } void ConnectJob::NotifyDelegateOfCompletion(int rv) { + TRACE_EVENT0("net", "ConnectJob::NotifyDelegateOfCompletion"); // The delegate will own |this|. Delegate* delegate = delegate_; delegate_ = NULL; @@ -141,16 +145,16 @@ ClientSocketPoolBaseHelper::Request::Request( ClientSocketHandle* handle, const CompletionCallback& callback, RequestPriority priority, - bool ignore_limits, + ClientSocketPool::RespectLimits respect_limits, Flags flags, const BoundNetLog& net_log) : handle_(handle), callback_(callback), priority_(priority), - ignore_limits_(ignore_limits), + respect_limits_(respect_limits), flags_(flags), net_log_(net_log) { - if (ignore_limits_) + if (respect_limits_ == ClientSocketPool::RespectLimits::DISABLED) DCHECK_EQ(priority_, MAXIMUM_PRIORITY); } @@ -218,6 +222,9 @@ ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair( result(result_in) { } +ClientSocketPoolBaseHelper::CallbackResultPair::CallbackResultPair( + const CallbackResultPair& other) = default; + ClientSocketPoolBaseHelper::CallbackResultPair::~CallbackResultPair() {} bool ClientSocketPoolBaseHelper::IsStalled() const { @@ -378,7 +385,7 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( // Can we make another active socket now? if (!group->HasAvailableSocketSlot(max_sockets_per_group_) && - !request.ignore_limits()) { + request.respect_limits() == ClientSocketPool::RespectLimits::ENABLED) { // TODO(willchan): Consider whether or not we need to close a socket in a // higher layered group. I don't think this makes sense since we would just // reuse that socket then if we needed one and wouldn't make it down to this @@ -388,7 +395,8 @@ int ClientSocketPoolBaseHelper::RequestSocketInternal( return ERR_IO_PENDING; } - if (ReachedMaxSocketsLimit() && !request.ignore_limits()) { + if (ReachedMaxSocketsLimit() && + request.respect_limits() == ClientSocketPool::RespectLimits::ENABLED) { // NOTE(mmenke): Wonder if we really need different code for each case // here. Only reason for them now seems to be preconnects. if (idle_socket_count() > 0) { @@ -1302,8 +1310,8 @@ void ClientSocketPoolBaseHelper::Group::InsertPendingRequest( scoped_ptr<const Request> request) { // This value must be cached before we release |request|. RequestPriority priority = request->priority(); - if (request->ignore_limits()) { - // Put requests with ignore_limits == true (which should have + if (request->respect_limits() == ClientSocketPool::RespectLimits::DISABLED) { + // Put requests with RespectLimits::DISABLED (which should have // priority == MAXIMUM_PRIORITY) ahead of other requests with // MAXIMUM_PRIORITY. DCHECK_EQ(priority, MAXIMUM_PRIORITY); diff --git a/chromium/net/socket/client_socket_pool_base.h b/chromium/net/socket/client_socket_pool_base.h index cd13ce0dbdd..1e57adb65b1 100644 --- a/chromium/net/socket/client_socket_pool_base.h +++ b/chromium/net/socket/client_socket_pool_base.h @@ -81,6 +81,7 @@ class NET_EXPORT_PRIVATE ConnectJob { ConnectJob(const std::string& group_name, base::TimeDelta timeout_duration, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, Delegate* delegate, const BoundNetLog& net_log); virtual ~ConnectJob(); @@ -117,6 +118,9 @@ class NET_EXPORT_PRIVATE ConnectJob { protected: RequestPriority priority() const { return priority_; } + ClientSocketPool::RespectLimits respect_limits() const { + return respect_limits_; + } void SetSocket(scoped_ptr<StreamSocket> socket); StreamSocket* socket() { return socket_.get(); } void NotifyDelegateOfCompletion(int rv); @@ -138,6 +142,7 @@ class NET_EXPORT_PRIVATE ConnectJob { const base::TimeDelta timeout_duration_; // TODO(akalin): Support reprioritization. const RequestPriority priority_; + const ClientSocketPool::RespectLimits respect_limits_; // Timer to abort jobs that take too long. base::OneShotTimer timer_; Delegate* delegate_; @@ -173,7 +178,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper Request(ClientSocketHandle* handle, const CompletionCallback& callback, RequestPriority priority, - bool ignore_limits, + ClientSocketPool::RespectLimits respect_limits, Flags flags, const BoundNetLog& net_log); @@ -182,7 +187,9 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper ClientSocketHandle* handle() const { return handle_; } const CompletionCallback& callback() const { return callback_; } RequestPriority priority() const { return priority_; } - bool ignore_limits() const { return ignore_limits_; } + ClientSocketPool::RespectLimits respect_limits() const { + return respect_limits_; + } Flags flags() const { return flags_; } const BoundNetLog& net_log() const { return net_log_; } @@ -200,7 +207,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper const CompletionCallback callback_; // TODO(akalin): Support reprioritization. const RequestPriority priority_; - const bool ignore_limits_; + const ClientSocketPool::RespectLimits respect_limits_; const Flags flags_; const BoundNetLog net_log_; @@ -406,8 +413,8 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper // Returns the priority of the top of the pending request queue // (which may be less than the maximum priority over the entire - // queue, due to how we prioritize requests with |ignore_limits| - // set over others). + // queue, due to how we prioritize requests with |respect_limits| + // DISABLED over others). RequestPriority TopPendingPriority() const { // NOTE: FirstMax().value()->priority() is not the same as // FirstMax().priority()! @@ -507,6 +514,7 @@ class NET_EXPORT_PRIVATE ClientSocketPoolBaseHelper struct CallbackResultPair { CallbackResultPair(); CallbackResultPair(const CompletionCallback& callback_in, int result_in); + CallbackResultPair(const CallbackResultPair& other); ~CallbackResultPair(); CompletionCallback callback; @@ -680,12 +688,16 @@ class ClientSocketPoolBase { Request(ClientSocketHandle* handle, const CompletionCallback& callback, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, internal::ClientSocketPoolBaseHelper::Flags flags, - bool ignore_limits, const scoped_refptr<SocketParams>& params, const BoundNetLog& net_log) - : internal::ClientSocketPoolBaseHelper::Request( - handle, callback, priority, ignore_limits, flags, net_log), + : internal::ClientSocketPoolBaseHelper::Request(handle, + callback, + priority, + respect_limits, + flags, + net_log), params_(params) {} const scoped_refptr<SocketParams>& params() const { return params_; } @@ -749,14 +761,13 @@ class ClientSocketPoolBase { int RequestSocket(const std::string& group_name, const scoped_refptr<SocketParams>& params, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) { - scoped_ptr<const Request> request( - new Request(handle, callback, priority, - internal::ClientSocketPoolBaseHelper::NORMAL, - params->ignore_limits(), - params, net_log)); + scoped_ptr<const Request> request(new Request( + handle, callback, priority, respect_limits, + internal::ClientSocketPoolBaseHelper::NORMAL, params, net_log)); return helper_.RequestSocket(group_name, std::move(request)); } @@ -767,9 +778,10 @@ class ClientSocketPoolBase { const scoped_refptr<SocketParams>& params, int num_sockets, const BoundNetLog& net_log) { - const Request request(NULL /* no handle */, CompletionCallback(), IDLE, + const Request request(nullptr /* no handle */, CompletionCallback(), IDLE, + ClientSocketPool::RespectLimits::ENABLED, internal::ClientSocketPoolBaseHelper::NO_IDLE_SOCKETS, - params->ignore_limits(), params, net_log); + params, net_log); helper_.RequestSockets(group_name, request, num_sockets); } diff --git a/chromium/net/socket/client_socket_pool_base_unittest.cc b/chromium/net/socket/client_socket_pool_base_unittest.cc index a467f49e601..308d4af9248 100644 --- a/chromium/net/socket/client_socket_pool_base_unittest.cc +++ b/chromium/net/socket/client_socket_pool_base_unittest.cc @@ -107,16 +107,11 @@ void TestLoadTimingInfoNotConnected(const ClientSocketHandle& handle) { class TestSocketParams : public base::RefCounted<TestSocketParams> { public: - explicit TestSocketParams(bool ignore_limits) - : ignore_limits_(ignore_limits) {} - - bool ignore_limits() { return ignore_limits_; } + explicit TestSocketParams() {} private: friend class base::RefCounted<TestSocketParams>; ~TestSocketParams() {} - - const bool ignore_limits_; }; typedef ClientSocketPoolBase<TestSocketParams> TestClientSocketPoolBase; @@ -180,7 +175,6 @@ class MockClientSocket : public StreamSocket { void SetSubresourceSpeculation() override {} void SetOmniboxSpeculation() override {} bool WasEverUsed() const override { return was_used_to_convey_data_; } - bool UsingTCPFastOpen() const override { return false; } bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } @@ -278,14 +272,17 @@ class TestConnectJob : public ConnectJob { ConnectJob::Delegate* delegate, MockClientSocketFactory* client_socket_factory, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, request.priority(), delegate, + : ConnectJob(group_name, + timeout_duration, + request.priority(), + request.respect_limits(), + delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), job_type_(job_type), client_socket_factory_(client_socket_factory), load_state_(LOAD_STATE_IDLE), store_additional_error_state_(false), - weak_factory_(this) { - } + weak_factory_(this) {} void Signal() { DoConnect(waiting_success_, true /* async */, false /* recoverable */); @@ -505,13 +502,14 @@ class TestClientSocketPool : public ClientSocketPool { int RequestSocket(const std::string& group_name, const void* params, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) override { const scoped_refptr<TestSocketParams>* casted_socket_params = static_cast<const scoped_refptr<TestSocketParams>*>(params); return base_.RequestSocket(group_name, *casted_socket_params, priority, - handle, callback, net_log); + respect_limits, handle, callback, net_log); } void RequestSockets(const std::string& group_name, @@ -663,8 +661,7 @@ class TestConnectJobDelegate : public ConnectJob::Delegate { class ClientSocketPoolBaseTest : public testing::Test { protected: - ClientSocketPoolBaseTest() - : params_(new TestSocketParams(false /* ignore_limits */)) { + ClientSocketPoolBaseTest() : params_(new TestSocketParams()) { connect_backup_jobs_enabled_ = internal::ClientSocketPoolBaseHelper::connect_backup_jobs_enabled(); internal::ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(true); @@ -701,16 +698,17 @@ class ClientSocketPoolBaseTest : public testing::Test { connect_job_factory_)); } - int StartRequestWithParams( + int StartRequestWithIgnoreLimits( const std::string& group_name, RequestPriority priority, - const scoped_refptr<TestSocketParams>& params) { - return test_base_.StartRequestUsingPool( - pool_.get(), group_name, priority, params); + ClientSocketPool::RespectLimits respect_limits) { + return test_base_.StartRequestUsingPool(pool_.get(), group_name, priority, + respect_limits, params_); } int StartRequest(const std::string& group_name, RequestPriority priority) { - return StartRequestWithParams(group_name, priority, params_); + return StartRequestWithIgnoreLimits( + group_name, priority, ClientSocketPool::RespectLimits::ENABLED); } int GetOrderOfRequest(size_t index) const { @@ -749,8 +747,8 @@ TEST_F(ClientSocketPoolBaseTest, ConnectJob_NoTimeoutOnSynchronousCompletion) { ClientSocketHandle ignored; TestClientSocketPoolBase::Request request( &ignored, CompletionCallback(), DEFAULT_PRIORITY, - internal::ClientSocketPoolBaseHelper::NORMAL, - false, params_, BoundNetLog()); + ClientSocketPool::RespectLimits::ENABLED, + internal::ClientSocketPoolBaseHelper::NORMAL, params_, BoundNetLog()); scoped_ptr<TestConnectJob> job( new TestConnectJob(TestConnectJob::kMockJob, "a", @@ -769,8 +767,8 @@ TEST_F(ClientSocketPoolBaseTest, ConnectJob_TimedOut) { TestClientSocketPoolBase::Request request( &ignored, CompletionCallback(), DEFAULT_PRIORITY, - internal::ClientSocketPoolBaseHelper::NORMAL, - false, params_, BoundNetLog()); + ClientSocketPool::RespectLimits::ENABLED, + internal::ClientSocketPoolBaseHelper::NORMAL, params_, BoundNetLog()); // Deleted by TestConnectJobDelegate. TestConnectJob* job = new TestConnectJob(TestConnectJob::kMockPendingJob, @@ -812,13 +810,9 @@ TEST_F(ClientSocketPoolBaseTest, BasicSynchronous) { BoundTestNetLog log; TestLoadTimingInfoNotConnected(handle); - EXPECT_EQ(OK, - handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - log.bound())); + EXPECT_EQ(OK, handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), log.bound())); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); TestLoadTimingInfoConnectedNotReused(handle); @@ -856,12 +850,9 @@ TEST_F(ClientSocketPoolBaseTest, InitConnectionFailure) { info.headers = new HttpResponseHeaders(std::string()); handle.set_ssl_error_response_info(info); EXPECT_EQ(ERR_CONNECTION_FAILED, - handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - log.bound())); + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), log.bound())); EXPECT_FALSE(handle.socket()); EXPECT_FALSE(handle.is_ssl_error()); EXPECT_TRUE(handle.ssl_error_response_info().headers.get() == NULL); @@ -1096,23 +1087,17 @@ TEST_F(ClientSocketPoolBaseTest, StallAndThenCancelAndTriggerAvailableSocket) { ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); ClientSocketHandle handles[4]; for (size_t i = 0; i < arraysize(handles); ++i) { TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handles[i].Init("b", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handles[i].Init("b", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); } // One will be stalled, cancel all the handles now. @@ -1130,23 +1115,20 @@ TEST_F(ClientSocketPoolBaseTest, CancelStalledSocketAtSocketLimit) { ClientSocketHandle handles[kDefaultMaxSockets]; TestCompletionCallback callbacks[kDefaultMaxSockets]; for (int i = 0; i < kDefaultMaxSockets; ++i) { - EXPECT_EQ(OK, handles[i].Init(base::IntToString(i), - params_, - DEFAULT_PRIORITY, - callbacks[i].callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handles[i].Init( + base::IntToString(i), params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callbacks[i].callback(), pool_.get(), BoundNetLog())); } // Force a stalled group. ClientSocketHandle stalled_handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ( + ERR_IO_PENDING, + stalled_handle.Init("foo", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // Cancel the stalled request. stalled_handle.Reset(); @@ -1169,24 +1151,22 @@ TEST_F(ClientSocketPoolBaseTest, CancelPendingSocketAtSocketLimit) { ClientSocketHandle handles[kDefaultMaxSockets]; for (int i = 0; i < kDefaultMaxSockets; ++i) { TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handles[i].Init(base::IntToString(i), - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ( + ERR_IO_PENDING, + handles[i].Init(base::IntToString(i), params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); } // Force a stalled group. connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); ClientSocketHandle stalled_handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ( + ERR_IO_PENDING, + stalled_handle.Init("foo", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // Since it is stalled, it should have no connect jobs. EXPECT_EQ(0, pool_->NumConnectJobsInGroup("foo")); @@ -1225,13 +1205,11 @@ TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) { ClientSocketHandle handles[kDefaultMaxSockets]; for (int i = 0; i < kDefaultMaxSockets; ++i) { TestCompletionCallback callback; - EXPECT_EQ(OK, handles[i].Init(base::StringPrintf( - "Take 2: %d", i), - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ( + OK, handles[i].Init(base::StringPrintf("Take 2: %d", i), params_, + DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); } EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); @@ -1239,12 +1217,11 @@ TEST_F(ClientSocketPoolBaseTest, WaitForStalledSocketAtSocketLimit) { EXPECT_FALSE(pool_->IsStalled()); // Now we will hit the socket limit. - EXPECT_EQ(ERR_IO_PENDING, stalled_handle.Init("foo", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ( + ERR_IO_PENDING, + stalled_handle.Init("foo", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_TRUE(pool_->IsStalled()); // Dropping out of scope will close all handles and return them to idle. @@ -1267,12 +1244,9 @@ TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketAtSocketLimitDeleteGroup) { for (int i = 0; i < kDefaultMaxSockets; ++i) { ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(OK, handle.Init(base::IntToString(i), - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handle.Init(base::IntToString(i), params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); } // Flush all the DoReleaseSocket tasks. @@ -1287,12 +1261,9 @@ TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketAtSocketLimitDeleteGroup) { // "0" is special here, since it should be the first entry in the sorted map, // which is the one which we would close an idle socket for. We shouldn't // close an idle socket though, since we should reuse the idle socket. - EXPECT_EQ(OK, handle.Init("0", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handle.Init("0", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(kDefaultMaxSockets, client_socket_factory_.allocation_count()); EXPECT_EQ(kDefaultMaxSockets - 1, pool_->IdleSocketCount()); @@ -1361,12 +1332,10 @@ TEST_F(ClientSocketPoolBaseTest, CancelRequestClearGroup) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); handle.Reset(); } @@ -1377,23 +1346,18 @@ TEST_F(ClientSocketPoolBaseTest, ConnectCancelConnect) { ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); handle.Reset(); TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback2.WaitForResult()); EXPECT_FALSE(callback.have_result()); @@ -1457,11 +1421,11 @@ void RequestSocketOnComplete(ClientSocketHandle* handle, handle->socket()->Disconnect(); handle->Reset(); - scoped_refptr<TestSocketParams> params( - new TestSocketParams(false /* ignore_limits */)); + scoped_refptr<TestSocketParams> params(new TestSocketParams()); TestCompletionCallback callback; - int rv = - handle->Init("a", params, LOWEST, nested_callback, pool, BoundNetLog()); + int rv = handle->Init("a", params, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + nested_callback, pool, BoundNetLog()); if (rv != ERR_IO_PENDING) { DCHECK_EQ(TestConnectJob::kMockJob, next_job_type); nested_callback.Run(rv); @@ -1480,7 +1444,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestPendingJobTwice) { ClientSocketHandle handle; TestCompletionCallback second_result_callback; int rv = handle.Init( - "a", params_, DEFAULT_PRIORITY, + "a", params_, DEFAULT_PRIORITY, ClientSocketPool::RespectLimits::ENABLED, base::Bind(&RequestSocketOnComplete, &handle, pool_.get(), connect_job_factory_, TestConnectJob::kMockPendingJob, second_result_callback.callback()), @@ -1500,7 +1464,7 @@ TEST_F(ClientSocketPoolBaseTest, RequestPendingJobThenSynchronous) { ClientSocketHandle handle; TestCompletionCallback second_result_callback; int rv = handle.Init( - "a", params_, DEFAULT_PRIORITY, + "a", params_, DEFAULT_PRIORITY, ClientSocketPool::RespectLimits::ENABLED, base::Bind(&RequestSocketOnComplete, &handle, pool_.get(), connect_job_factory_, TestConnectJob::kMockPendingJob, second_result_callback.callback()), @@ -1567,23 +1531,17 @@ TEST_F(ClientSocketPoolBaseTest, CancelActiveRequestThenRequestSocket) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); // Cancel the active request. handle.Reset(); - rv = handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog()); + rv = handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, callback.WaitForResult()); @@ -1636,12 +1594,9 @@ TEST_F(ClientSocketPoolBaseTest, BasicAsynchronous) { ClientSocketHandle handle; TestCompletionCallback callback; BoundTestNetLog log; - int rv = handle.Init("a", - params_, - LOWEST, - callback.callback(), - pool_.get(), - log.bound()); + int rv = handle.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), log.bound()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); TestLoadTimingInfoNotConnected(handle); @@ -1683,12 +1638,10 @@ TEST_F(ClientSocketPoolBaseTest, HttpResponseInfo info; info.headers = new HttpResponseHeaders(std::string()); handle.set_ssl_error_response_info(info); - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - log.bound())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), log.bound())); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); EXPECT_FALSE(handle.is_ssl_error()); @@ -1735,20 +1688,14 @@ TEST_F(ClientSocketPoolBaseTest, TwoRequestsCancelOne) { TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); BoundTestNetLog log2; EXPECT_EQ(ERR_IO_PENDING, - handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); handle.Reset(); @@ -1795,11 +1742,9 @@ TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) { std::vector<TestSocketRequest*> request_order; size_t completion_count; // unused TestSocketRequest req1(&request_order, &completion_count); - int rv = req1.handle()->Init("a", - params_, - DEFAULT_PRIORITY, - req1.callback(), pool_.get(), - BoundNetLog()); + int rv = req1.handle()->Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + req1.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, req1.WaitForResult()); @@ -1808,20 +1753,14 @@ TEST_F(ClientSocketPoolBaseTest, ReleaseSockets) { connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); TestSocketRequest req2(&request_order, &completion_count); - rv = req2.handle()->Init("a", - params_, - DEFAULT_PRIORITY, - req2.callback(), - pool_.get(), - BoundNetLog()); + rv = req2.handle()->Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + req2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); TestSocketRequest req3(&request_order, &completion_count); - rv = req3.handle()->Init("a", - params_, - DEFAULT_PRIORITY, - req3.callback(), - pool_.get(), - BoundNetLog()); + rv = req3.handle()->Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + req3.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); // Both Requests 2 and 3 are pending. We release socket 1 which should @@ -1855,33 +1794,24 @@ TEST_F(ClientSocketPoolBaseTest, PendingJobCompletionOrder) { std::vector<TestSocketRequest*> request_order; size_t completion_count; // unused TestSocketRequest req1(&request_order, &completion_count); - int rv = req1.handle()->Init("a", - params_, - DEFAULT_PRIORITY, - req1.callback(), - pool_.get(), - BoundNetLog()); + int rv = req1.handle()->Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + req1.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); TestSocketRequest req2(&request_order, &completion_count); - rv = req2.handle()->Init("a", - params_, - DEFAULT_PRIORITY, - req2.callback(), - pool_.get(), - BoundNetLog()); + rv = req2.handle()->Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + req2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); // The pending job is sync. connect_job_factory_->set_job_type(TestConnectJob::kMockJob); TestSocketRequest req3(&request_order, &completion_count); - rv = req3.handle()->Init("a", - params_, - DEFAULT_PRIORITY, - req3.callback(), - pool_.get(), - BoundNetLog()); + rv = req3.handle()->Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + req3.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(ERR_CONNECTION_FAILED, req1.WaitForResult()); @@ -1901,12 +1831,9 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateOneRequest) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); @@ -1925,15 +1852,17 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateTwoRequests) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", params_, DEFAULT_PRIORITY, callback.callback(), - pool_.get(), BoundNetLog()); + int rv = handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); client_socket_factory_.SetJobLoadState(0, LOAD_STATE_RESOLVING_HOST); ClientSocketHandle handle2; TestCompletionCallback callback2; - rv = handle2.Init("a", params_, DEFAULT_PRIORITY, callback2.callback(), - pool_.get(), BoundNetLog()); + rv = handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); client_socket_factory_.SetJobLoadState(1, LOAD_STATE_RESOLVING_HOST); @@ -1956,14 +1885,16 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateTwoRequestsChangeSecondRequestState) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", params_, DEFAULT_PRIORITY, callback.callback(), - pool_.get(), BoundNetLog()); + int rv = handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); ClientSocketHandle handle2; TestCompletionCallback callback2; - rv = handle2.Init("a", params_, DEFAULT_PRIORITY, callback2.callback(), - pool_.get(), BoundNetLog()); + rv = handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); client_socket_factory_.SetJobLoadState(1, LOAD_STATE_RESOLVING_HOST); @@ -1984,12 +1915,9 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateGroupLimit) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - MEDIUM, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, MEDIUM, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, handle.GetLoadState()); @@ -1997,12 +1925,9 @@ TEST_F(ClientSocketPoolBaseTest, LoadStateGroupLimit) { // The first request should now be stalled at the socket group limit. ClientSocketHandle handle2; TestCompletionCallback callback2; - rv = handle2.Init("a", - params_, - HIGHEST, - callback2.callback(), - pool_.get(), - BoundNetLog()); + rv = handle2.Init("a", params_, HIGHEST, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_WAITING_FOR_AVAILABLE_SOCKET, handle.GetLoadState()); EXPECT_EQ(LOAD_STATE_CONNECTING, handle2.GetLoadState()); @@ -2032,35 +1957,26 @@ TEST_F(ClientSocketPoolBaseTest, LoadStatePoolLimit) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); // Request for socket from another pool. ClientSocketHandle handle2; TestCompletionCallback callback2; - rv = handle2.Init("b", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog()); + rv = handle2.Init("b", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); // Request another socket from the first pool. Request should stall at the // socket pool limit. ClientSocketHandle handle3; TestCompletionCallback callback3; - rv = handle3.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog()); + rv = handle3.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); // The third handle should remain stalled as the other sockets in its group @@ -2091,8 +2007,9 @@ TEST_F(ClientSocketPoolBaseTest, Recoverable) { ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, - handle.Init("a", params_, DEFAULT_PRIORITY, callback.callback(), - pool_.get(), BoundNetLog())); + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); } @@ -2105,12 +2022,9 @@ TEST_F(ClientSocketPoolBaseTest, AsyncRecoverable) { ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); EXPECT_EQ(ERR_PROXY_AUTH_REQUESTED, callback.WaitForResult()); EXPECT_TRUE(handle.is_initialized()); @@ -2125,12 +2039,9 @@ TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateSynchronous) { ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_CONNECTION_FAILED, - handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); EXPECT_TRUE(handle.is_ssl_error()); @@ -2145,12 +2056,9 @@ TEST_F(ClientSocketPoolBaseTest, AdditionalErrorStateAsynchronous) { ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); EXPECT_FALSE(handle.is_initialized()); @@ -2173,12 +2081,9 @@ TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerReuse) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - LOWEST, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); ASSERT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); ASSERT_EQ(OK, callback.WaitForResult()); @@ -2194,12 +2099,9 @@ TEST_F(ClientSocketPoolBaseTest, DisableCleanupTimerReuse) { // Request a new socket. This should reuse the old socket and complete // synchronously. BoundTestNetLog log; - rv = handle.Init("a", - params_, - LOWEST, - CompletionCallback(), - pool_.get(), - log.bound()); + rv = handle.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + CompletionCallback(), pool_.get(), log.bound()); ASSERT_EQ(OK, rv); EXPECT_TRUE(handle.is_reused()); TestLoadTimingInfoConnectedReused(handle); @@ -2236,23 +2138,17 @@ TEST_F(ClientSocketPoolBaseTest, MAYBE_DisableCleanupTimerNoReuse) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - LOWEST, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); ASSERT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); ClientSocketHandle handle2; TestCompletionCallback callback2; - rv = handle2.Init("a", - params_, - LOWEST, - callback2.callback(), - pool_.get(), - BoundNetLog()); + rv = handle2.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); ASSERT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle2)); @@ -2281,12 +2177,9 @@ TEST_F(ClientSocketPoolBaseTest, MAYBE_DisableCleanupTimerNoReuse) { // A new socket will be created rather than reusing the idle one. BoundTestNetLog log; TestCompletionCallback callback3; - rv = handle.Init("a", - params_, - LOWEST, - callback3.callback(), - pool_.get(), - log.bound()); + rv = handle.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback3.callback(), pool_.get(), log.bound()); ASSERT_EQ(ERR_IO_PENDING, rv); ASSERT_EQ(OK, callback3.WaitForResult()); EXPECT_FALSE(handle.is_reused()); @@ -2314,23 +2207,17 @@ TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - LOWEST, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle)); ClientSocketHandle handle2; TestCompletionCallback callback2; - rv = handle2.Init("a", - params_, - LOWEST, - callback2.callback(), - pool_.get(), - BoundNetLog()); + rv = handle2.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(LOAD_STATE_CONNECTING, pool_->GetLoadState("a", &handle2)); @@ -2359,12 +2246,9 @@ TEST_F(ClientSocketPoolBaseTest, CleanupTimedOutIdleSockets) { pool_->CleanupTimedOutIdleSockets(); BoundTestNetLog log; - rv = handle.Init("a", - params_, - LOWEST, - callback.callback(), - pool_.get(), - log.bound()); + rv = handle.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), log.bound()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_reused()); @@ -2388,42 +2272,30 @@ TEST_F(ClientSocketPoolBaseTest, MultipleReleasingDisconnectedSockets) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init("a", - params_, - LOWEST, - callback.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(OK, rv); ClientSocketHandle handle2; TestCompletionCallback callback2; - rv = handle2.Init("a", - params_, - LOWEST, - callback2.callback(), - pool_.get(), - BoundNetLog()); + rv = handle2.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(OK, rv); ClientSocketHandle handle3; TestCompletionCallback callback3; - rv = handle3.Init("a", - params_, - LOWEST, - callback3.callback(), - pool_.get(), - BoundNetLog()); + rv = handle3.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback3.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); ClientSocketHandle handle4; TestCompletionCallback callback4; - rv = handle4.Init("a", - params_, - LOWEST, - callback4.callback(), - pool_.get(), - BoundNetLog()); + rv = handle4.Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback4.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); // Release two disconnected sockets. @@ -2458,37 +2330,29 @@ TEST_F(ClientSocketPoolBaseTest, SocketLimitReleasingSockets) { TestCompletionCallback callback_b[4]; for (int i = 0; i < 2; ++i) { - EXPECT_EQ(OK, handle_a[i].Init("a", - params_, - LOWEST, - callback_a[i].callback(), - pool_.get(), + EXPECT_EQ(OK, handle_a[i].Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback_a[i].callback(), pool_.get(), BoundNetLog())); - EXPECT_EQ(OK, handle_b[i].Init("b", - params_, - LOWEST, - callback_b[i].callback(), - pool_.get(), + EXPECT_EQ(OK, handle_b[i].Init("b", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback_b[i].callback(), pool_.get(), BoundNetLog())); } // Make 4 pending requests, 2 per group. for (int i = 2; i < 4; ++i) { - EXPECT_EQ(ERR_IO_PENDING, - handle_a[i].Init("a", - params_, - LOWEST, - callback_a[i].callback(), - pool_.get(), - BoundNetLog())); - EXPECT_EQ(ERR_IO_PENDING, - handle_b[i].Init("b", - params_, - LOWEST, - callback_b[i].callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ( + ERR_IO_PENDING, + handle_a[i].Init("a", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback_a[i].callback(), pool_.get(), BoundNetLog())); + EXPECT_EQ( + ERR_IO_PENDING, + handle_b[i].Init("b", params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback_b[i].callback(), pool_.get(), BoundNetLog())); } // Release b's socket first. The order is important, because in @@ -2571,10 +2435,10 @@ class TestReleasingSocketRequest : public TestCompletionCallbackBase { if (reset_releasing_handle_) handle_.Reset(); - scoped_refptr<TestSocketParams> con_params( - new TestSocketParams(false /* ignore_limits */)); + scoped_refptr<TestSocketParams> con_params(new TestSocketParams()); EXPECT_EQ(expected_result_, handle2_.Init("a", con_params, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, callback2_.callback(), pool_, BoundNetLog())); } @@ -2602,8 +2466,9 @@ TEST_F(ClientSocketPoolBaseTest, AdditionalErrorSocketsDontUseSlot) { TestConnectJob::kMockPendingAdditionalErrorStateJob); TestReleasingSocketRequest req(pool_.get(), OK, false); EXPECT_EQ(ERR_IO_PENDING, - req.handle()->Init("a", params_, DEFAULT_PRIORITY, req.callback(), - pool_.get(), BoundNetLog())); + req.handle()->Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + req.callback(), pool_.get(), BoundNetLog())); // The next job should complete synchronously connect_job_factory_->set_job_type(TestConnectJob::kMockJob); @@ -2628,12 +2493,10 @@ TEST_F(ClientSocketPoolBaseTest, CallbackThatReleasesPool) { ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); pool_->FlushWithError(ERR_NETWORK_CHANGED); @@ -2647,12 +2510,10 @@ TEST_F(ClientSocketPoolBaseTest, DoNotReuseSocketAfterFlush) { ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ClientSocketHandle::UNUSED, handle.reuse_type()); @@ -2661,12 +2522,10 @@ TEST_F(ClientSocketPoolBaseTest, DoNotReuseSocketAfterFlush) { handle.Reset(); base::MessageLoop::current()->RunUntilIdle(); - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback.WaitForResult()); EXPECT_EQ(ClientSocketHandle::UNUSED, handle.reuse_type()); } @@ -2696,12 +2555,9 @@ class ConnectWithinCallback : public TestCompletionCallbackBase { void OnComplete(int result) { SetResult(result); EXPECT_EQ(ERR_IO_PENDING, - handle_.Init(group_name_, - params_, - DEFAULT_PRIORITY, - nested_callback_.callback(), - pool_, - BoundNetLog())); + handle_.Init(group_name_, params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + nested_callback_.callback(), pool_, BoundNetLog())); } const std::string group_name_; @@ -2722,12 +2578,10 @@ TEST_F(ClientSocketPoolBaseTest, AbortAllRequestsOnFlush) { ClientSocketHandle handle; ConnectWithinCallback callback("a", params_, pool_.get()); - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // Second job will be started during the first callback, and will // asynchronously complete with OK. @@ -2749,24 +2603,20 @@ TEST_F(ClientSocketPoolBaseTest, BackupSocketCancelAtMaxSockets) { connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("bar", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // Start (MaxSockets - 1) connected sockets to reach max sockets. connect_job_factory_->set_job_type(TestConnectJob::kMockJob); ClientSocketHandle handles[kDefaultMaxSockets]; for (int i = 1; i < kDefaultMaxSockets; ++i) { TestCompletionCallback callback; - EXPECT_EQ(OK, handles[i].Init("bar", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, + handles[i].Init("bar", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); } base::MessageLoop::current()->RunUntilIdle(); @@ -2791,12 +2641,10 @@ TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterCancelingAllRequests) { connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("bar", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); ASSERT_TRUE(pool_->HasGroup("bar")); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("bar")); EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("bar")); @@ -2821,21 +2669,17 @@ TEST_F(ClientSocketPoolBaseTest, CancelBackupSocketAfterFinishingAllRequests) { connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("bar", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("bar", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); ClientSocketHandle handle2; TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, handle2.Init("bar", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("bar", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); ASSERT_TRUE(pool_->HasGroup("bar")); EXPECT_EQ(2, pool_->NumConnectJobsInGroup("bar")); @@ -2860,12 +2704,9 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingWaitingForConnect) { ClientSocketHandle handle1; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback.WaitForResult()); // No idle sockets, no pending jobs. @@ -2876,12 +2717,9 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingWaitingForConnect) { connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); ClientSocketHandle handle2; EXPECT_EQ(ERR_IO_PENDING, - handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // No idle sockets, and one connecting job. EXPECT_EQ(0, pool_->IdleSocketCount()); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); @@ -2918,12 +2756,9 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtGroupCapacity) { ClientSocketHandle handle1; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback.WaitForResult()); // No idle sockets, no pending jobs. @@ -2934,12 +2769,9 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtGroupCapacity) { connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); ClientSocketHandle handle2; EXPECT_EQ(ERR_IO_PENDING, - handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // No idle sockets, and one connecting job. EXPECT_EQ(0, pool_->IdleSocketCount()); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); @@ -2978,12 +2810,9 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtStall) { ClientSocketHandle handle1; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback.WaitForResult()); // No idle sockets, no pending jobs. @@ -2994,12 +2823,9 @@ TEST_F(ClientSocketPoolBaseTest, DelayedSocketBindingAtStall) { connect_job_factory_->set_job_type(TestConnectJob::kMockWaitingJob); ClientSocketHandle handle2; EXPECT_EQ(ERR_IO_PENDING, - handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // No idle sockets, and one connecting job. EXPECT_EQ(0, pool_->IdleSocketCount()); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); @@ -3041,12 +2867,9 @@ TEST_F(ClientSocketPoolBaseTest, SynchronouslyProcessOnePendingRequest) { ClientSocketHandle handle1; TestCompletionCallback callback1; EXPECT_EQ(ERR_IO_PENDING, - handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); // Make the second request synchronously fail. This should make the Group @@ -3057,12 +2880,9 @@ TEST_F(ClientSocketPoolBaseTest, SynchronouslyProcessOnePendingRequest) { // It'll be ERR_IO_PENDING now, but the TestConnectJob will synchronously fail // when created. EXPECT_EQ(ERR_IO_PENDING, - handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); @@ -3078,29 +2898,23 @@ TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) { ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); ClientSocketHandle handle2; TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); ClientSocketHandle handle3; TestCompletionCallback callback3; - EXPECT_EQ(ERR_IO_PENDING, handle3.Init("a", - params_, - DEFAULT_PRIORITY, - callback3.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle3.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback3.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback1.WaitForResult()); EXPECT_EQ(OK, callback2.WaitForResult()); @@ -3114,24 +2928,15 @@ TEST_F(ClientSocketPoolBaseTest, PreferUsedSocketToUnusedSocket) { handle2.Reset(); handle3.Reset(); - EXPECT_EQ(OK, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); - EXPECT_EQ(OK, handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); - EXPECT_EQ(OK, handle3.Init("a", - params_, - DEFAULT_PRIORITY, - callback3.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); + EXPECT_EQ(OK, handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); + EXPECT_EQ(OK, handle3.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback3.callback(), pool_.get(), BoundNetLog())); EXPECT_TRUE(handle1.socket()->WasEverUsed()); EXPECT_TRUE(handle2.socket()->WasEverUsed()); @@ -3151,21 +2956,17 @@ TEST_F(ClientSocketPoolBaseTest, RequestSockets) { ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); ClientSocketHandle handle2; TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); @@ -3187,12 +2988,10 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsWhenAlreadyHaveAConnectJob) { ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); ASSERT_TRUE(pool_->HasGroup("a")); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); @@ -3207,12 +3006,10 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsWhenAlreadyHaveAConnectJob) { ClientSocketHandle handle2; TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(2, pool_->NumConnectJobsInGroup("a")); EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); @@ -3235,30 +3032,24 @@ TEST_F(ClientSocketPoolBaseTest, ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); ClientSocketHandle handle2; TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); ClientSocketHandle handle3; TestCompletionCallback callback3; - EXPECT_EQ(ERR_IO_PENDING, handle3.Init("a", - params_, - DEFAULT_PRIORITY, - callback3.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle3.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback3.callback(), pool_.get(), BoundNetLog())); ASSERT_TRUE(pool_->HasGroup("a")); EXPECT_EQ(3, pool_->NumConnectJobsInGroup("a")); @@ -3335,12 +3126,10 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountIdleSockets) { ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); ASSERT_EQ(OK, callback1.WaitForResult()); handle1.Reset(); @@ -3362,12 +3151,10 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsCountActiveSockets) { ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); ASSERT_EQ(OK, callback1.WaitForResult()); ASSERT_TRUE(pool_->HasGroup("a")); @@ -3439,22 +3226,17 @@ TEST_F(ClientSocketPoolBaseTest, RequestSocketsMultipleTimesDoesNothing) { ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); ASSERT_EQ(OK, callback1.WaitForResult()); ClientSocketHandle handle2; TestCompletionCallback callback2; - int rv = handle2.Init("a", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog()); + int rv = handle2.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog()); if (rv != OK) { EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, callback2.WaitForResult()); @@ -3518,12 +3300,10 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectJobsTakenByNormalRequests) { ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); @@ -3553,12 +3333,9 @@ TEST_F(ClientSocketPoolBaseTest, ConnectedPreconnectJobsHaveNoConnectTimes) { ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(OK, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // Make sure the idle socket was used. EXPECT_EQ(0, pool_->IdleSocketCountInGroup("a")); @@ -3581,12 +3358,10 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectClosesIdleSocketRemovesGroup) { // Set up one idle socket in "a". ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("a", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); ASSERT_EQ(OK, callback1.WaitForResult()); handle1.Reset(); @@ -3595,18 +3370,14 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectClosesIdleSocketRemovesGroup) { // Set up two active sockets in "b". ClientSocketHandle handle2; TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, handle1.Init("b", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); - EXPECT_EQ(ERR_IO_PENDING, handle2.Init("b", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle1.Init("b", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle2.Init("b", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); ASSERT_EQ(OK, callback1.WaitForResult()); ASSERT_EQ(OK, callback2.WaitForResult()); @@ -3688,12 +3459,10 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectWithBackupJob) { connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); // Timer has started, but the backup connect job shouldn't be created yet. EXPECT_EQ(1, pool_->NumConnectJobsInGroup("a")); EXPECT_EQ(0, pool_->NumUnassignedConnectJobsInGroup("a")); @@ -3727,12 +3496,9 @@ TEST_F(ClientSocketPoolBaseTest, PreconnectWithUnreadData) { connect_job_factory_->set_job_type(TestConnectJob::kMockFailingJob); ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(OK, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); ASSERT_TRUE(pool_->HasGroup("a")); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); @@ -3764,16 +3530,16 @@ class MockLayeredPool : public HigherLayeredPool { } int RequestSocket(TestClientSocketPool* pool) { - scoped_refptr<TestSocketParams> params( - new TestSocketParams(false /* ignore_limits */)); + scoped_refptr<TestSocketParams> params(new TestSocketParams()); return handle_.Init(group_name_, params, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, callback_.callback(), pool, BoundNetLog()); } int RequestSocketWithoutLimits(TestClientSocketPool* pool) { - scoped_refptr<TestSocketParams> params( - new TestSocketParams(true /* ignore_limits */)); + scoped_refptr<TestSocketParams> params(new TestSocketParams()); return handle_.Init(group_name_, params, MAXIMUM_PRIORITY, + ClientSocketPool::RespectLimits::DISABLED, callback_.callback(), pool, BoundNetLog()); } @@ -3836,12 +3602,10 @@ TEST_F(ClientSocketPoolBaseTest, CloseIdleSocketsHeldByLayeredPoolWhenNeeded) { &MockLayeredPool::ReleaseOneConnection)); ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback.WaitForResult()); } @@ -3859,12 +3623,9 @@ TEST_F(ClientSocketPoolBaseTest, // has the maximum number of connections already, it's not stalled). ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(OK, handle1.Init("group1", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handle1.Init("group1", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); MockLayeredPool mock_layered_pool(pool_.get(), "group2"); EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); @@ -3873,12 +3634,10 @@ TEST_F(ClientSocketPoolBaseTest, &MockLayeredPool::ReleaseOneConnection)); ClientSocketHandle handle; TestCompletionCallback callback2; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("group2", - params_, - DEFAULT_PRIORITY, - callback2.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("group2", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback2.WaitForResult()); } @@ -3896,12 +3655,9 @@ TEST_F(ClientSocketPoolBaseTest, ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(OK, handle1.Init("group1", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handle1.Init("group1", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); MockLayeredPool mock_layered_pool(pool_.get(), "group2"); EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); @@ -3913,12 +3669,10 @@ TEST_F(ClientSocketPoolBaseTest, // The third request is made when the socket pool is in a stalled state. ClientSocketHandle handle3; TestCompletionCallback callback3; - EXPECT_EQ(ERR_IO_PENDING, handle3.Init("group3", - params_, - DEFAULT_PRIORITY, - callback3.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle3.Init("group3", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback3.callback(), pool_.get(), BoundNetLog())); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(callback3.have_result()); @@ -3929,12 +3683,10 @@ TEST_F(ClientSocketPoolBaseTest, mock_layered_pool.set_can_release_connection(true); ClientSocketHandle handle4; TestCompletionCallback callback4; - EXPECT_EQ(ERR_IO_PENDING, handle4.Init("group3", - params_, - DEFAULT_PRIORITY, - callback4.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle4.Init("group3", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback4.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback3.WaitForResult()); EXPECT_FALSE(callback4.have_result()); @@ -3961,12 +3713,9 @@ TEST_F(ClientSocketPoolBaseTest, ClientSocketHandle handle1; TestCompletionCallback callback1; - EXPECT_EQ(OK, handle1.Init("group1", - params_, - DEFAULT_PRIORITY, - callback1.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(OK, handle1.Init("group1", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback1.callback(), pool_.get(), BoundNetLog())); MockLayeredPool mock_layered_pool(pool_.get(), "group2"); EXPECT_EQ(OK, mock_layered_pool.RequestSocket(pool_.get())); @@ -3978,12 +3727,10 @@ TEST_F(ClientSocketPoolBaseTest, // The third request is made when the socket pool is in a stalled state. ClientSocketHandle handle3; TestCompletionCallback callback3; - EXPECT_EQ(ERR_IO_PENDING, handle3.Init("group3", - params_, - MEDIUM, - callback3.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle3.Init("group3", params_, MEDIUM, + ClientSocketPool::RespectLimits::ENABLED, + callback3.callback(), pool_.get(), BoundNetLog())); base::RunLoop().RunUntilIdle(); EXPECT_FALSE(callback3.have_result()); @@ -3993,12 +3740,10 @@ TEST_F(ClientSocketPoolBaseTest, mock_layered_pool.set_can_release_connection(true); ClientSocketHandle handle4; TestCompletionCallback callback4; - EXPECT_EQ(ERR_IO_PENDING, handle4.Init("group3", - params_, - HIGHEST, - callback4.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle4.Init("group3", params_, HIGHEST, + ClientSocketPool::RespectLimits::ENABLED, + callback4.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback4.WaitForResult()); EXPECT_FALSE(callback3.have_result()); @@ -4024,36 +3769,38 @@ TEST_F(ClientSocketPoolBaseTest, &MockLayeredPool::ReleaseOneConnection)); ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", - params_, - DEFAULT_PRIORITY, - callback.callback(), - pool_.get(), - BoundNetLog())); + EXPECT_EQ(ERR_IO_PENDING, + handle.Init("a", params_, DEFAULT_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(OK, callback.WaitForResult()); } // Test that when a socket pool and group are at their limits, a request -// with |ignore_limits| triggers creation of a new socket, and gets the socket -// instead of a request with the same priority that was issued earlier, but -// that does not have |ignore_limits| set. +// with RespectLimits::DISABLED triggers creation of a new socket, and gets the +// socket instead of a request with the same priority that was issued earlier, +// but has RespectLimits::ENABLED. TEST_F(ClientSocketPoolBaseTest, IgnoreLimits) { - scoped_refptr<TestSocketParams> params_ignore_limits( - new TestSocketParams(true /* ignore_limits */)); CreatePool(1, 1); // Issue a request to reach the socket pool limit. - EXPECT_EQ(OK, StartRequestWithParams("a", MAXIMUM_PRIORITY, params_)); + EXPECT_EQ( + OK, StartRequestWithIgnoreLimits( + "a", MAXIMUM_PRIORITY, ClientSocketPool::RespectLimits::ENABLED)); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", MAXIMUM_PRIORITY, - params_)); + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithIgnoreLimits( + "a", MAXIMUM_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED)); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", MAXIMUM_PRIORITY, - params_ignore_limits)); + // Issue a request that ignores the limits, so a new ConnectJob is + // created. + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithIgnoreLimits( + "a", MAXIMUM_PRIORITY, + ClientSocketPool::RespectLimits::DISABLED)); ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); EXPECT_EQ(OK, request(2)->WaitForResult()); @@ -4061,28 +3808,32 @@ TEST_F(ClientSocketPoolBaseTest, IgnoreLimits) { } // Test that when a socket pool and group are at their limits, a ConnectJob -// issued for a request with |ignore_limits| set is not cancelled when a request -// without |ignore_limits| issued to the same group is cancelled. +// issued for a request with RespectLimits::DISABLED is not cancelled when a +// request with RespectLimits::ENABLED issued to the same group is cancelled. TEST_F(ClientSocketPoolBaseTest, IgnoreLimitsCancelOtherJob) { - scoped_refptr<TestSocketParams> params_ignore_limits( - new TestSocketParams(true /* ignore_limits */)); CreatePool(1, 1); // Issue a request to reach the socket pool limit. - EXPECT_EQ(OK, StartRequestWithParams("a", MAXIMUM_PRIORITY, params_)); + EXPECT_EQ( + OK, StartRequestWithIgnoreLimits( + "a", MAXIMUM_PRIORITY, ClientSocketPool::RespectLimits::ENABLED)); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); connect_job_factory_->set_job_type(TestConnectJob::kMockPendingJob); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", MAXIMUM_PRIORITY, - params_)); + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithIgnoreLimits( + "a", MAXIMUM_PRIORITY, + ClientSocketPool::RespectLimits::ENABLED)); EXPECT_EQ(0, pool_->NumConnectJobsInGroup("a")); - EXPECT_EQ(ERR_IO_PENDING, StartRequestWithParams("a", MAXIMUM_PRIORITY, - params_ignore_limits)); + // Issue a request with RespectLimits::DISABLED, so a new ConnectJob is + // created. + EXPECT_EQ(ERR_IO_PENDING, StartRequestWithIgnoreLimits( + "a", MAXIMUM_PRIORITY, + ClientSocketPool::RespectLimits::DISABLED)); ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); - // Cancel the pending request without ignore_limits set. The ConnectJob + // Cancel the pending request with RespectLimits::ENABLED. The ConnectJob // should not be cancelled. request(1)->handle()->Reset(); ASSERT_EQ(1, pool_->NumConnectJobsInGroup("a")); diff --git a/chromium/net/socket/client_socket_pool_manager.cc b/chromium/net/socket/client_socket_pool_manager.cc index c053b1f4c04..dee8218294a 100644 --- a/chromium/net/socket/client_socket_pool_manager.cc +++ b/chromium/net/socket/client_socket_pool_manager.cc @@ -14,6 +14,7 @@ #include "net/http/http_stream_factory.h" #include "net/proxy/proxy_info.h" #include "net/socket/client_socket_handle.h" +#include "net/socket/client_socket_pool.h" #include "net/socket/socks_client_socket_pool.h" #include "net/socket/ssl_client_socket_pool.h" #include "net/socket/transport_client_socket_pool.h" @@ -148,7 +149,10 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, connection_group = prefix + connection_group; } - bool ignore_limits = (request_load_flags & LOAD_IGNORE_LIMITS) != 0; + ClientSocketPool::RespectLimits respect_limits = + ClientSocketPool::RespectLimits::ENABLED; + if ((request_load_flags & LOAD_IGNORE_LIMITS) != 0) + respect_limits = ClientSocketPool::RespectLimits::DISABLED; if (!proxy_info.is_direct()) { ProxyServer proxy_server = proxy_info.proxy_server(); proxy_host_port.reset(new HostPortPair(proxy_server.host_port_pair())); @@ -156,7 +160,6 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, new TransportSocketParams( *proxy_host_port, disable_resolver_cache, - ignore_limits, resolution_callback, TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); @@ -175,7 +178,6 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT; proxy_tcp_params = new TransportSocketParams(*proxy_host_port, disable_resolver_cache, - ignore_limits, resolution_callback, combine_connect_and_write); // Set ssl_params, and unset proxy_tcp_params @@ -229,7 +231,6 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT; ssl_tcp_params = new TransportSocketParams(origin_host_port, disable_resolver_cache, - ignore_limits, resolution_callback, combine_connect_and_write); } @@ -250,9 +251,8 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, return OK; } - return socket_handle->Init(connection_group, ssl_params, - request_priority, callback, ssl_pool, - net_log); + return socket_handle->Init(connection_group, ssl_params, request_priority, + respect_limits, callback, ssl_pool, net_log); } // Finally, get the connection started. @@ -267,8 +267,8 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, } return socket_handle->Init(connection_group, http_proxy_params, - request_priority, callback, - pool, net_log); + request_priority, respect_limits, callback, pool, + net_log); } if (proxy_info.is_socks()) { @@ -280,9 +280,8 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, return OK; } - return socket_handle->Init(connection_group, socks_params, - request_priority, callback, pool, - net_log); + return socket_handle->Init(connection_group, socks_params, request_priority, + respect_limits, callback, pool, net_log); } DCHECK(proxy_info.is_direct()); @@ -290,7 +289,6 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, new TransportSocketParams( origin_host_port, disable_resolver_cache, - ignore_limits, resolution_callback, TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT); TransportClientSocketPool* pool = @@ -301,9 +299,8 @@ int InitSocketPoolHelper(ClientSocketPoolManager::SocketGroupType group_type, return OK; } - return socket_handle->Init(connection_group, tcp_params, - request_priority, callback, - pool, net_log); + return socket_handle->Init(connection_group, tcp_params, request_priority, + respect_limits, callback, pool, net_log); } } // namespace diff --git a/chromium/net/socket/client_socket_pool_manager_impl.h b/chromium/net/socket/client_socket_pool_manager_impl.h index 538e507fe6e..e0662f19037 100644 --- a/chromium/net/socket/client_socket_pool_manager_impl.h +++ b/chromium/net/socket/client_socket_pool_manager_impl.h @@ -6,12 +6,13 @@ #define NET_SOCKET_CLIENT_SOCKET_POOL_MANAGER_IMPL_H_ #include <map> +#include <type_traits> + #include "base/compiler_specific.h" #include "base/macros.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/stl_util.h" -#include "base/template_util.h" #include "base/threading/non_thread_safe.h" #include "net/cert/cert_database.h" #include "net/http/http_network_session.h" @@ -39,7 +40,7 @@ template <typename Key, typename Value> class OwnedPoolMap : public std::map<Key, Value> { public: OwnedPoolMap() { - static_assert(base::is_pointer<Value>::value, "value must be a pointer"); + static_assert(std::is_pointer<Value>::value, "value must be a pointer"); } ~OwnedPoolMap() { diff --git a/chromium/net/socket/next_proto.cc b/chromium/net/socket/next_proto.cc index a22418c542e..a3e2e0bd767 100644 --- a/chromium/net/socket/next_proto.cc +++ b/chromium/net/socket/next_proto.cc @@ -6,35 +6,6 @@ namespace net { -NextProtoVector NextProtosDefaults() { - NextProtoVector next_protos; - next_protos.push_back(kProtoHTTP2); - next_protos.push_back(kProtoSPDY31); - next_protos.push_back(kProtoHTTP11); - return next_protos; -} - -NextProtoVector NextProtosWithSpdyAndQuic(bool spdy_enabled, - bool quic_enabled) { - NextProtoVector next_protos; - if (quic_enabled) - next_protos.push_back(kProtoQUIC1SPDY3); - if (spdy_enabled) { - next_protos.push_back(kProtoHTTP2); - next_protos.push_back(kProtoSPDY31); - } - next_protos.push_back(kProtoHTTP11); - return next_protos; -} - -NextProtoVector NextProtosSpdy31() { - NextProtoVector next_protos; - next_protos.push_back(kProtoQUIC1SPDY3); - next_protos.push_back(kProtoSPDY31); - next_protos.push_back(kProtoHTTP11); - return next_protos; -} - bool NextProtoIsSPDY(NextProto next_proto) { return next_proto >= kProtoSPDYMinimumVersion && next_proto <= kProtoSPDYMaximumVersion; diff --git a/chromium/net/socket/next_proto.h b/chromium/net/socket/next_proto.h index 3938d4424ad..734e0dd85fc 100644 --- a/chromium/net/socket/next_proto.h +++ b/chromium/net/socket/next_proto.h @@ -43,16 +43,6 @@ typedef std::vector<NextProto> NextProtoVector; // Convenience functions to create NextProtoVector. -// Default values, which are subject to change over time. -NET_EXPORT NextProtoVector NextProtosDefaults(); - -// Enable SPDY/3.1 and QUIC, but not HTTP/2. -NET_EXPORT NextProtoVector NextProtosSpdy31(); - -// Control SPDY/3.1 and HTTP/2 separately. -NET_EXPORT NextProtoVector NextProtosWithSpdyAndQuic(bool spdy_enabled, - bool quic_enabled); - // Returns true if |next_proto| is a version of SPDY or HTTP/2. bool NextProtoIsSPDY(NextProto next_proto); diff --git a/chromium/net/socket/nss_ssl_util.cc b/chromium/net/socket/nss_ssl_util.cc index ee587571639..6d0064d0b94 100644 --- a/chromium/net/socket/nss_ssl_util.cc +++ b/chromium/net/socket/nss_ssl_util.cc @@ -148,8 +148,9 @@ class NSSSSLInitSingleton { // we prefer AES-GCM, otherwise ChaCha20. The remainder of the cipher suite // preference is inheriented from NSS. */ static const uint16_t chacha_ciphers[] = { - TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305, - TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 0, + TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256, 0, }; static const uint16_t aes_gcm_ciphers[] = { TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, diff --git a/chromium/net/socket/sequenced_socket_data_unittest.cc b/chromium/net/socket/sequenced_socket_data_unittest.cc index e8df5546e74..c27fd615fd4 100644 --- a/chromium/net/socket/sequenced_socket_data_unittest.cc +++ b/chromium/net/socket/sequenced_socket_data_unittest.cc @@ -235,7 +235,6 @@ SequencedSocketDataTest::SequencedSocketDataTest() tcp_params_(new TransportSocketParams( endpoint_, false, - false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), socket_pool_(10, 10, &socket_factory_), @@ -262,7 +261,8 @@ void SequencedSocketDataTest::Initialize(MockRead* reads, EXPECT_EQ(OK, connection_.Init( - endpoint_.ToString(), tcp_params_, LOWEST, CompletionCallback(), + endpoint_.ToString(), tcp_params_, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, CompletionCallback(), reinterpret_cast<TransportClientSocketPool*>(&socket_pool_), BoundNetLog())); sock_ = connection_.socket(); diff --git a/chromium/net/socket/server_socket.cc b/chromium/net/socket/server_socket.cc index 50722a91f32..f2c2383ab58 100644 --- a/chromium/net/socket/server_socket.cc +++ b/chromium/net/socket/server_socket.cc @@ -4,9 +4,9 @@ #include "net/socket/server_socket.h" +#include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_util.h" namespace net { @@ -19,12 +19,12 @@ ServerSocket::~ServerSocket() { int ServerSocket::ListenWithAddressAndPort(const std::string& address_string, uint16_t port, int backlog) { - IPAddressNumber address_number; - if (!ParseIPLiteralToNumber(address_string, &address_number)) { + IPAddress ip_address; + if (!ip_address.AssignFromIPLiteral(address_string)) { return ERR_ADDRESS_INVALID; } - return Listen(IPEndPoint(address_number, port), backlog); + return Listen(IPEndPoint(ip_address, port), backlog); } } // namespace net diff --git a/chromium/net/socket/server_socket.h b/chromium/net/socket/server_socket.h index 41894338e8a..a0794f16dee 100644 --- a/chromium/net/socket/server_socket.h +++ b/chromium/net/socket/server_socket.h @@ -30,8 +30,6 @@ class NET_EXPORT ServerSocket { // Binds the socket with address and port, and starts listening. It expects // a valid IPv4 or IPv6 address. Otherwise, it returns ERR_ADDRESS_INVALID. - // Subclasses may override this function if |address_string| is in a different - // format, for example, unix domain socket path. virtual int ListenWithAddressAndPort(const std::string& address_string, uint16_t port, int backlog); diff --git a/chromium/net/socket/socket_net_log_params.cc b/chromium/net/socket/socket_net_log_params.cc index 37be0a6a34b..347644ac06d 100644 --- a/chromium/net/socket/socket_net_log_params.cc +++ b/chromium/net/socket/socket_net_log_params.cc @@ -10,7 +10,6 @@ #include "base/values.h" #include "net/base/host_port_pair.h" #include "net/base/ip_endpoint.h" -#include "net/base/net_util.h" namespace net { diff --git a/chromium/net/socket/socket_posix.cc b/chromium/net/socket/socket_posix.cc index 96892891d09..ca8b04f09a4 100644 --- a/chromium/net/socket/socket_posix.cc +++ b/chromium/net/socket/socket_posix.cc @@ -13,6 +13,7 @@ #include "base/files/file_util.h" #include "base/logging.h" #include "base/posix/eintr_wrapper.h" +#include "base/trace_event/trace_event.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" @@ -348,6 +349,7 @@ void SocketPosix::DetachFromThread() { } void SocketPosix::OnFileCanReadWithoutBlocking(int fd) { + TRACE_EVENT0("net", "SocketPosix::OnFileCanReadWithoutBlocking"); DCHECK(!accept_callback_.is_null() || !read_callback_.is_null()); if (!accept_callback_.is_null()) { AcceptCompleted(); diff --git a/chromium/net/socket/socket_test_util.cc b/chromium/net/socket/socket_test_util.cc index 04f5151d4fe..f6b2e098feb 100644 --- a/chromium/net/socket/socket_test_util.cc +++ b/chromium/net/socket/socket_test_util.cc @@ -20,6 +20,7 @@ #include "net/base/address_family.h" #include "net/base/address_list.h" #include "net/base/auth.h" +#include "net/base/ip_address.h" #include "net/base/load_timing_info.h" #include "net/http/http_network_session.h" #include "net/http/http_request_headers.h" @@ -123,15 +124,11 @@ void DumpMockReadWrite(const MockReadWrite<type>& r) { } // namespace MockConnect::MockConnect() : mode(ASYNC), result(OK) { - IPAddressNumber ip; - CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); - peer_addr = IPEndPoint(ip, 0); + peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0); } MockConnect::MockConnect(IoMode io_mode, int r) : mode(io_mode), result(r) { - IPAddressNumber ip; - CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); - peer_addr = IPEndPoint(ip, 0); + peer_addr = IPEndPoint(IPAddress(192, 0, 2, 33), 0); } MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) : @@ -142,6 +139,8 @@ MockConnect::MockConnect(IoMode io_mode, int r, IPEndPoint addr) : MockConnect::~MockConnect() {} +void SocketDataProvider::OnEnableTCPFastOpenIfSupported() {} + bool SocketDataProvider::IsIdle() const { return true; } @@ -287,13 +286,17 @@ SSLSocketDataProvider::SSLSocketDataProvider(IoMode mode, int result) client_cert_sent(false), cert_request_info(NULL), channel_id_sent(false), - connection_status(0) { + connection_status(0), + token_binding_negotiated(false) { SSLConnectionStatusSetVersion(SSL_CONNECTION_VERSION_TLS1_2, &connection_status); // Set to TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305 SSLConnectionStatusSetCipherSuite(0xcc14, &connection_status); } +SSLSocketDataProvider::SSLSocketDataProvider( + const SSLSocketDataProvider& other) = default; + SSLSocketDataProvider::~SSLSocketDataProvider() { } @@ -311,6 +314,7 @@ SequencedSocketData::SequencedSocketData(MockRead* reads, read_state_(IDLE), write_state_(IDLE), busy_before_sync_reads_(false), + is_using_tcp_fast_open_(false), weak_factory_(this) { // Check that reads and writes have a contiguous set of sequence numbers // starting from 0 and working their way up, with no repeats and skipping @@ -489,6 +493,10 @@ bool SequencedSocketData::AllWriteDataConsumed() const { return helper_.AllWriteDataConsumed(); } +void SequencedSocketData::OnEnableTCPFastOpenIfSupported() { + is_using_tcp_fast_open_ = true; +} + bool SequencedSocketData::IsIdle() const { // If |busy_before_sync_reads_| is not set, always considered idle. If // no reads left, or the next operation is a write, also consider it idle. @@ -584,6 +592,10 @@ void SequencedSocketData::MaybePostReadCompleteTask() { read_state_ = COMPLETING; } +bool SequencedSocketData::IsUsingTCPFastOpen() const { + return is_using_tcp_fast_open_; +} + void SequencedSocketData::MaybePostWriteCompleteTask() { NET_TRACE(1, " ****** ") << " current: " << sequence_number_; // Only trigger the next write to complete if there is already a write pending @@ -616,6 +628,7 @@ void SequencedSocketData::Reset() { sequence_number_ = 0; read_state_ = IDLE; write_state_ = IDLE; + is_using_tcp_fast_open_ = false; weak_factory_.InvalidateWeakPtrs(); } @@ -744,15 +757,11 @@ scoped_ptr<SSLClientSocket> MockClientSocketFactory::CreateSSLClientSocket( void MockClientSocketFactory::ClearSSLSessionCache() { } -const char MockClientSocket::kTlsUnique[] = "MOCK_TLSUNIQ"; - MockClientSocket::MockClientSocket(const BoundNetLog& net_log) : connected_(false), net_log_(net_log), weak_factory_(this) { - IPAddressNumber ip; - CHECK(ParseIPLiteralToNumber("192.0.2.33", &ip)); - peer_addr_ = IPEndPoint(ip, 0); + peer_addr_ = IPEndPoint(IPAddress(192, 0, 2, 33), 0); } int MockClientSocket::SetReceiveBufferSize(int32_t size) { @@ -783,10 +792,7 @@ int MockClientSocket::GetPeerAddress(IPEndPoint* address) const { } int MockClientSocket::GetLocalAddress(IPEndPoint* address) const { - IPAddressNumber ip; - bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); - CHECK(rv); - *address = IPEndPoint(ip, 123); + *address = IPEndPoint(IPAddress(192, 0, 2, 33), 123); return OK; } @@ -816,12 +822,18 @@ int MockClientSocket::ExportKeyingMaterial(const base::StringPiece& label, return OK; } -int MockClientSocket::GetTLSUniqueChannelBinding(std::string* out) { - out->assign(MockClientSocket::kTlsUnique); - return OK; +ChannelIDService* MockClientSocket::GetChannelIDService() const { + NOTREACHED(); + return NULL; } -ChannelIDService* MockClientSocket::GetChannelIDService() const { +Error MockClientSocket::GetSignedEKMForTokenBinding(crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) { + NOTREACHED(); + return ERR_NOT_IMPLEMENTED; +} + +crypto::ECPrivateKey* MockClientSocket::GetChannelIDKey() const { NOTREACHED(); return NULL; } @@ -1013,8 +1025,10 @@ bool MockTCPClientSocket::WasEverUsed() const { return was_used_to_convey_data_; } -bool MockTCPClientSocket::UsingTCPFastOpen() const { - return false; +void MockTCPClientSocket::EnableTCPFastOpenIfSupported() { + EXPECT_FALSE(IsConnected()) << "Can't enable fast open after connect."; + + data_->OnEnableTCPFastOpenIfSupported(); } bool MockTCPClientSocket::WasNpnNegotiated() const { @@ -1185,10 +1199,6 @@ bool MockSSLClientSocket::WasEverUsed() const { return transport_->socket()->WasEverUsed(); } -bool MockSSLClientSocket::UsingTCPFastOpen() const { - return transport_->socket()->UsingTCPFastOpen(); -} - int MockSSLClientSocket::GetPeerAddress(IPEndPoint* address) const { return transport_->socket()->GetPeerAddress(address); } @@ -1199,6 +1209,8 @@ bool MockSSLClientSocket::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->client_cert_sent = data_->client_cert_sent; ssl_info->channel_id_sent = data_->channel_id_sent; ssl_info->connection_status = data_->connection_status; + ssl_info->token_binding_negotiated = data_->token_binding_negotiated; + ssl_info->token_binding_key_param = data_->token_binding_key_param; return true; } @@ -1224,6 +1236,13 @@ ChannelIDService* MockSSLClientSocket::GetChannelIDService() const { return data_->channel_id_service; } +Error MockSSLClientSocket::GetSignedEKMForTokenBinding( + crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) { + out->push_back('A'); + return OK; +} + void MockSSLClientSocket::OnReadComplete(const MockRead& data) { NOTIMPLEMENTED(); } @@ -1330,10 +1349,7 @@ int MockUDPClientSocket::GetPeerAddress(IPEndPoint* address) const { } int MockUDPClientSocket::GetLocalAddress(IPEndPoint* address) const { - IPAddressNumber ip; - bool rv = ParseIPLiteralToNumber("192.0.2.33", &ip); - CHECK(rv); - *address = IPEndPoint(ip, source_port_); + *address = IPEndPoint(IPAddress(192, 0, 2, 33), source_port_); return OK; } @@ -1341,30 +1357,41 @@ const BoundNetLog& MockUDPClientSocket::NetLog() const { return net_log_; } -int MockUDPClientSocket::BindToNetwork( - NetworkChangeNotifier::NetworkHandle network) { - network_ = network; - return OK; -} - -int MockUDPClientSocket::BindToDefaultNetwork() { - network_ = kDefaultNetworkForTests; - return OK; +int MockUDPClientSocket::Connect(const IPEndPoint& address) { + if (!data_) + return ERR_UNEXPECTED; + connected_ = true; + peer_addr_ = address; + return data_->connect_data().result; } -NetworkChangeNotifier::NetworkHandle MockUDPClientSocket::GetBoundNetwork() - const { - return network_; +int MockUDPClientSocket::ConnectUsingNetwork( + NetworkChangeNotifier::NetworkHandle network, + const IPEndPoint& address) { + DCHECK(!connected_); + if (!data_) + return ERR_UNEXPECTED; + network_ = network; + connected_ = true; + peer_addr_ = address; + return data_->connect_data().result; } -int MockUDPClientSocket::Connect(const IPEndPoint& address) { +int MockUDPClientSocket::ConnectUsingDefaultNetwork(const IPEndPoint& address) { + DCHECK(!connected_); if (!data_) return ERR_UNEXPECTED; + network_ = kDefaultNetworkForTests; connected_ = true; peer_addr_ = address; return data_->connect_data().result; } +NetworkChangeNotifier::NetworkHandle MockUDPClientSocket::GetBoundNetwork() + const { + return network_; +} + void MockUDPClientSocket::OnReadComplete(const MockRead& data) { if (!data_) return; @@ -1595,9 +1622,13 @@ MockTransportClientSocketPool::MockTransportClientSocketPool( MockTransportClientSocketPool::~MockTransportClientSocketPool() {} int MockTransportClientSocketPool::RequestSocket( - const std::string& group_name, const void* socket_params, - RequestPriority priority, ClientSocketHandle* handle, - const CompletionCallback& callback, const BoundNetLog& net_log) { + const std::string& group_name, + const void* socket_params, + RequestPriority priority, + RespectLimits respect_limits, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) { last_request_priority_ = priority; scoped_ptr<StreamSocket> socket = client_socket_factory_->CreateTransportClientSocket( @@ -1640,12 +1671,16 @@ MockSOCKSClientSocketPool::MockSOCKSClientSocketPool( MockSOCKSClientSocketPool::~MockSOCKSClientSocketPool() {} -int MockSOCKSClientSocketPool::RequestSocket( - const std::string& group_name, const void* socket_params, - RequestPriority priority, ClientSocketHandle* handle, - const CompletionCallback& callback, const BoundNetLog& net_log) { - return transport_pool_->RequestSocket( - group_name, socket_params, priority, handle, callback, net_log); +int MockSOCKSClientSocketPool::RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + RespectLimits respect_limits, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) { + return transport_pool_->RequestSocket(group_name, socket_params, priority, + respect_limits, handle, callback, + net_log); } void MockSOCKSClientSocketPool::CancelRequest( diff --git a/chromium/net/socket/socket_test_util.h b/chromium/net/socket/socket_test_util.h index a3c6df42149..dada520bd0d 100644 --- a/chromium/net/socket/socket_test_util.h +++ b/chromium/net/socket/socket_test_util.h @@ -212,6 +212,8 @@ class SocketDataProvider { virtual bool AllReadDataConsumed() const = 0; virtual bool AllWriteDataConsumed() const = 0; + virtual void OnEnableTCPFastOpenIfSupported(); + // Returns true if the request should be considered idle, for the purposes of // IsConnectedAndIdle. virtual bool IsIdle() const; @@ -347,6 +349,7 @@ class StaticSocketDataProvider : public SocketDataProvider { // to Connect(). struct SSLSocketDataProvider { SSLSocketDataProvider(IoMode mode, int result); + SSLSocketDataProvider(const SSLSocketDataProvider& other); ~SSLSocketDataProvider(); void SetNextProto(NextProto proto); @@ -361,6 +364,8 @@ struct SSLSocketDataProvider { bool channel_id_sent; ChannelIDService* channel_id_service; int connection_status; + bool token_binding_negotiated; + TokenBindingParam token_binding_key_param; }; // Uses the sequence_number field in the mock reads and writes to @@ -390,6 +395,7 @@ class SequencedSocketData : public SocketDataProvider { MockWriteResult OnWrite(const std::string& data) override; bool AllReadDataConsumed() const override; bool AllWriteDataConsumed() const override; + void OnEnableTCPFastOpenIfSupported() override; bool IsIdle() const override; // An ASYNC read event with a return value of ERR_IO_PENDING will cause the @@ -407,6 +413,8 @@ class SequencedSocketData : public SocketDataProvider { void Resume(); void RunUntilPaused(); + bool IsUsingTCPFastOpen() const; + // When true, IsConnectedAndIdle() will return false if the next event in the // sequence is a synchronous. Otherwise, the socket claims to be idle as // long as it's connected. Defaults to false. @@ -442,6 +450,7 @@ class SequencedSocketData : public SocketDataProvider { IoState write_state_; bool busy_before_sync_reads_; + bool is_using_tcp_fast_open_; // Used by RunUntilPaused. NULL at all other times. scoped_ptr<base::RunLoop> run_until_paused_run_loop_; @@ -535,9 +544,6 @@ class MockClientSocketFactory : public ClientSocketFactory { class MockClientSocket : public SSLClientSocket { public: - // Value returned by GetTLSUniqueChannelBinding(). - static const char kTlsUnique[]; - // The BoundNetLog is needed to test LoadTimingInfo, which uses NetLog IDs as // unique socket IDs. explicit MockClientSocket(const BoundNetLog& net_log); @@ -574,9 +580,11 @@ class MockClientSocket : public SSLClientSocket { const base::StringPiece& context, unsigned char* out, unsigned int outlen) override; - int GetTLSUniqueChannelBinding(std::string* out) override; NextProtoStatus GetNextProto(std::string* proto) const override; ChannelIDService* GetChannelIDService() const override; + Error GetSignedEKMForTokenBinding(crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) override; + crypto::ECPrivateKey* GetChannelIDKey() const override; SSLFailureState GetSSLFailureState() const override; protected: @@ -622,7 +630,7 @@ class MockTCPClientSocket : public MockClientSocket, public AsyncSocket { bool IsConnectedAndIdle() const override; int GetPeerAddress(IPEndPoint* address) const override; bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; + void EnableTCPFastOpenIfSupported() override; bool WasNpnNegotiated() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; void GetConnectionAttempts(ConnectionAttempts* out) const override; @@ -685,12 +693,13 @@ class MockSSLClientSocket : public MockClientSocket, public AsyncSocket { bool IsConnected() const override; bool IsConnectedAndIdle() const override; bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; int GetPeerAddress(IPEndPoint* address) const override; bool GetSSLInfo(SSLInfo* ssl_info) override; // SSLClientSocket implementation. void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; + Error GetSignedEKMForTokenBinding(crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) override; NextProtoStatus GetNextProto(std::string* proto) const override; // This MockSocket does not implement the manual async IO feature. @@ -737,10 +746,11 @@ class MockUDPClientSocket : public DatagramClientSocket, public AsyncSocket { const BoundNetLog& NetLog() const override; // DatagramClientSocket implementation. - int BindToNetwork(NetworkChangeNotifier::NetworkHandle network) override; - int BindToDefaultNetwork() override; - NetworkChangeNotifier::NetworkHandle GetBoundNetwork() const override; int Connect(const IPEndPoint& address) override; + int ConnectUsingNetwork(NetworkChangeNotifier::NetworkHandle network, + const IPEndPoint& address) override; + int ConnectUsingDefaultNetwork(const IPEndPoint& address) override; + NetworkChangeNotifier::NetworkHandle GetBoundNetwork() const override; // AsyncSocket implementation. void OnReadComplete(const MockRead& data) override; @@ -824,17 +834,15 @@ class ClientSocketPoolTest { PoolType* socket_pool, const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<typename PoolType::SocketParams>& socket_params) { DCHECK(socket_pool); TestSocketRequest* request( new TestSocketRequest(&request_order_, &completion_count_)); requests_.push_back(make_scoped_ptr(request)); - int rv = request->handle()->Init(group_name, - socket_params, - priority, - request->callback(), - socket_pool, - BoundNetLog()); + int rv = request->handle()->Init(group_name, socket_params, priority, + respect_limits, request->callback(), + socket_pool, BoundNetLog()); if (rv != ERR_IO_PENDING) request_order_.push_back(request); return rv; @@ -918,6 +926,7 @@ class MockTransportClientSocketPool : public TransportClientSocketPool { int RequestSocket(const std::string& group_name, const void* socket_params, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) override; @@ -950,6 +959,7 @@ class MockSOCKSClientSocketPool : public SOCKSClientSocketPool { int RequestSocket(const std::string& group_name, const void* socket_params, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) override; diff --git a/chromium/net/socket/socks5_client_socket.cc b/chromium/net/socket/socks5_client_socket.cc index 20baaf2cbe4..8c172b0b514 100644 --- a/chromium/net/socket/socks5_client_socket.cc +++ b/chromium/net/socket/socks5_client_socket.cc @@ -13,7 +13,6 @@ #include "base/sys_byteorder.h" #include "base/trace_event/trace_event.h" #include "net/base/io_buffer.h" -#include "net/base/net_util.h" #include "net/log/net_log.h" #include "net/socket/client_socket_handle.h" @@ -114,14 +113,6 @@ bool SOCKS5ClientSocket::WasEverUsed() const { return was_ever_used_; } -bool SOCKS5ClientSocket::UsingTCPFastOpen() const { - if (transport_.get() && transport_->socket()) { - return transport_->socket()->UsingTCPFastOpen(); - } - NOTREACHED(); - return false; -} - bool SOCKS5ClientSocket::WasNpnNegotiated() const { if (transport_.get() && transport_->socket()) { return transport_->socket()->WasNpnNegotiated(); diff --git a/chromium/net/socket/socks5_client_socket.h b/chromium/net/socket/socks5_client_socket.h index a5438b643ad..4d3d6dbcf1d 100644 --- a/chromium/net/socket/socks5_client_socket.h +++ b/chromium/net/socket/socks5_client_socket.h @@ -53,7 +53,6 @@ class NET_EXPORT_PRIVATE SOCKS5ClientSocket : public StreamSocket { void SetSubresourceSpeculation() override; void SetOmniboxSpeculation() override; bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; diff --git a/chromium/net/socket/socks_client_socket.cc b/chromium/net/socket/socks_client_socket.cc index 69805bb6f3a..0f6e925596e 100644 --- a/chromium/net/socket/socks_client_socket.cc +++ b/chromium/net/socket/socks_client_socket.cc @@ -11,7 +11,6 @@ #include "base/compiler_specific.h" #include "base/sys_byteorder.h" #include "net/base/io_buffer.h" -#include "net/base/net_util.h" #include "net/log/net_log.h" #include "net/socket/client_socket_handle.h" @@ -143,14 +142,6 @@ bool SOCKSClientSocket::WasEverUsed() const { return was_ever_used_; } -bool SOCKSClientSocket::UsingTCPFastOpen() const { - if (transport_.get() && transport_->socket()) { - return transport_->socket()->UsingTCPFastOpen(); - } - NOTREACHED(); - return false; -} - bool SOCKSClientSocket::WasNpnNegotiated() const { if (transport_.get() && transport_->socket()) { return transport_->socket()->WasNpnNegotiated(); @@ -335,7 +326,8 @@ const std::string SOCKSClientSocket::BuildHandshakeWriteBuffer() const { // failing the connect attempt. CHECK_EQ(ADDRESS_FAMILY_IPV4, endpoint.GetFamily()); CHECK_LE(endpoint.address().size(), sizeof(request.ip)); - memcpy(&request.ip, &endpoint.address()[0], endpoint.address().size()); + memcpy(&request.ip, &endpoint.address().bytes()[0], + endpoint.address().size()); DVLOG(1) << "Resolved Host is : " << endpoint.ToStringWithoutPort(); diff --git a/chromium/net/socket/socks_client_socket.h b/chromium/net/socket/socks_client_socket.h index 8d5accd4200..c01156050a4 100644 --- a/chromium/net/socket/socks_client_socket.h +++ b/chromium/net/socket/socks_client_socket.h @@ -51,7 +51,6 @@ class NET_EXPORT_PRIVATE SOCKSClientSocket : public StreamSocket { void SetSubresourceSpeculation() override; void SetOmniboxSpeculation() override; bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; diff --git a/chromium/net/socket/socks_client_socket_pool.cc b/chromium/net/socket/socks_client_socket_pool.cc index 543c2c4ad72..d7f97f3ec8b 100644 --- a/chromium/net/socket/socks_client_socket_pool.cc +++ b/chromium/net/socket/socks_client_socket_pool.cc @@ -27,10 +27,6 @@ SOCKSSocketParams::SOCKSSocketParams( : transport_params_(proxy_server), destination_(host_port_pair), socks_v5_(socks_v5) { - if (transport_params_.get()) - ignore_limits_ = transport_params_->ignore_limits(); - else - ignore_limits_ = false; } SOCKSSocketParams::~SOCKSSocketParams() {} @@ -42,20 +38,24 @@ static const int kSOCKSConnectJobTimeoutInSeconds = 30; SOCKSConnectJob::SOCKSConnectJob( const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<SOCKSSocketParams>& socks_params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, HostResolver* host_resolver, Delegate* delegate, NetLog* net_log) - : ConnectJob(group_name, timeout_duration, priority, delegate, + : ConnectJob(group_name, + timeout_duration, + priority, + respect_limits, + delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), socks_params_(socks_params), transport_pool_(transport_pool), resolver_(host_resolver), - callback_(base::Bind(&SOCKSConnectJob::OnIOComplete, - base::Unretained(this))) { -} + callback_( + base::Bind(&SOCKSConnectJob::OnIOComplete, base::Unretained(this))) {} SOCKSConnectJob::~SOCKSConnectJob() { // We don't worry about cancelling the tcp socket since the destructor in @@ -118,12 +118,9 @@ int SOCKSConnectJob::DoLoop(int result) { int SOCKSConnectJob::DoTransportConnect() { next_state_ = STATE_TRANSPORT_CONNECT_COMPLETE; transport_socket_handle_.reset(new ClientSocketHandle()); - return transport_socket_handle_->Init(group_name(), - socks_params_->transport_params(), - priority(), - callback_, - transport_pool_, - net_log()); + return transport_socket_handle_->Init( + group_name(), socks_params_->transport_params(), priority(), + respect_limits(), callback_, transport_pool_, net_log()); } int SOCKSConnectJob::DoTransportConnectComplete(int result) { @@ -174,14 +171,10 @@ SOCKSClientSocketPool::SOCKSConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return scoped_ptr<ConnectJob>(new SOCKSConnectJob(group_name, - request.priority(), - request.params(), - ConnectionTimeout(), - transport_pool_, - host_resolver_, - delegate, - net_log_)); + return scoped_ptr<ConnectJob>(new SOCKSConnectJob( + group_name, request.priority(), request.respect_limits(), + request.params(), ConnectionTimeout(), transport_pool_, host_resolver_, + delegate, net_log_)); } base::TimeDelta @@ -212,15 +205,18 @@ SOCKSClientSocketPool::SOCKSClientSocketPool( SOCKSClientSocketPool::~SOCKSClientSocketPool() { } -int SOCKSClientSocketPool::RequestSocket( - const std::string& group_name, const void* socket_params, - RequestPriority priority, ClientSocketHandle* handle, - const CompletionCallback& callback, const BoundNetLog& net_log) { +int SOCKSClientSocketPool::RequestSocket(const std::string& group_name, + const void* socket_params, + RequestPriority priority, + RespectLimits respect_limits, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) { const scoped_refptr<SOCKSSocketParams>* casted_socket_params = static_cast<const scoped_refptr<SOCKSSocketParams>*>(socket_params); return base_.RequestSocket(group_name, *casted_socket_params, priority, - handle, callback, net_log); + respect_limits, handle, callback, net_log); } void SOCKSClientSocketPool::RequestSockets( diff --git a/chromium/net/socket/socks_client_socket_pool.h b/chromium/net/socket/socks_client_socket_pool.h index 8aaf726c353..66d0e53794e 100644 --- a/chromium/net/socket/socks_client_socket_pool.h +++ b/chromium/net/socket/socks_client_socket_pool.h @@ -34,7 +34,6 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams } const HostResolver::RequestInfo& destination() const { return destination_; } bool is_socks_v5() const { return socks_v5_; } - bool ignore_limits() const { return ignore_limits_; } private: friend class base::RefCounted<SOCKSSocketParams>; @@ -45,7 +44,6 @@ class NET_EXPORT_PRIVATE SOCKSSocketParams // This is the HTTP destination. HostResolver::RequestInfo destination_; const bool socks_v5_; - bool ignore_limits_; DISALLOW_COPY_AND_ASSIGN(SOCKSSocketParams); }; @@ -56,6 +54,7 @@ class SOCKSConnectJob : public ConnectJob { public: SOCKSConnectJob(const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<SOCKSSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -121,6 +120,7 @@ class NET_EXPORT_PRIVATE SOCKSClientSocketPool int RequestSocket(const std::string& group_name, const void* connect_params, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) override; diff --git a/chromium/net/socket/socks_client_socket_pool_unittest.cc b/chromium/net/socket/socks_client_socket_pool_unittest.cc index 587f1fa5fa5..5d16ce4cdb5 100644 --- a/chromium/net/socket/socks_client_socket_pool_unittest.cc +++ b/chromium/net/socket/socks_client_socket_pool_unittest.cc @@ -43,7 +43,7 @@ void TestLoadTimingInfo(const ClientSocketHandle& handle) { scoped_refptr<TransportSocketParams> CreateProxyHostParams() { return new TransportSocketParams( - HostPortPair("proxy", 80), false, false, OnHostResolutionCallback(), + HostPortPair("proxy", 80), false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT); } @@ -102,7 +102,8 @@ class SOCKSClientSocketPoolTest : public testing::Test { int StartRequestV5(const std::string& group_name, RequestPriority priority) { return test_base_.StartRequestUsingPool( - &pool_, group_name, priority, CreateSOCKSv5Params()); + &pool_, group_name, priority, ClientSocketPool::RespectLimits::ENABLED, + CreateSOCKSv5Params()); } int GetOrderOfRequest(size_t index) const { @@ -127,8 +128,9 @@ TEST_F(SOCKSClientSocketPoolTest, Simple) { transport_client_socket_factory_.AddSocketDataProvider(data.data_provider()); ClientSocketHandle handle; - int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), - &pool_, BoundNetLog()); + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, + ClientSocketPool::RespectLimits::ENABLED, + CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); @@ -146,9 +148,9 @@ TEST_F(SOCKSClientSocketPoolTest, SetSocketRequestPriorityOnInit) { data.data_provider()); ClientSocketHandle handle; - EXPECT_EQ(OK, - handle.Init("a", CreateSOCKSv5Params(), priority, - CompletionCallback(), &pool_, BoundNetLog())); + EXPECT_EQ(OK, handle.Init("a", CreateSOCKSv5Params(), priority, + ClientSocketPool::RespectLimits::ENABLED, + CompletionCallback(), &pool_, BoundNetLog())); EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); handle.socket()->Disconnect(); } @@ -167,6 +169,7 @@ TEST_F(SOCKSClientSocketPoolTest, SetResolvePriorityOnInit) { ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, handle.Init("a", CreateSOCKSv4Params(), priority, + ClientSocketPool::RespectLimits::ENABLED, CompletionCallback(), &pool_, BoundNetLog())); EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); EXPECT_EQ(priority, host_resolver_.last_request_priority()); @@ -180,8 +183,9 @@ TEST_F(SOCKSClientSocketPoolTest, Async) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), - &pool_, BoundNetLog()); + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -199,8 +203,9 @@ TEST_F(SOCKSClientSocketPoolTest, TransportConnectError) { transport_client_socket_factory_.AddSocketDataProvider(&socket_data); ClientSocketHandle handle; - int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), - &pool_, BoundNetLog()); + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, + ClientSocketPool::RespectLimits::ENABLED, + CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -213,8 +218,9 @@ TEST_F(SOCKSClientSocketPoolTest, AsyncTransportConnectError) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), - &pool_, BoundNetLog()); + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -235,8 +241,9 @@ TEST_F(SOCKSClientSocketPoolTest, SOCKSConnectError) { ClientSocketHandle handle; EXPECT_EQ(0, transport_socket_pool_.release_count()); - int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, CompletionCallback(), - &pool_, BoundNetLog()); + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, + ClientSocketPool::RespectLimits::ENABLED, + CompletionCallback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_SOCKS_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -255,8 +262,9 @@ TEST_F(SOCKSClientSocketPoolTest, AsyncSOCKSConnectError) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(0, transport_socket_pool_.release_count()); - int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, callback.callback(), - &pool_, BoundNetLog()); + int rv = handle.Init("a", CreateSOCKSv5Params(), LOW, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); diff --git a/chromium/net/socket/ssl_client_socket.h b/chromium/net/socket/ssl_client_socket.h index 3a6aa94258f..10affda136a 100644 --- a/chromium/net/socket/ssl_client_socket.h +++ b/chromium/net/socket/ssl_client_socket.h @@ -22,6 +22,10 @@ class FilePath; class SequencedTaskRunner; } +namespace crypto { +class ECPrivateKey; +} + namespace net { class CTPolicyEnforcer; @@ -144,6 +148,16 @@ class NET_EXPORT SSLClientSocket : public SSLSocket { // channel ids are not supported. virtual ChannelIDService* GetChannelIDService() const = 0; + // Signs the EKM value for Token Binding with |*key| and puts it in |*out|. + // Returns a net error code. + virtual Error GetSignedEKMForTokenBinding(crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) = 0; + + // This method is only for debugging crbug.com/548423 and will be removed when + // that bug is closed. This returns the channel ID key that was used when + // establishing the connection (or NULL if no channel ID was used). + virtual crypto::ECPrivateKey* GetChannelIDKey() const = 0; + // Returns the state of the handshake when it failed, or |SSL_FAILURE_NONE| if // the handshake succeeded. This is used to classify causes of the TLS version // fallback. diff --git a/chromium/net/socket/ssl_client_socket_nss.cc b/chromium/net/socket/ssl_client_socket_nss.cc index b15d76174aa..e3d0a36bde4 100644 --- a/chromium/net/socket/ssl_client_socket_nss.cc +++ b/chromium/net/socket/ssl_client_socket_nss.cc @@ -95,6 +95,7 @@ #include "net/cert/cert_verifier.h" #include "net/cert/ct_ev_whitelist.h" #include "net/cert/ct_policy_enforcer.h" +#include "net/cert/ct_policy_status.h" #include "net/cert/ct_verifier.h" #include "net/cert/ct_verify_result.h" #include "net/cert/scoped_nss_types.h" @@ -520,6 +521,10 @@ class SSLClientSocketNSS::Core : public base::RefCountedThreadSafe<Core> { // verified, and may not be called within an NSS callback. void CacheSessionIfNecessary(); + crypto::ECPrivateKey* GetChannelIDKey() const { + return channel_id_key_.get(); + } + private: friend class base::RefCountedThreadSafe<Core>; ~Core(); @@ -964,9 +969,8 @@ int SSLClientSocketNSS::Core::Read(IOBuffer* buf, int buf_len, nss_waiting_read_ = true; bool posted = nss_task_runner_->PostTask( - FROM_HERE, - base::Bind(IgnoreResult(&Core::Read), this, make_scoped_refptr(buf), - buf_len, callback)); + FROM_HERE, base::Bind(IgnoreResult(&Core::Read), this, + base::RetainedRef(buf), buf_len, callback)); if (!posted) { nss_is_closed_ = true; nss_waiting_read_ = false; @@ -1021,9 +1025,8 @@ int SSLClientSocketNSS::Core::Write(IOBuffer* buf, int buf_len, nss_waiting_write_ = true; bool posted = nss_task_runner_->PostTask( - FROM_HERE, - base::Bind(IgnoreResult(&Core::Write), this, make_scoped_refptr(buf), - buf_len, callback)); + FROM_HERE, base::Bind(IgnoreResult(&Core::Write), this, + base::RetainedRef(buf), buf_len, callback)); if (!posted) { nss_is_closed_ = true; nss_waiting_write_ = false; @@ -1531,11 +1534,10 @@ int SSLClientSocketNSS::Core::DoPayloadRead() { pending_read_nss_error_ = 0; if (rv == 0) { - PostOrRunCallback( - FROM_HERE, - base::Bind(&LogByteTransferEvent, weak_net_log_, - NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, - scoped_refptr<IOBuffer>(user_read_buf_))); + PostOrRunCallback(FROM_HERE, + base::Bind(&LogByteTransferEvent, weak_net_log_, + NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, + base::RetainedRef(user_read_buf_))); } else { PostOrRunCallback( FROM_HERE, @@ -1616,11 +1618,10 @@ int SSLClientSocketNSS::Core::DoPayloadRead() { DCHECK_NE(ERR_IO_PENDING, pending_read_result_); if (rv >= 0) { - PostOrRunCallback( - FROM_HERE, - base::Bind(&LogByteTransferEvent, weak_net_log_, - NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, - scoped_refptr<IOBuffer>(user_read_buf_))); + PostOrRunCallback(FROM_HERE, + base::Bind(&LogByteTransferEvent, weak_net_log_, + NetLog::TYPE_SSL_SOCKET_BYTES_RECEIVED, rv, + base::RetainedRef(user_read_buf_))); } else if (rv != ERR_IO_PENDING) { PostOrRunCallback( FROM_HERE, @@ -1649,11 +1650,10 @@ int SSLClientSocketNSS::Core::DoPayloadWrite() { base::Bind(&Core::OnNSSBufferUpdated, this, new_amount_in_read_buffer)); } if (rv >= 0) { - PostOrRunCallback( - FROM_HERE, - base::Bind(&LogByteTransferEvent, weak_net_log_, - NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv, - scoped_refptr<IOBuffer>(user_write_buf_))); + PostOrRunCallback(FROM_HERE, + base::Bind(&LogByteTransferEvent, weak_net_log_, + NetLog::TYPE_SSL_SOCKET_BYTES_SENT, rv, + base::RetainedRef(user_write_buf_))); return rv; } PRErrorCode prerr = PR_GetError(); @@ -1720,9 +1720,8 @@ int SSLClientSocketNSS::Core::BufferRecv() { rv = DoBufferRecv(read_buffer.get(), nb); } else { bool posted = network_task_runner_->PostTask( - FROM_HERE, - base::Bind(IgnoreResult(&Core::DoBufferRecv), this, read_buffer, - nb)); + FROM_HERE, base::Bind(IgnoreResult(&Core::DoBufferRecv), this, + base::RetainedRef(read_buffer), nb)); rv = posted ? ERR_IO_PENDING : ERR_ABORTED; } @@ -1770,9 +1769,8 @@ int SSLClientSocketNSS::Core::BufferSend() { rv = DoBufferSend(send_buffer.get(), len); } else { bool posted = network_task_runner_->PostTask( - FROM_HERE, - base::Bind(IgnoreResult(&Core::DoBufferSend), this, send_buffer, - len)); + FROM_HERE, base::Bind(IgnoreResult(&Core::DoBufferSend), this, + base::RetainedRef(send_buffer), len)); rv = posted ? ERR_IO_PENDING : ERR_ABORTED; } @@ -1978,7 +1976,7 @@ void SSLClientSocketNSS::Core::UpdateServerCert() { // own a reference to the certificate. NetLog::ParametersCallback net_log_callback = base::Bind(&NetLogX509CertificateCallback, - nss_handshake_state_.server_cert); + base::RetainedRef(nss_handshake_state_.server_cert)); PostOrRunCallback( FROM_HERE, base::Bind(&AddLogEventWithCallback, weak_net_log_, @@ -2159,12 +2157,12 @@ int SSLClientSocketNSS::Core::DoBufferRecv(IOBuffer* read_buffer, int len) { int rv = transport_->socket()->Read( read_buffer, len, base::Bind(&Core::BufferRecvComplete, base::Unretained(this), - scoped_refptr<IOBuffer>(read_buffer))); + base::RetainedRef(read_buffer))); if (!OnNSSTaskRunner() && rv != ERR_IO_PENDING) { - nss_task_runner_->PostTask( - FROM_HERE, base::Bind(&Core::BufferRecvComplete, this, - scoped_refptr<IOBuffer>(read_buffer), rv)); + nss_task_runner_->PostTask(FROM_HERE, + base::Bind(&Core::BufferRecvComplete, this, + base::RetainedRef(read_buffer), rv)); return rv; } @@ -2301,7 +2299,7 @@ void SSLClientSocketNSS::Core::BufferRecvComplete( nss_task_runner_->PostTask( FROM_HERE, base::Bind(&Core::BufferRecvComplete, this, - scoped_refptr<IOBuffer>(read_buffer), result)); + base::RetainedRef(read_buffer), result)); return; } @@ -2410,7 +2408,7 @@ bool SSLClientSocketNSS::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->cert = server_cert_verify_result_.verified_cert; ssl_info->unverified_cert = core_->state().server_cert; - AddSCTInfoToSSLInfo(ssl_info); + AddCTInfoToSSLInfo(ssl_info); ssl_info->connection_status = core_->state().ssl_connection_status; @@ -2480,22 +2478,6 @@ int SSLClientSocketNSS::ExportKeyingMaterial(const base::StringPiece& label, return OK; } -int SSLClientSocketNSS::GetTLSUniqueChannelBinding(std::string* out) { - if (!IsConnected()) - return ERR_SOCKET_NOT_CONNECTED; - unsigned char buf[64]; - unsigned int len; - SECStatus result = SSL_GetChannelBinding(nss_fd_, - SSL_CHANNEL_BINDING_TLS_UNIQUE, - buf, &len, arraysize(buf)); - if (result != SECSuccess) { - LogFailedNSSFunction(net_log_, "SSL_GetChannelBinding", ""); - return MapNSSError(PORT_GetError()); - } - out->assign(reinterpret_cast<char*>(buf), len); - return OK; -} - SSLClientSocket::NextProtoStatus SSLClientSocketNSS::GetNextProto( std::string* proto) const { *proto = core_->state().next_proto; @@ -2631,14 +2613,6 @@ bool SSLClientSocketNSS::WasEverUsed() const { return core_->WasEverUsed(); } -bool SSLClientSocketNSS::UsingTCPFastOpen() const { - if (transport_.get() && transport_->socket()) { - return transport_->socket()->UsingTCPFastOpen(); - } - NOTREACHED(); - return false; -} - int SSLClientSocketNSS::Read(IOBuffer* buf, int buf_len, const CompletionCallback& callback) { DCHECK(core_.get()); @@ -2678,7 +2652,7 @@ int SSLClientSocketNSS::Init() { EnsureNSSSSLInit(); if (!NSS_IsInitialized()) return ERR_UNEXPECTED; -#if defined(USE_NSS_CERTS) || defined(OS_IOS) +#if defined(USE_NSS_VERIFIER) if (ssl_config_.cert_io_enabled) { // We must call EnsureNSSHttpIOInit() here, on the IO thread, to get the IO // loop by MessageLoopForIO::current(). @@ -3126,22 +3100,38 @@ void SSLClientSocketNSS::VerifyCT() { // TODO(ekasper): wipe stapled_ocsp_response and sct_list_from_tls_extension // from the state after verification is complete, to conserve memory. - if (policy_enforcer_ && - (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV)) { - scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = - SSLConfigService::GetEVCertsWhitelist(); - if (!policy_enforcer_->DoesConformToCTEVPolicy( - server_cert_verify_result_.verified_cert.get(), ev_whitelist.get(), - ct_verify_result_, net_log_)) { - // TODO(eranm): Log via the BoundNetLog, see crbug.com/437766 - VLOG(1) << "EV certificate for " - << server_cert_verify_result_.verified_cert->subject() - .GetDisplayName() - << " does not conform to CT policy, removing EV status."; - server_cert_verify_result_.cert_status |= - CERT_STATUS_CT_COMPLIANCE_FAILED; - server_cert_verify_result_.cert_status &= ~CERT_STATUS_IS_EV; + ct_verify_result_.ct_policies_applied = (policy_enforcer_ != nullptr); + ct_verify_result_.ev_policy_compliance = + ct::EVPolicyCompliance::EV_POLICY_DOES_NOT_APPLY; + if (policy_enforcer_) { + if ((server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV)) { + scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = + SSLConfigService::GetEVCertsWhitelist(); + ct::EVPolicyCompliance ev_policy_compliance = + policy_enforcer_->DoesConformToCTEVPolicy( + server_cert_verify_result_.verified_cert.get(), + ev_whitelist.get(), ct_verify_result_.verified_scts, net_log_); + ct_verify_result_.ev_policy_compliance = ev_policy_compliance; + if (ev_policy_compliance != + ct::EVPolicyCompliance::EV_POLICY_DOES_NOT_APPLY && + ev_policy_compliance != + ct::EVPolicyCompliance::EV_POLICY_COMPLIES_VIA_WHITELIST && + ev_policy_compliance != + ct::EVPolicyCompliance::EV_POLICY_COMPLIES_VIA_SCTS) { + // TODO(eranm): Log via the BoundNetLog, see crbug.com/437766 + VLOG(1) << "EV certificate for " + << server_cert_verify_result_.verified_cert->subject() + .GetDisplayName() + << " does not conform to CT policy, removing EV status."; + server_cert_verify_result_.cert_status |= + CERT_STATUS_CT_COMPLIANCE_FAILED; + server_cert_verify_result_.cert_status &= ~CERT_STATUS_IS_EV; + } } + ct_verify_result_.cert_policy_compliance = + policy_enforcer_->DoesConformToCertPolicy( + server_cert_verify_result_.verified_cert.get(), + ct_verify_result_.verified_scts, net_log_); } } @@ -3158,8 +3148,8 @@ bool SSLClientSocketNSS::CalledOnValidThread() const { return valid_thread_id_ == base::PlatformThread::CurrentId(); } -void SSLClientSocketNSS::AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const { - ssl_info->UpdateSignedCertificateTimestamps(ct_verify_result_); +void SSLClientSocketNSS::AddCTInfoToSSLInfo(SSLInfo* ssl_info) const { + ssl_info->UpdateCertificateTransparencyInfo(ct_verify_result_); } // static @@ -3179,6 +3169,17 @@ ChannelIDService* SSLClientSocketNSS::GetChannelIDService() const { return channel_id_service_; } +Error SSLClientSocketNSS::GetSignedEKMForTokenBinding( + crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) { + NOTREACHED(); + return ERR_NOT_IMPLEMENTED; +} + +crypto::ECPrivateKey* SSLClientSocketNSS::GetChannelIDKey() const { + return core_->GetChannelIDKey(); +} + SSLFailureState SSLClientSocketNSS::GetSSLFailureState() const { if (completed_handshake_) return SSL_FAILURE_NONE; diff --git a/chromium/net/socket/ssl_client_socket_nss.h b/chromium/net/socket/ssl_client_socket_nss.h index 366df1c4ee7..7073290a791 100644 --- a/chromium/net/socket/ssl_client_socket_nss.h +++ b/chromium/net/socket/ssl_client_socket_nss.h @@ -67,7 +67,6 @@ class SSLClientSocketNSS : public SSLClientSocket { const base::StringPiece& context, unsigned char* out, unsigned int outlen) override; - int GetTLSUniqueChannelBinding(std::string* out) override; // StreamSocket implementation. int Connect(const CompletionCallback& callback) override; @@ -80,7 +79,6 @@ class SSLClientSocketNSS : public SSLClientSocket { void SetSubresourceSpeculation() override; void SetOmniboxSpeculation() override; bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; void GetConnectionAttempts(ConnectionAttempts* out) const override; void ClearConnectionAttempts() override {} @@ -99,6 +97,9 @@ class SSLClientSocketNSS : public SSLClientSocket { // SSLClientSocket implementation. ChannelIDService* GetChannelIDService() const override; + Error GetSignedEKMForTokenBinding(crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) override; + crypto::ECPrivateKey* GetChannelIDKey() const override; SSLFailureState GetSSLFailureState() const override; private: @@ -144,7 +145,7 @@ class SSLClientSocketNSS : public SSLClientSocket { // vetor representing a particular verification state, this method associates // each of the SCTs with the corresponding SCTVerifyStatus as it adds it to // the |ssl_info|.signed_certificate_timestamps list. - void AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const; + void AddCTInfoToSSLInfo(SSLInfo* ssl_info) const; // Move last protocol to first place: SSLConfig::next_protos has protocols in // decreasing order of preference with NPN fallback protocol at the end, but diff --git a/chromium/net/socket/ssl_client_socket_openssl.cc b/chromium/net/socket/ssl_client_socket_openssl.cc index 82ced7cbfea..cfad9deeba2 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.cc +++ b/chromium/net/socket/ssl_client_socket_openssl.cc @@ -26,18 +26,22 @@ #include "base/metrics/histogram_macros.h" #include "base/metrics/sparse_histogram.h" #include "base/profiler/scoped_tracker.h" +#include "base/strings/string_number_conversions.h" #include "base/strings/string_piece.h" #include "base/synchronization/lock.h" #include "base/threading/thread_local.h" +#include "base/trace_event/trace_event.h" #include "base/values.h" +#include "crypto/auto_cbb.h" #include "crypto/ec_private_key.h" #include "crypto/openssl_util.h" #include "crypto/scoped_openssl_types.h" -#include "net/base/ip_address_number.h" +#include "net/base/ip_address.h" #include "net/base/net_errors.h" #include "net/cert/cert_verifier.h" #include "net/cert/ct_ev_whitelist.h" #include "net/cert/ct_policy_enforcer.h" +#include "net/cert/ct_policy_status.h" #include "net/cert/ct_verifier.h" #include "net/cert/x509_certificate_net_log_param.h" #include "net/cert/x509_util_openssl.h" @@ -49,16 +53,13 @@ #include "net/ssl/ssl_failure_state.h" #include "net/ssl/ssl_info.h" #include "net/ssl/ssl_private_key.h" - -#if defined(OS_WIN) -#include "base/win/windows_version.h" -#endif +#include "net/ssl/token_binding.h" #if !defined(OS_NACL) #include "net/ssl/ssl_key_logger.h" #endif -#if defined(USE_NSS_CERTS) || defined(OS_IOS) +#if defined(USE_NSS_VERIFIER) #include "net/cert_net/nss_ocsp.h" #endif @@ -87,71 +88,13 @@ const char kDefaultSupportedNPNProtocol[] = "http/1.1"; const int KDefaultOpenSSLBufferSize = 17 * 1024; // TLS extension number use for Token Binding. -const unsigned int kTbExtNum = 30033; +const unsigned int kTbExtNum = 24; // Token Binding ProtocolVersions supported. const uint8_t kTbProtocolVersionMajor = 0; -const uint8_t kTbProtocolVersionMinor = 3; +const uint8_t kTbProtocolVersionMinor = 5; const uint8_t kTbMinProtocolVersionMajor = 0; -const uint8_t kTbMinProtocolVersionMinor = 2; - -void FreeX509Stack(STACK_OF(X509)* ptr) { - sk_X509_pop_free(ptr, X509_free); -} - -using ScopedX509Stack = crypto::ScopedOpenSSL<STACK_OF(X509), FreeX509Stack>; - -// Used for encoding the |connection_status| field of an SSLInfo object. -int EncodeSSLConnectionStatus(uint16_t cipher_suite, - int compression, - int version) { - return cipher_suite | - ((compression & SSL_CONNECTION_COMPRESSION_MASK) << - SSL_CONNECTION_COMPRESSION_SHIFT) | - ((version & SSL_CONNECTION_VERSION_MASK) << - SSL_CONNECTION_VERSION_SHIFT); -} - -// Returns the net SSL version number (see ssl_connection_status_flags.h) for -// this SSL connection. -int GetNetSSLVersion(SSL* ssl) { - switch (SSL_version(ssl)) { - case TLS1_VERSION: - return SSL_CONNECTION_VERSION_TLS1; - case TLS1_1_VERSION: - return SSL_CONNECTION_VERSION_TLS1_1; - case TLS1_2_VERSION: - return SSL_CONNECTION_VERSION_TLS1_2; - default: - NOTREACHED(); - return SSL_CONNECTION_VERSION_UNKNOWN; - } -} - -ScopedX509 OSCertHandleToOpenSSL( - X509Certificate::OSCertHandle os_handle) { -#if defined(USE_OPENSSL_CERTS) - return ScopedX509(X509Certificate::DupOSCertHandle(os_handle)); -#else // !defined(USE_OPENSSL_CERTS) - std::string der_encoded; - if (!X509Certificate::GetDEREncoded(os_handle, &der_encoded)) - return ScopedX509(); - const uint8_t* bytes = reinterpret_cast<const uint8_t*>(der_encoded.data()); - return ScopedX509(d2i_X509(NULL, &bytes, der_encoded.size())); -#endif // defined(USE_OPENSSL_CERTS) -} - -ScopedX509Stack OSCertHandlesToOpenSSL( - const X509Certificate::OSCertHandles& os_handles) { - ScopedX509Stack stack(sk_X509_new_null()); - for (size_t i = 0; i < os_handles.size(); i++) { - ScopedX509 x509 = OSCertHandleToOpenSSL(os_handles[i]); - if (!x509) - return ScopedX509Stack(); - sk_X509_push(stack.get(), x509.release()); - } - return stack; -} +const uint8_t kTbMinProtocolVersionMinor = 3; bool EVP_MDToPrivateKeyHash(const EVP_MD* md, SSLPrivateKey::Hash* hash) { switch (EVP_MD_type(md)) { @@ -175,18 +118,6 @@ bool EVP_MDToPrivateKeyHash(const EVP_MD* md, SSLPrivateKey::Hash* hash) { } } -class ScopedCBB { - public: - ScopedCBB() { CBB_zero(&cbb_); } - ~ScopedCBB() { CBB_cleanup(&cbb_); } - - CBB* get() { return &cbb_; } - - private: - CBB cbb_; - DISALLOW_COPY_AND_ASSIGN(ScopedCBB); -}; - scoped_ptr<base::Value> NetLogPrivateKeyOperationCallback( SSLPrivateKey::Type type, SSLPrivateKey::Hash hash, @@ -226,6 +157,35 @@ scoped_ptr<base::Value> NetLogPrivateKeyOperationCallback( return std::move(value); } +scoped_ptr<base::Value> NetLogChannelIDLookupCallback( + ChannelIDService* channel_id_service, + NetLogCaptureMode capture_mode) { + ChannelIDStore* store = channel_id_service->GetChannelIDStore(); + scoped_ptr<base::DictionaryValue> dict(new base::DictionaryValue()); + dict->SetBoolean("ephemeral", store->IsEphemeral()); + dict->SetString("service", base::HexEncode(&channel_id_service, + sizeof(channel_id_service))); + dict->SetString("store", base::HexEncode(&store, sizeof(store))); + return std::move(dict); +} + +scoped_ptr<base::Value> NetLogChannelIDLookupCompleteCallback( + crypto::ECPrivateKey* key, + int result, + NetLogCaptureMode capture_mode) { + scoped_ptr<base::DictionaryValue> dict(new base::DictionaryValue()); + dict->SetInteger("net_error", result); + std::string raw_key; + if (result == OK && key && key->ExportRawPublicKey(&raw_key)) { + std::string key_to_log = "redacted"; + if (capture_mode.include_cookies_and_credentials()) { + key_to_log = base::HexEncode(raw_key.data(), raw_key.length()); + } + dict->SetString("key", key_to_log); + } + return std::move(dict); +} + } // namespace class SSLClientSocketOpenSSL::SSLContext { @@ -483,8 +443,8 @@ SSLClientSocketOpenSSL::PeerCertificateChain::AsOSChain() const { intermediates.push_back(sk_X509_value(openssl_chain_.get(), i)); } - return make_scoped_refptr(X509Certificate::CreateFromHandle( - sk_X509_value(openssl_chain_.get(), 0), intermediates)); + return X509Certificate::CreateFromHandle( + sk_X509_value(openssl_chain_.get(), 0), intermediates); #else // DER-encode the chain and convert to a platform certificate handle. std::vector<base::StringPiece> der_chain; @@ -496,7 +456,7 @@ SSLClientSocketOpenSSL::PeerCertificateChain::AsOSChain() const { der_chain.push_back(der); } - return make_scoped_refptr(X509Certificate::CreateFromDERCertChain(der_chain)); + return X509Certificate::CreateFromDERCertChain(der_chain); #endif } @@ -526,6 +486,7 @@ SSLClientSocketOpenSSL::SSLClientSocketOpenSSL( channel_id_service_(context.channel_id_service), tb_was_negotiated_(false), tb_negotiated_param_(TB_PARAM_ECDSAP256), + tb_signed_ekm_map_(10), ssl_(NULL), transport_bio_(NULL), transport_(std::move(transport_socket)), @@ -577,6 +538,43 @@ SSLClientSocketOpenSSL::GetChannelIDService() const { return channel_id_service_; } +Error SSLClientSocketOpenSSL::GetSignedEKMForTokenBinding( + crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) { + // The same key will be used across multiple requests to sign the same value, + // so the signature is cached. + std::string raw_public_key; + if (!key->ExportRawPublicKey(&raw_public_key)) + return ERR_FAILED; + SignedEkmMap::iterator it = tb_signed_ekm_map_.Get(raw_public_key); + if (it != tb_signed_ekm_map_.end()) { + *out = it->second; + return OK; + } + + uint8_t tb_ekm_buf[32]; + static const char kTokenBindingExporterLabel[] = "EXPORTER-Token-Binding"; + if (!SSL_export_keying_material(ssl_, tb_ekm_buf, sizeof(tb_ekm_buf), + kTokenBindingExporterLabel, + strlen(kTokenBindingExporterLabel), nullptr, + 0, false /* no context */)) { + return ERR_FAILED; + } + + if (!SignTokenBindingEkm( + base::StringPiece(reinterpret_cast<char*>(tb_ekm_buf), + sizeof(tb_ekm_buf)), + key, out)) + return ERR_FAILED; + + tb_signed_ekm_map_.Put(raw_public_key, *out); + return OK; +} + +crypto::ECPrivateKey* SSLClientSocketOpenSSL::GetChannelIDKey() const { + return channel_id_key_.get(); +} + SSLFailureState SSLClientSocketOpenSSL::GetSSLFailureState() const { return ssl_failure_state_; } @@ -605,11 +603,6 @@ int SSLClientSocketOpenSSL::ExportKeyingMaterial( return OK; } -int SSLClientSocketOpenSSL::GetTLSUniqueChannelBinding(std::string* out) { - NOTIMPLEMENTED(); - return ERR_NOT_IMPLEMENTED; -} - int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { // It is an error to create an SSLClientSocket whose context has no // TransportSecurityState. @@ -647,6 +640,8 @@ int SSLClientSocketOpenSSL::Connect(const CompletionCallback& callback) { } void SSLClientSocketOpenSSL::Disconnect() { + crypto::OpenSSLErrStackTracer tracer(FROM_HERE); + if (ssl_) { // Calling SSL_shutdown prevents the session from being marked as // unresumable. @@ -772,14 +767,6 @@ bool SSLClientSocketOpenSSL::WasEverUsed() const { return was_ever_used_; } -bool SSLClientSocketOpenSSL::UsingTCPFastOpen() const { - if (transport_.get() && transport_->socket()) - return transport_->socket()->UsingTCPFastOpen(); - - NOTREACHED(); - return false; -} - bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->Reset(); if (server_cert_chain_->empty()) @@ -799,7 +786,7 @@ bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->token_binding_key_param = tb_negotiated_param_; ssl_info->pinning_failure_log = pinning_failure_log_; - AddSCTInfoToSSLInfo(ssl_info); + AddCTInfoToSSLInfo(ssl_info); const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_); CHECK(cipher); @@ -807,9 +794,11 @@ bool SSLClientSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { ssl_info->key_exchange_info = SSL_SESSION_get_key_exchange_info(SSL_get_session(ssl_)); - ssl_info->connection_status = EncodeSSLConnectionStatus( - static_cast<uint16_t>(SSL_CIPHER_get_id(cipher)), 0 /* no compression */, - GetNetSSLVersion(ssl_)); + SSLConnectionStatusSetCipherSuite( + static_cast<uint16_t>(SSL_CIPHER_get_id(cipher)), + &ssl_info->connection_status); + SSLConnectionStatusSetVersion(GetNetSSLVersion(ssl_), + &ssl_info->connection_status); if (!SSL_get_secure_renegotiation_support(ssl_)) ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION; @@ -888,7 +877,7 @@ int SSLClientSocketOpenSSL::Init() { DCHECK(!ssl_); DCHECK(!transport_bio_); -#if defined(USE_NSS_CERTS) || defined(OS_IOS) +#if defined(USE_NSS_VERIFIER) if (ssl_config_.cert_io_enabled) { // TODO(davidben): Move this out of SSLClientSocket. See // https://crbug.com/539520. @@ -908,8 +897,8 @@ int SSLClientSocketOpenSSL::Init() { // // TODO(rsleevi): Should this code allow hostnames that violate the LDH rule? // See https://crbug.com/496472 and https://crbug.com/496468 for discussion. - IPAddressNumber unused; - if (!ParseIPLiteralToNumber(host_and_port_.host(), &unused) && + IPAddress unused; + if (!unused.AssignFromIPLiteral(host_and_port_.host()) && !SSL_set_tlsext_host_name(ssl_, host_and_port_.host().c_str())) { return ERR_UNEXPECTED; } @@ -972,11 +961,13 @@ int SSLClientSocketOpenSSL::Init() { SSL_set_mode(ssl_, mode.set_mask); SSL_clear_mode(ssl_, mode.clear_mask); - // See SSLConfig::disabled_cipher_suites for description of the suites - // disabled by default. Note that SHA256 and SHA384 only select HMAC-SHA256 - // and HMAC-SHA384 cipher suites, not GCM cipher suites with SHA256 or SHA384 - // as the handshake hash. - std::string command("DEFAULT:!SHA256:-SHA384:!AESGCM+AES256:!aPSK"); + // Use BoringSSL defaults, but disable HMAC-SHA256 and HMAC-SHA384 ciphers + // (note that SHA256 and SHA384 only select legacy CBC ciphers). Also disable + // DHE_RSA_WITH_AES_256_GCM_SHA384. Historically, AES_256_GCM was not + // supported. As DHE is being deprecated, don't add a cipher only to remove it + // immediately. + std::string command( + "DEFAULT:!SHA256:!SHA384:!DHE-RSA-AES256-GCM-SHA384:!aPSK"); if (ssl_config_.require_ecdhe) command.append(":!kRSA:!kDHE"); @@ -1000,14 +991,6 @@ int SSLClientSocketOpenSSL::Init() { } } - // Disable ECDSA cipher suites on platforms that do not support ECDSA - // signed certificates, as servers may use the presence of such - // ciphersuites as a hint to send an ECDSA certificate. -#if defined(OS_WIN) - if (base::win::GetVersion() < base::win::VERSION_VISTA) - command.append(":!ECDSA"); -#endif - int rv = SSL_set_cipher_list(ssl_, command.c_str()); // If this fails (rv = 0) it means there are no ciphers enabled on this SSL. // This will almost certainly result in the socket failing to complete the @@ -1244,7 +1227,9 @@ int SSLClientSocketOpenSSL::DoHandshakeComplete(int result) { } int SSLClientSocketOpenSSL::DoChannelIDLookup() { - net_log_.AddEvent(NetLog::TYPE_SSL_CHANNEL_ID_REQUESTED); + NetLog::ParametersCallback callback = base::Bind( + &NetLogChannelIDLookupCallback, base::Unretained(channel_id_service_)); + net_log_.BeginEvent(NetLog::TYPE_SSL_GET_CHANNEL_ID, callback); GotoState(STATE_CHANNEL_ID_LOOKUP_COMPLETE); return channel_id_service_->GetOrCreateChannelID( host_and_port_.host(), &channel_id_key_, @@ -1254,16 +1239,15 @@ int SSLClientSocketOpenSSL::DoChannelIDLookup() { } int SSLClientSocketOpenSSL::DoChannelIDLookupComplete(int result) { + net_log_.EndEvent(NetLog::TYPE_SSL_GET_CHANNEL_ID, + base::Bind(&NetLogChannelIDLookupCompleteCallback, + channel_id_key_.get(), result)); if (result < 0) return result; - if (!channel_id_key_) { - LOG(ERROR) << "Failed to import Channel ID."; - return ERR_CHANNEL_ID_IMPORT_FAILED; - } - // Hand the key to OpenSSL. Check for error in case OpenSSL rejects the key // type. + DCHECK(channel_id_key_); crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); int rv = SSL_set1_tls_channel_id(ssl_, channel_id_key_->key()); if (!rv) { @@ -1274,7 +1258,6 @@ int SSLClientSocketOpenSSL::DoChannelIDLookupComplete(int result) { // Return to the handshake. channel_id_sent_ = true; - net_log_.AddEvent(NetLog::TYPE_SSL_CHANNEL_ID_PROVIDED); GotoState(STATE_HANDSHAKE); return OK; } @@ -1419,22 +1402,38 @@ void SSLClientSocketOpenSSL::VerifyCT() { server_cert_verify_result_.verified_cert.get(), ocsp_response, sct_list, &ct_verify_result_, net_log_); - if (policy_enforcer_ && - (server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV)) { - scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = - SSLConfigService::GetEVCertsWhitelist(); - if (!policy_enforcer_->DoesConformToCTEVPolicy( - server_cert_verify_result_.verified_cert.get(), ev_whitelist.get(), - ct_verify_result_, net_log_)) { - // TODO(eranm): Log via the BoundNetLog, see crbug.com/437766 - VLOG(1) << "EV certificate for " - << server_cert_verify_result_.verified_cert->subject() - .GetDisplayName() - << " does not conform to CT policy, removing EV status."; - server_cert_verify_result_.cert_status |= - CERT_STATUS_CT_COMPLIANCE_FAILED; - server_cert_verify_result_.cert_status &= ~CERT_STATUS_IS_EV; + ct_verify_result_.ct_policies_applied = (policy_enforcer_ != nullptr); + ct_verify_result_.ev_policy_compliance = + ct::EVPolicyCompliance::EV_POLICY_DOES_NOT_APPLY; + if (policy_enforcer_) { + if ((server_cert_verify_result_.cert_status & CERT_STATUS_IS_EV)) { + scoped_refptr<ct::EVCertsWhitelist> ev_whitelist = + SSLConfigService::GetEVCertsWhitelist(); + ct::EVPolicyCompliance ev_policy_compliance = + policy_enforcer_->DoesConformToCTEVPolicy( + server_cert_verify_result_.verified_cert.get(), + ev_whitelist.get(), ct_verify_result_.verified_scts, net_log_); + ct_verify_result_.ev_policy_compliance = ev_policy_compliance; + if (ev_policy_compliance != + ct::EVPolicyCompliance::EV_POLICY_DOES_NOT_APPLY && + ev_policy_compliance != + ct::EVPolicyCompliance::EV_POLICY_COMPLIES_VIA_WHITELIST && + ev_policy_compliance != + ct::EVPolicyCompliance::EV_POLICY_COMPLIES_VIA_SCTS) { + // TODO(eranm): Log via the BoundNetLog, see crbug.com/437766 + VLOG(1) << "EV certificate for " + << server_cert_verify_result_.verified_cert->subject() + .GetDisplayName() + << " does not conform to CT policy, removing EV status."; + server_cert_verify_result_.cert_status |= + CERT_STATUS_CT_COMPLIANCE_FAILED; + server_cert_verify_result_.cert_status &= ~CERT_STATUS_IS_EV; + } } + ct_verify_result_.cert_policy_compliance = + policy_enforcer_->DoesConformToCertPolicy( + server_cert_verify_result_.verified_cert.get(), + ct_verify_result_.verified_scts, net_log_); } } @@ -1459,6 +1458,7 @@ void SSLClientSocketOpenSSL::OnSendComplete(int result) { } void SSLClientSocketOpenSSL::OnRecvComplete(int result) { + TRACE_EVENT0("net", "SSLClientSocketOpenSSL::OnRecvComplete"); if (next_handshake_state_ == STATE_HANDSHAKE) { // In handshake phase. OnHandshakeIOComplete(result); @@ -1476,6 +1476,7 @@ void SSLClientSocketOpenSSL::OnRecvComplete(int result) { } int SSLClientSocketOpenSSL::DoHandshakeLoop(int last_io_result) { + TRACE_EVENT0("net", "SSLClientSocketOpenSSL::DoHandshakeLoop"); int rv = last_io_result; do { // Default to STATE_NONE for next state. @@ -2087,8 +2088,8 @@ int SSLClientSocketOpenSSL::NewSessionCallback(SSL_SESSION* session) { return 1; } -void SSLClientSocketOpenSSL::AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const { - ssl_info->UpdateSignedCertificateTimestamps(ct_verify_result_); +void SSLClientSocketOpenSSL::AddCTInfoToSSLInfo(SSLInfo* ssl_info) const { + ssl_info->UpdateCertificateTransparencyInfo(ct_verify_result_); } std::string SSLClientSocketOpenSSL::GetSessionCacheKey() const { @@ -2237,7 +2238,7 @@ int SSLClientSocketOpenSSL::TokenBindingAdd(const uint8_t** out, if (ssl_config_.token_binding_params.empty()) { return 0; } - ScopedCBB output; + crypto::AutoCBB output; CBB parameters_list; if (!CBB_init(output.get(), 7) || !CBB_add_u8(output.get(), kTbProtocolVersionMajor) || diff --git a/chromium/net/socket/ssl_client_socket_openssl.h b/chromium/net/socket/ssl_client_socket_openssl.h index 178daeb3273..628d4dbeedc 100644 --- a/chromium/net/socket/ssl_client_socket_openssl.h +++ b/chromium/net/socket/ssl_client_socket_openssl.h @@ -13,6 +13,8 @@ #include <string> #include <vector> +#include "base/compiler_specific.h" +#include "base/containers/mru_cache.h" #include "base/memory/ref_counted.h" #include "base/memory/scoped_ptr.h" #include "base/memory/weak_ptr.h" @@ -42,6 +44,8 @@ class CTVerifier; class SSLCertRequestInfo; class SSLInfo; +using SignedEkmMap = base::MRUCache<std::string, std::vector<uint8_t>>; + // An SSL client socket implemented with OpenSSL. class SSLClientSocketOpenSSL : public SSLClientSocket { public: @@ -72,6 +76,9 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { void GetSSLCertRequestInfo(SSLCertRequestInfo* cert_request_info) override; NextProtoStatus GetNextProto(std::string* proto) const override; ChannelIDService* GetChannelIDService() const override; + Error GetSignedEKMForTokenBinding(crypto::ECPrivateKey* key, + std::vector<uint8_t>* out) override; + crypto::ECPrivateKey* GetChannelIDKey() const override; SSLFailureState GetSSLFailureState() const override; // SSLSocket implementation. @@ -80,7 +87,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { const base::StringPiece& context, unsigned char* out, unsigned int outlen) override; - int GetTLSUniqueChannelBinding(std::string* out) override; // StreamSocket implementation. int Connect(const CompletionCallback& callback) override; @@ -93,7 +99,6 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { void SetSubresourceSpeculation() override; void SetOmniboxSpeculation() override; bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; void GetConnectionAttempts(ConnectionAttempts* out) const override; void ClearConnectionAttempts() override {} @@ -191,12 +196,13 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { // Called from the SSL layer whenever a new session is established. int NewSessionCallback(SSL_SESSION* session); - // Adds the SignedCertificateTimestamps from ct_verify_result_ to |ssl_info|. + // Adds the Certificate Transparency info from ct_verify_result_ to + // |ssl_info|. // SCTs are held in three separate vectors in ct_verify_result, each // vetor representing a particular verification state, this method associates // each of the SCTs with the corresponding SCTVerifyStatus as it adds it to // the |ssl_info|.signed_certificate_timestamps list. - void AddSCTInfoToSSLInfo(SSLInfo* ssl_info) const; + void AddCTInfoToSSLInfo(SSLInfo* ssl_info) const; // Returns a unique key string for the SSL session cache for // this socket. @@ -301,6 +307,7 @@ class SSLClientSocketOpenSSL : public SSLClientSocket { ChannelIDService* channel_id_service_; bool tb_was_negotiated_; TokenBindingParam tb_negotiated_param_; + SignedEkmMap tb_signed_ekm_map_; // OpenSSL stuff SSL* ssl_; diff --git a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc b/chromium/net/socket/ssl_client_socket_openssl_unittest.cc deleted file mode 100644 index a1ab91a4754..00000000000 --- a/chromium/net/socket/ssl_client_socket_openssl_unittest.cc +++ /dev/null @@ -1,267 +0,0 @@ -// Copyright (c) 2012 The Chromium Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include "net/socket/ssl_client_socket.h" - -#include <errno.h> -#include <openssl/bio.h> -#include <openssl/bn.h> -#include <openssl/evp.h> -#include <openssl/pem.h> -#include <openssl/rsa.h> -#include <string.h> -#include <utility> - -#include "base/files/file_path.h" -#include "base/files/file_util.h" -#include "base/memory/ref_counted.h" -#include "base/values.h" -#include "crypto/openssl_util.h" -#include "crypto/scoped_openssl_types.h" -#include "net/base/address_list.h" -#include "net/base/io_buffer.h" -#include "net/base/net_errors.h" -#include "net/base/test_completion_callback.h" -#include "net/base/test_data_directory.h" -#include "net/cert/mock_cert_verifier.h" -#include "net/cert/test_root_certs.h" -#include "net/dns/host_resolver.h" -#include "net/http/transport_security_state.h" -#include "net/log/net_log.h" -#include "net/socket/client_socket_factory.h" -#include "net/socket/client_socket_handle.h" -#include "net/socket/socket_test_util.h" -#include "net/socket/tcp_client_socket.h" -#include "net/ssl/openssl_client_key_store.h" -#include "net/ssl/ssl_cert_request_info.h" -#include "net/ssl/ssl_config_service.h" -#include "net/ssl/ssl_platform_key.h" -#include "net/test/cert_test_util.h" -#include "net/test/spawned_test_server/spawned_test_server.h" -#include "testing/gtest/include/gtest/gtest.h" -#include "testing/platform_test.h" - -namespace net { - -namespace { - -// These client auth tests are currently dependent on OpenSSL's struct X509. -#if defined(USE_OPENSSL_CERTS) - -// Loads a PEM-encoded private key file into a scoped EVP_PKEY object. -// |filepath| is the private key file path. -// |*pkey| is reset to the new EVP_PKEY on success, untouched otherwise. -// Returns true on success, false on failure. -bool LoadPrivateKeyOpenSSL( - const base::FilePath& filepath, - crypto::ScopedEVP_PKEY* pkey) { - std::string data; - if (!base::ReadFileToString(filepath, &data)) { - LOG(ERROR) << "Could not read private key file: " - << filepath.value() << ": " << strerror(errno); - return false; - } - crypto::ScopedBIO bio(BIO_new_mem_buf( - const_cast<char*>(reinterpret_cast<const char*>(data.data())), - static_cast<int>(data.size()))); - if (!bio.get()) { - LOG(ERROR) << "Could not allocate BIO for buffer?"; - return false; - } - EVP_PKEY* result = PEM_read_bio_PrivateKey(bio.get(), NULL, NULL, NULL); - if (result == NULL) { - LOG(ERROR) << "Could not decode private key file: " - << filepath.value(); - return false; - } - pkey->reset(result); - return true; -} - -class SSLClientSocketOpenSSLClientAuthTest : public PlatformTest { - public: - SSLClientSocketOpenSSLClientAuthTest() - : socket_factory_(ClientSocketFactory::GetDefaultFactory()), - cert_verifier_(new MockCertVerifier), - transport_security_state_(new TransportSecurityState) { - cert_verifier_->set_default_result(OK); - context_.cert_verifier = cert_verifier_.get(); - context_.transport_security_state = transport_security_state_.get(); - key_store_ = OpenSSLClientKeyStore::GetInstance(); - } - - ~SSLClientSocketOpenSSLClientAuthTest() override { key_store_->Flush(); } - - protected: - scoped_ptr<SSLClientSocket> CreateSSLClientSocket( - scoped_ptr<StreamSocket> transport_socket, - const HostPortPair& host_and_port, - const SSLConfig& ssl_config) { - scoped_ptr<ClientSocketHandle> connection(new ClientSocketHandle); - connection->SetSocket(std::move(transport_socket)); - return socket_factory_->CreateSSLClientSocket( - std::move(connection), host_and_port, ssl_config, context_); - } - - // Connect to a HTTPS test server. - bool ConnectToTestServer(SpawnedTestServer::SSLOptions& ssl_options) { - test_server_.reset(new SpawnedTestServer(SpawnedTestServer::TYPE_HTTPS, - ssl_options, - base::FilePath())); - if (!test_server_->Start()) { - LOG(ERROR) << "Could not start SpawnedTestServer"; - return false; - } - - if (!test_server_->GetAddressList(&addr_)) { - LOG(ERROR) << "Could not get SpawnedTestServer address list"; - return false; - } - - transport_.reset(new TCPClientSocket( - addr_, &log_, NetLog::Source())); - int rv = callback_.GetResult( - transport_->Connect(callback_.callback())); - if (rv != OK) { - LOG(ERROR) << "Could not connect to SpawnedTestServer"; - return false; - } - return true; - } - - // Record a certificate's private key to ensure it can be used - // by the OpenSSL-based SSLClientSocket implementation. - // |ssl_config| provides a client certificate. - // |private_key| must be an EVP_PKEY for the corresponding private key. - // Returns true on success, false on failure. - bool RecordPrivateKey(SSLConfig& ssl_config, - EVP_PKEY* private_key) { - return key_store_->RecordClientCertPrivateKey( - ssl_config.client_cert.get(), private_key); - } - - // Create an SSLClientSocket object and use it to connect to a test - // server, then wait for connection results. This must be called after - // a succesful ConnectToTestServer() call. - // |ssl_config| the SSL configuration to use. - // |result| will retrieve the ::Connect() result value. - // Returns true on succes, false otherwise. Success means that the socket - // could be created and its Connect() was called, not that the connection - // itself was a success. - bool CreateAndConnectSSLClientSocket(const SSLConfig& ssl_config, - int* result) { - sock_ = CreateSSLClientSocket(std::move(transport_), - test_server_->host_port_pair(), ssl_config); - - if (sock_->IsConnected()) { - LOG(ERROR) << "SSL Socket prematurely connected"; - return false; - } - - *result = callback_.GetResult(sock_->Connect(callback_.callback())); - return true; - } - - - // Check that the client certificate was sent. - // Returns true on success. - bool CheckSSLClientSocketSentCert() { - SSLInfo ssl_info; - sock_->GetSSLInfo(&ssl_info); - return ssl_info.client_cert_sent; - } - - ClientSocketFactory* socket_factory_; - scoped_ptr<MockCertVerifier> cert_verifier_; - scoped_ptr<TransportSecurityState> transport_security_state_; - SSLClientSocketContext context_; - OpenSSLClientKeyStore* key_store_; - scoped_ptr<SpawnedTestServer> test_server_; - AddressList addr_; - TestCompletionCallback callback_; - NetLog log_; - scoped_ptr<StreamSocket> transport_; - scoped_ptr<SSLClientSocket> sock_; -}; - -// Connect to a server requesting client authentication, do not send -// any client certificates. It should refuse the connection. -TEST_F(SSLClientSocketOpenSSLClientAuthTest, NoCert) { - SpawnedTestServer::SSLOptions ssl_options; - ssl_options.request_client_certificate = true; - - ASSERT_TRUE(ConnectToTestServer(ssl_options)); - - base::FilePath certs_dir = GetTestCertsDirectory(); - - int rv; - ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); - - EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); - EXPECT_FALSE(sock_->IsConnected()); -} - -// Connect to a server requesting client authentication, and send it -// an empty certificate. It should refuse the connection. -TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendEmptyCert) { - SpawnedTestServer::SSLOptions ssl_options; - ssl_options.request_client_certificate = true; - ssl_options.client_authorities.push_back( - GetTestClientCertsDirectory().AppendASCII("client_1_ca.pem")); - - ASSERT_TRUE(ConnectToTestServer(ssl_options)); - - base::FilePath certs_dir = GetTestCertsDirectory(); - SSLConfig ssl_config; - ssl_config.send_client_cert = true; - ssl_config.client_cert = NULL; - ssl_config.client_private_key = NULL; - - int rv; - ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); - - EXPECT_EQ(OK, rv); - EXPECT_TRUE(sock_->IsConnected()); -} - -// Connect to a server requesting client authentication. Send it a -// matching certificate. It should allow the connection. -TEST_F(SSLClientSocketOpenSSLClientAuthTest, SendGoodCert) { - SpawnedTestServer::SSLOptions ssl_options; - ssl_options.request_client_certificate = true; - ssl_options.client_authorities.push_back( - GetTestClientCertsDirectory().AppendASCII("client_1_ca.pem")); - - ASSERT_TRUE(ConnectToTestServer(ssl_options)); - - base::FilePath certs_dir = GetTestCertsDirectory(); - SSLConfig ssl_config; - ssl_config.send_client_cert = true; - ssl_config.client_cert = ImportCertFromFile(certs_dir, "client_1.pem"); - - // This is required to ensure that signing works with the client - // certificate's private key. - crypto::ScopedEVP_PKEY client_private_key; - ASSERT_TRUE(LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key"), - &client_private_key)); - EXPECT_TRUE(RecordPrivateKey(ssl_config, client_private_key.get())); - - ssl_config.client_private_key = - FetchClientCertPrivateKey(ssl_config.client_cert.get()); - - int rv; - ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); - - EXPECT_EQ(OK, rv); - EXPECT_TRUE(sock_->IsConnected()); - - EXPECT_TRUE(CheckSSLClientSocketSentCert()); - - sock_->Disconnect(); - EXPECT_FALSE(sock_->IsConnected()); -} -#endif // defined(USE_OPENSSL_CERTS) - -} // namespace -} // namespace net diff --git a/chromium/net/socket/ssl_client_socket_pool.cc b/chromium/net/socket/ssl_client_socket_pool.cc index f9a405854b3..2a2865c65f9 100644 --- a/chromium/net/socket/ssl_client_socket_pool.cc +++ b/chromium/net/socket/ssl_client_socket_pool.cc @@ -12,6 +12,7 @@ #include "base/metrics/histogram_macros.h" #include "base/metrics/sparse_histogram.h" #include "base/profiler/scoped_tracker.h" +#include "base/trace_event/trace_event.h" #include "base/values.h" #include "net/base/host_port_pair.h" #include "net/base/net_errors.h" @@ -45,19 +46,11 @@ SSLSocketParams::SSLSocketParams( ssl_config_(ssl_config), privacy_mode_(privacy_mode), load_flags_(load_flags), - expect_spdy_(expect_spdy), - ignore_limits_(false) { - if (direct_params_.get()) { - DCHECK(!socks_proxy_params_.get()); - DCHECK(!http_proxy_params_.get()); - ignore_limits_ = direct_params_->ignore_limits(); - } else if (socks_proxy_params_.get()) { - DCHECK(!http_proxy_params_.get()); - ignore_limits_ = socks_proxy_params_->ignore_limits(); - } else { - DCHECK(http_proxy_params_.get()); - ignore_limits_ = http_proxy_params_->ignore_limits(); - } + expect_spdy_(expect_spdy) { + // Only one set of lower level pool params should be non-NULL. + DCHECK((direct_params_ && !socks_proxy_params_ && !http_proxy_params_) || + (!direct_params_ && socks_proxy_params_ && !http_proxy_params_) || + (!direct_params_ && !socks_proxy_params_ && http_proxy_params_)); } SSLSocketParams::~SSLSocketParams() {} @@ -101,6 +94,7 @@ static const int kSSLHandshakeTimeoutInSeconds = 30; SSLConnectJob::SSLConnectJob(const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<SSLSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -113,6 +107,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, : ConnectJob(group_name, timeout_duration, priority, + respect_limits, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), params_(params), @@ -129,8 +124,7 @@ SSLConnectJob::SSLConnectJob(const std::string& group_name, ? "pm/" + context.ssl_session_cache_shard : context.ssl_session_cache_shard)), callback_( - base::Bind(&SSLConnectJob::OnIOComplete, base::Unretained(this))) { -} + base::Bind(&SSLConnectJob::OnIOComplete, base::Unretained(this))) {} SSLConnectJob::~SSLConnectJob() { } @@ -179,6 +173,7 @@ void SSLConnectJob::OnIOComplete(int result) { } int SSLConnectJob::DoLoop(int result) { + TRACE_EVENT0("net", "SSLConnectJob::DoLoop"); DCHECK_NE(next_state_, STATE_NONE); int rv = result; @@ -232,7 +227,8 @@ int SSLConnectJob::DoTransportConnect() { scoped_refptr<TransportSocketParams> direct_params = params_->GetDirectConnectionParams(); return transport_socket_handle_->Init(group_name(), direct_params, priority(), - callback_, transport_pool_, net_log()); + respect_limits(), callback_, + transport_pool_, net_log()); } int SSLConnectJob::DoTransportConnectComplete(int result) { @@ -252,8 +248,8 @@ int SSLConnectJob::DoSOCKSConnect() { scoped_refptr<SOCKSSocketParams> socks_proxy_params = params_->GetSocksProxyConnectionParams(); return transport_socket_handle_->Init(group_name(), socks_proxy_params, - priority(), callback_, socks_pool_, - net_log()); + priority(), respect_limits(), callback_, + socks_pool_, net_log()); } int SSLConnectJob::DoSOCKSConnectComplete(int result) { @@ -271,8 +267,8 @@ int SSLConnectJob::DoTunnelConnect() { scoped_refptr<HttpProxySocketParams> http_proxy_params = params_->GetHttpProxyConnectionParams(); return transport_socket_handle_->Init(group_name(), http_proxy_params, - priority(), callback_, http_proxy_pool_, - net_log()); + priority(), respect_limits(), callback_, + http_proxy_pool_, net_log()); } int SSLConnectJob::DoTunnelConnectComplete(int result) { @@ -295,6 +291,7 @@ int SSLConnectJob::DoTunnelConnectComplete(int result) { } int SSLConnectJob::DoSSLConnect() { + TRACE_EVENT0("net", "SSLConnectJob::DoSSLConnect"); // TODO(pkasting): Remove ScopedTracker below once crbug.com/462815 is fixed. tracked_objects::ScopedTracker tracking_profile( FROM_HERE_WITH_EXPLICIT_FUNCTION("462815 SSLConnectJob::DoSSLConnect")); @@ -558,17 +555,10 @@ scoped_ptr<ConnectJob> SSLClientSocketPool::SSLConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return scoped_ptr<ConnectJob>(new SSLConnectJob(group_name, - request.priority(), - request.params(), - ConnectionTimeout(), - transport_pool_, - socks_pool_, - http_proxy_pool_, - client_socket_factory_, - context_, - delegate, - net_log_)); + return scoped_ptr<ConnectJob>(new SSLConnectJob( + group_name, request.priority(), request.respect_limits(), + request.params(), ConnectionTimeout(), transport_pool_, socks_pool_, + http_proxy_pool_, client_socket_factory_, context_, delegate, net_log_)); } base::TimeDelta SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout() @@ -579,6 +569,7 @@ base::TimeDelta SSLClientSocketPool::SSLConnectJobFactory::ConnectionTimeout() int SSLClientSocketPool::RequestSocket(const std::string& group_name, const void* socket_params, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) { @@ -586,7 +577,7 @@ int SSLClientSocketPool::RequestSocket(const std::string& group_name, static_cast<const scoped_refptr<SSLSocketParams>*>(socket_params); return base_.RequestSocket(group_name, *casted_socket_params, priority, - handle, callback, net_log); + respect_limits, handle, callback, net_log); } void SSLClientSocketPool::RequestSockets( diff --git a/chromium/net/socket/ssl_client_socket_pool.h b/chromium/net/socket/ssl_client_socket_pool.h index b015baeb797..d5f480799ef 100644 --- a/chromium/net/socket/ssl_client_socket_pool.h +++ b/chromium/net/socket/ssl_client_socket_pool.h @@ -72,7 +72,6 @@ class NET_EXPORT_PRIVATE SSLSocketParams PrivacyMode privacy_mode() const { return privacy_mode_; } int load_flags() const { return load_flags_; } bool expect_spdy() const { return expect_spdy_; } - bool ignore_limits() const { return ignore_limits_; } private: friend class base::RefCounted<SSLSocketParams>; @@ -86,7 +85,6 @@ class NET_EXPORT_PRIVATE SSLSocketParams const PrivacyMode privacy_mode_; const int load_flags_; const bool expect_spdy_; - bool ignore_limits_; DISALLOW_COPY_AND_ASSIGN(SSLSocketParams); }; @@ -99,6 +97,7 @@ class SSLConnectJob : public ConnectJob { // job. SSLConnectJob(const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<SSLSocketParams>& params, const base::TimeDelta& timeout_duration, TransportClientSocketPool* transport_pool, @@ -205,6 +204,7 @@ class NET_EXPORT_PRIVATE SSLClientSocketPool int RequestSocket(const std::string& group_name, const void* connect_params, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) override; diff --git a/chromium/net/socket/ssl_client_socket_pool_unittest.cc b/chromium/net/socket/ssl_client_socket_pool_unittest.cc index 2baae896fb0..f83ffd50d13 100644 --- a/chromium/net/socket/ssl_client_socket_pool_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_pool_unittest.cc @@ -87,7 +87,6 @@ class SSLClientSocketPoolTest direct_transport_socket_params_(new TransportSocketParams( HostPortPair("host", 443), false, - false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), transport_socket_pool_(kMaxSockets, @@ -96,7 +95,6 @@ class SSLClientSocketPoolTest proxy_transport_socket_params_(new TransportSocketParams( HostPortPair("proxy", 443), false, - false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), socks_socket_params_( @@ -170,7 +168,6 @@ class SSLClientSocketPoolTest params.http_auth_handler_factory = http_auth_handler_factory_.get(); params.http_server_properties = http_server_properties_.GetWeakPtr(); - params.enable_spdy_compression = false; params.spdy_default_protocol = GetParam(); return new HttpNetworkSession(params); } @@ -217,8 +214,9 @@ TEST_P(SSLClientSocketPoolTest, TCPFail) { false); ClientSocketHandle handle; - int rv = handle.Init("a", params, MEDIUM, CompletionCallback(), pool_.get(), - BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + CompletionCallback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -238,8 +236,9 @@ TEST_P(SSLClientSocketPoolTest, TCPFailAsync) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -265,8 +264,9 @@ TEST_P(SSLClientSocketPoolTest, BasicDirect) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); @@ -291,8 +291,9 @@ TEST_P(SSLClientSocketPoolTest, SetSocketRequestPriorityOnInitDirect) { ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(OK, handle.Init("a", params, priority, callback.callback(), - pool_.get(), BoundNetLog())); + EXPECT_EQ(OK, handle.Init("a", params, priority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(priority, transport_socket_pool_.last_request_priority()); handle.socket()->Disconnect(); } @@ -310,8 +311,9 @@ TEST_P(SSLClientSocketPoolTest, BasicDirectAsync) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -334,8 +336,9 @@ TEST_P(SSLClientSocketPoolTest, DirectCertError) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -358,8 +361,9 @@ TEST_P(SSLClientSocketPoolTest, DirectSSLError) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -383,8 +387,9 @@ TEST_P(SSLClientSocketPoolTest, DirectWithNPN) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -410,8 +415,9 @@ TEST_P(SSLClientSocketPoolTest, DirectNoSPDY) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -435,8 +441,9 @@ TEST_P(SSLClientSocketPoolTest, DirectGotSPDY) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -466,8 +473,9 @@ TEST_P(SSLClientSocketPoolTest, DirectGotBonusSPDY) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -495,8 +503,9 @@ TEST_P(SSLClientSocketPoolTest, SOCKSFail) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -514,8 +523,9 @@ TEST_P(SSLClientSocketPoolTest, SOCKSFailAsync) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -539,8 +549,9 @@ TEST_P(SSLClientSocketPoolTest, SOCKSBasic) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); @@ -564,8 +575,9 @@ TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitSOCKS) { ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, callback.callback(), - pool_.get(), BoundNetLog())); + EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(HIGHEST, transport_socket_pool_.last_request_priority()); } @@ -581,8 +593,9 @@ TEST_P(SSLClientSocketPoolTest, SOCKSBasicAsync) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -606,8 +619,9 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyFail) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_PROXY_CONNECTION_FAILED, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -625,8 +639,9 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyFailAsync) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -662,8 +677,9 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyBasic) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_TRUE(handle.is_initialized()); EXPECT_TRUE(handle.socket()); @@ -697,8 +713,9 @@ TEST_P(SSLClientSocketPoolTest, SetTransportPriorityOnInitHTTP) { ClientSocketHandle handle; TestCompletionCallback callback; - EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, callback.callback(), - pool_.get(), BoundNetLog())); + EXPECT_EQ(OK, handle.Init("a", params, HIGHEST, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog())); EXPECT_EQ(HIGHEST, transport_socket_pool_.last_request_priority()); } @@ -726,8 +743,9 @@ TEST_P(SSLClientSocketPoolTest, HttpProxyBasicAsync) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -763,8 +781,9 @@ TEST_P(SSLClientSocketPoolTest, NeedProxyAuth) { ClientSocketHandle handle; TestCompletionCallback callback; - int rv = handle.Init( - "a", params, MEDIUM, callback.callback(), pool_.get(), BoundNetLog()); + int rv = + handle.Init("a", params, MEDIUM, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), pool_.get(), BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); diff --git a/chromium/net/socket/ssl_client_socket_unittest.cc b/chromium/net/socket/ssl_client_socket_unittest.cc index 6a9d4654bdf..d5656567984 100644 --- a/chromium/net/socket/ssl_client_socket_unittest.cc +++ b/chromium/net/socket/ssl_client_socket_unittest.cc @@ -7,6 +7,7 @@ #include <utility> #include "base/callback_helpers.h" +#include "base/files/file_util.h" #include "base/location.h" #include "base/macros.h" #include "base/memory/ref_counted.h" @@ -21,6 +22,7 @@ #include "net/base/test_data_directory.h" #include "net/cert/asn1_util.h" #include "net/cert/ct_policy_enforcer.h" +#include "net/cert/ct_policy_status.h" #include "net/cert/ct_verifier.h" #include "net/cert/mock_cert_verifier.h" #include "net/cert/test_root_certs.h" @@ -49,6 +51,17 @@ #include "testing/gtest/include/gtest/gtest.h" #include "testing/platform_test.h" +#if defined(USE_OPENSSL) +#include <errno.h> +#include <openssl/bio.h> +#include <openssl/evp.h> +#include <openssl/pem.h> +#include <string.h> + +#include "crypto/scoped_openssl_types.h" +#include "net/ssl/test_ssl_private_key.h" +#endif + using testing::_; using testing::Return; using testing::Truly; @@ -90,9 +103,6 @@ class WrappedStreamSocket : public StreamSocket { } void SetOmniboxSpeculation() override { transport_->SetOmniboxSpeculation(); } bool WasEverUsed() const override { return transport_->WasEverUsed(); } - bool UsingTCPFastOpen() const override { - return transport_->UsingTCPFastOpen(); - } bool WasNpnNegotiated() const override { return transport_->WasNpnNegotiated(); } @@ -643,6 +653,7 @@ class FailingChannelIDStore : public ChannelIDStore { void GetAllChannelIDs(const GetChannelIDListCallback& callback) override {} int GetChannelIDCount() override { return 0; } void SetForceKeepSessionState() override {} + bool IsEphemeral() override { return true; } }; // A ChannelIDStore that asynchronously returns an error when asked for a @@ -667,6 +678,7 @@ class AsyncFailingChannelIDStore : public ChannelIDStore { void GetAllChannelIDs(const GetChannelIDListCallback& callback) override {} int GetChannelIDCount() override { return 0; } void SetForceKeepSessionState() override {} + bool IsEphemeral() override { return true; } }; // A mock CTVerifier that records every call to Verify but doesn't verify @@ -684,11 +696,15 @@ class MockCTVerifier : public CTVerifier { // A mock CTPolicyEnforcer that returns a custom verification result. class MockCTPolicyEnforcer : public CTPolicyEnforcer { public: + MOCK_METHOD3(DoesConformToCertPolicy, + ct::CertPolicyCompliance(X509Certificate* cert, + const ct::SCTList&, + const BoundNetLog&)); MOCK_METHOD4(DoesConformToCTEVPolicy, - bool(X509Certificate* cert, - const ct::EVCertsWhitelist*, - const ct::CTVerifyResult&, - const BoundNetLog&)); + ct::EVPolicyCompliance(X509Certificate* cert, + const ct::EVCertsWhitelist*, + const ct::SCTList&, + const BoundNetLog&)); }; class SSLClientSocketTest : public PlatformTest { @@ -2334,8 +2350,12 @@ TEST_F(SSLClientSocketTest, EVCertStatusMaintainedForCompliantCert) { // Emulate compliance of the certificate to the policy. MockCTPolicyEnforcer policy_enforcer; SetCTPolicyEnforcer(&policy_enforcer); + EXPECT_CALL(policy_enforcer, DoesConformToCertPolicy(_, _, _)) + .WillRepeatedly( + Return(ct::CertPolicyCompliance::CERT_POLICY_COMPLIES_VIA_SCTS)); EXPECT_CALL(policy_enforcer, DoesConformToCTEVPolicy(_, _, _, _)) - .WillRepeatedly(Return(true)); + .WillRepeatedly( + Return(ct::EVPolicyCompliance::EV_POLICY_COMPLIES_VIA_SCTS)); int rv; ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); @@ -2366,8 +2386,12 @@ TEST_F(SSLClientSocketTest, EVCertStatusRemovedForNonCompliantCert) { // Emulate non-compliance of the certificate to the policy. MockCTPolicyEnforcer policy_enforcer; SetCTPolicyEnforcer(&policy_enforcer); + EXPECT_CALL(policy_enforcer, DoesConformToCertPolicy(_, _, _)) + .WillRepeatedly( + Return(ct::CertPolicyCompliance::CERT_POLICY_NOT_ENOUGH_SCTS)); EXPECT_CALL(policy_enforcer, DoesConformToCTEVPolicy(_, _, _, _)) - .WillRepeatedly(Return(false)); + .WillRepeatedly( + Return(ct::EVPolicyCompliance::EV_POLICY_NOT_ENOUGH_SCTS)); int rv; ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); @@ -3240,4 +3264,111 @@ TEST_F(SSLClientSocketTest, NPNServerDisabled) { sock_->GetNextProto(&proto)); } +// Client auth is not supported in NSS ports. +#if defined(USE_OPENSSL) + +namespace { + +// Loads a PEM-encoded private key file into a SSLPrivateKey object. +// |filepath| is the private key file path. +// Returns the new SSLPrivateKey. +scoped_refptr<SSLPrivateKey> LoadPrivateKeyOpenSSL( + const base::FilePath& filepath) { + std::string data; + if (!base::ReadFileToString(filepath, &data)) { + LOG(ERROR) << "Could not read private key file: " << filepath.value(); + return nullptr; + } + crypto::ScopedBIO bio(BIO_new_mem_buf(const_cast<char*>(data.data()), + static_cast<int>(data.size()))); + if (!bio) { + LOG(ERROR) << "Could not allocate BIO for buffer?"; + return nullptr; + } + crypto::ScopedEVP_PKEY result( + PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr)); + if (!result) { + LOG(ERROR) << "Could not decode private key file: " << filepath.value(); + return nullptr; + } + return WrapOpenSSLPrivateKey(std::move(result)); +} + +} // namespace + +// Connect to a server requesting client authentication, do not send +// any client certificates. It should refuse the connection. +TEST_F(SSLClientSocketTest, NoCert) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + ASSERT_TRUE(StartTestServer(ssl_options)); + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(SSLConfig(), &rv)); + + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, rv); + EXPECT_FALSE(sock_->IsConnected()); +} + +// Connect to a server requesting client authentication, and send it +// an empty certificate. +TEST_F(SSLClientSocketTest, SendEmptyCert) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + ssl_options.client_authorities.push_back( + GetTestClientCertsDirectory().AppendASCII("client_1_ca.pem")); + + ASSERT_TRUE(StartTestServer(ssl_options)); + + SSLConfig ssl_config; + ssl_config.send_client_cert = true; + ssl_config.client_cert = nullptr; + ssl_config.client_private_key = nullptr; + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock_->IsConnected()); + + SSLInfo ssl_info; + ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info)); + EXPECT_FALSE(ssl_info.client_cert_sent); +} + +// Connect to a server requesting client authentication. Send it a +// matching certificate. It should allow the connection. +TEST_F(SSLClientSocketTest, SendGoodCert) { + SpawnedTestServer::SSLOptions ssl_options; + ssl_options.request_client_certificate = true; + ssl_options.client_authorities.push_back( + GetTestClientCertsDirectory().AppendASCII("client_1_ca.pem")); + + ASSERT_TRUE(StartTestServer(ssl_options)); + + base::FilePath certs_dir = GetTestCertsDirectory(); + SSLConfig ssl_config; + ssl_config.send_client_cert = true; + ssl_config.client_cert = ImportCertFromFile(certs_dir, "client_1.pem"); + + // This is required to ensure that signing works with the client + // certificate's private key. + ssl_config.client_private_key = + LoadPrivateKeyOpenSSL(certs_dir.AppendASCII("client_1.key")); + + int rv; + ASSERT_TRUE(CreateAndConnectSSLClientSocket(ssl_config, &rv)); + + EXPECT_EQ(OK, rv); + EXPECT_TRUE(sock_->IsConnected()); + + SSLInfo ssl_info; + ASSERT_TRUE(sock_->GetSSLInfo(&ssl_info)); + EXPECT_TRUE(ssl_info.client_cert_sent); + + sock_->Disconnect(); + EXPECT_FALSE(sock_->IsConnected()); +} +#endif // defined(USE_OPENSSL) + } // namespace net diff --git a/chromium/net/socket/ssl_server_socket.h b/chromium/net/socket/ssl_server_socket.h index bfbe7de9110..479bbc7a4f9 100644 --- a/chromium/net/socket/ssl_server_socket.h +++ b/chromium/net/socket/ssl_server_socket.h @@ -31,6 +31,20 @@ class SSLServerSocket : public SSLSocket { virtual int Handshake(const CompletionCallback& callback) = 0; }; +class SSLServerContext { + public: + virtual ~SSLServerContext(){}; + + // Creates an SSL server socket over an already-connected transport socket. + // The caller must ensure the returned socket does not outlive the server + // context. + // + // The caller starts the SSL server handshake by calling Handshake on the + // returned socket. + virtual scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket) = 0; +}; + // Configures the underlying SSL library for the use of SSL server sockets. // // Due to the requirements of the underlying libraries, this should be called @@ -41,18 +55,14 @@ class SSLServerSocket : public SSLSocket { // omitted. NET_EXPORT void EnableSSLServerSockets(); -// Creates an SSL server socket over an already-connected transport socket. -// The caller must provide the server certificate and private key to use. +// Creates an SSL server socket context where all sockets spawned using this +// context will share the same session cache. // -// The returned SSLServerSocket takes ownership of |socket|. Stubbed versions -// of CreateSSLServerSocket will delete |socket| and return NULL. +// The caller must provide the server certificate and private key to use. // It takes a reference to |certificate|. // The |key| and |ssl_config| parameters are copied. // -// The caller starts the SSL server handshake by calling Handshake on the -// returned socket. -NET_EXPORT scoped_ptr<SSLServerSocket> CreateSSLServerSocket( - scoped_ptr<StreamSocket> socket, +NET_EXPORT scoped_ptr<SSLServerContext> CreateSSLServerContext( X509Certificate* certificate, const crypto::RSAPrivateKey& key, const SSLServerConfig& ssl_config); diff --git a/chromium/net/socket/ssl_server_socket_nss.cc b/chromium/net/socket/ssl_server_socket_nss.cc index 80450fe65fa..8e02909cf54 100644 --- a/chromium/net/socket/ssl_server_socket_nss.cc +++ b/chromium/net/socket/ssl_server_socket_nss.cc @@ -75,29 +75,140 @@ class NSSSSLServerInitSingleton { static base::LazyInstance<NSSSSLServerInitSingleton>::Leaky g_nss_ssl_server_init_singleton = LAZY_INSTANCE_INITIALIZER; -} // namespace - -void EnableSSLServerSockets() { - g_nss_ssl_server_init_singleton.Get(); -} - -scoped_ptr<SSLServerSocket> CreateSSLServerSocket( - scoped_ptr<StreamSocket> socket, - X509Certificate* cert, - const crypto::RSAPrivateKey& key, - const SSLServerConfig& ssl_config) { - DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been" - << " called yet!"; - - return scoped_ptr<SSLServerSocket>( - new SSLServerSocketNSS(std::move(socket), cert, key, ssl_config)); -} +class SSLServerSocketNSS : public SSLServerSocket { + public: + // See comments on CreateSSLServerSocket for details of how these + // parameters are used. + SSLServerSocketNSS(scoped_ptr<StreamSocket> socket, + X509Certificate* certificate, + const crypto::RSAPrivateKey& key, + const SSLServerConfig& ssl_server_config); + ~SSLServerSocketNSS() override; + + // SSLServerSocket interface. + int Handshake(const CompletionCallback& callback) override; + + // SSLSocket interface. + int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) override; + + // Socket interface (via StreamSocket). + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + // StreamSocket implementation. + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} + int64_t GetTotalReceivedBytes() const override; + + private: + enum State { + STATE_NONE, + STATE_HANDSHAKE, + }; + + int InitializeSSLOptions(); + + void OnSendComplete(int result); + void OnRecvComplete(int result); + void OnHandshakeIOComplete(int result); + + int BufferSend(); + void BufferSendComplete(int result); + int BufferRecv(); + void BufferRecvComplete(int result); + bool DoTransportIO(); + int DoPayloadRead(); + int DoPayloadWrite(); + + int DoHandshakeLoop(int last_io_result); + int DoReadLoop(int result); + int DoWriteLoop(int result); + int DoHandshake(); + void DoHandshakeCallback(int result); + void DoReadCallback(int result); + void DoWriteCallback(int result); + + static SECStatus OwnAuthCertHandler(void* arg, + PRFileDesc* socket, + PRBool checksig, + PRBool is_server); + static void HandshakeCallback(PRFileDesc* socket, void* arg); + + int Init(); + + // Members used to send and receive buffer. + bool transport_send_busy_; + bool transport_recv_busy_; + + scoped_refptr<IOBuffer> recv_buffer_; + + BoundNetLog net_log_; + + CompletionCallback user_handshake_callback_; + CompletionCallback user_read_callback_; + CompletionCallback user_write_callback_; + + // Used by Read function. + scoped_refptr<IOBuffer> user_read_buf_; + int user_read_buf_len_; + + // Used by Write function. + scoped_refptr<IOBuffer> user_write_buf_; + int user_write_buf_len_; + + // The NSS SSL state machine + PRFileDesc* nss_fd_; + + // Buffers for the network end of the SSL state machine + memio_Private* nss_bufs_; + + // StreamSocket for sending and receiving data. + scoped_ptr<StreamSocket> transport_socket_; + + // Options for the SSL socket. + SSLServerConfig ssl_server_config_; + + // Certificate for the server. + scoped_refptr<X509Certificate> cert_; + + // Private key used by the server. + scoped_ptr<crypto::RSAPrivateKey> key_; + + State next_handshake_state_; + bool completed_handshake_; + + DISALLOW_COPY_AND_ASSIGN(SSLServerSocketNSS); +}; SSLServerSocketNSS::SSLServerSocketNSS( scoped_ptr<StreamSocket> transport_socket, - scoped_refptr<X509Certificate> cert, + X509Certificate* cert, const crypto::RSAPrivateKey& key, - const SSLServerConfig& ssl_config) + const SSLServerConfig& ssl_server_config) : transport_send_busy_(false), transport_recv_busy_(false), user_read_buf_len_(0), @@ -105,7 +216,7 @@ SSLServerSocketNSS::SSLServerSocketNSS( nss_fd_(NULL), nss_bufs_(NULL), transport_socket_(std::move(transport_socket)), - ssl_config_(ssl_config), + ssl_server_config_(ssl_server_config), cert_(cert), key_(key.Copy()), next_handshake_state_(STATE_NONE), @@ -172,28 +283,13 @@ int SSLServerSocketNSS::ExportKeyingMaterial(const base::StringPiece& label, return OK; } -int SSLServerSocketNSS::GetTLSUniqueChannelBinding(std::string* out) { - if (!IsConnected()) - return ERR_SOCKET_NOT_CONNECTED; - unsigned char buf[64]; - unsigned int len; - SECStatus result = SSL_GetChannelBinding(nss_fd_, - SSL_CHANNEL_BINDING_TLS_UNIQUE, - buf, &len, arraysize(buf)); - if (result != SECSuccess) { - LogFailedNSSFunction(net_log_, "SSL_GetChannelBinding", ""); - return MapNSSError(PORT_GetError()); - } - out->assign(reinterpret_cast<char*>(buf), len); - return OK; -} - int SSLServerSocketNSS::Connect(const CompletionCallback& callback) { NOTIMPLEMENTED(); return ERR_NOT_IMPLEMENTED; } -int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, +int SSLServerSocketNSS::Read(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) { DCHECK(user_read_callback_.is_null()); DCHECK(user_handshake_callback_.is_null()); @@ -217,7 +313,8 @@ int SSLServerSocketNSS::Read(IOBuffer* buf, int buf_len, return rv; } -int SSLServerSocketNSS::Write(IOBuffer* buf, int buf_len, +int SSLServerSocketNSS::Write(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) { DCHECK(user_write_callback_.is_null()); DCHECK(!user_write_buf_); @@ -288,10 +385,6 @@ bool SSLServerSocketNSS::WasEverUsed() const { return transport_socket_->WasEverUsed(); } -bool SSLServerSocketNSS::UsingTCPFastOpen() const { - return transport_socket_->UsingTCPFastOpen(); -} - bool SSLServerSocketNSS::WasNpnNegotiated() const { NOTIMPLEMENTED(); return false; @@ -337,7 +430,8 @@ int SSLServerSocketNSS::InitializeSSLOptions() { int rv; - if (ssl_config_.require_client_cert) { + if (ssl_server_config_.client_cert_type == + SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT) { rv = SSL_OptionSet(nss_fd_, SSL_REQUEST_CERTIFICATE, PR_TRUE); if (rv != SECSuccess) { LogFailedNSSFunction(net_log_, "SSL_OptionSet", @@ -359,15 +453,15 @@ int SSLServerSocketNSS::InitializeSSLOptions() { } SSLVersionRange version_range; - version_range.min = ssl_config_.version_min; - version_range.max = ssl_config_.version_max; + version_range.min = ssl_server_config_.version_min; + version_range.max = ssl_server_config_.version_max; rv = SSL_VersionRangeSet(nss_fd_, &version_range); if (rv != SECSuccess) { LogFailedNSSFunction(net_log_, "SSL_VersionRangeSet", ""); return ERR_NO_SSL_VERSIONS_ENABLED; } - if (ssl_config_.require_ecdhe) { + if (ssl_server_config_.require_ecdhe) { const PRUint16* const ssl_ciphers = SSL_GetImplementedCiphers(); const PRUint16 num_ciphers = SSL_GetNumImplementedCiphers(); @@ -384,8 +478,8 @@ int SSLServerSocketNSS::InitializeSSLOptions() { } for (std::vector<uint16_t>::const_iterator it = - ssl_config_.disabled_cipher_suites.begin(); - it != ssl_config_.disabled_cipher_suites.end(); ++it) { + ssl_server_config_.disabled_cipher_suites.begin(); + it != ssl_server_config_.disabled_cipher_suites.end(); ++it) { // This will fail if the specified cipher is not implemented by NSS, but // the failure is harmless. SSL_CipherPrefSet(nss_fd_, *it, PR_FALSE); @@ -581,8 +675,7 @@ int SSLServerSocketNSS::BufferSend(void) { memcpy(send_buffer->data(), buf1, len1); memcpy(send_buffer->data() + len1, buf2, len2); rv = transport_socket_->Write( - send_buffer.get(), - len, + send_buffer.get(), len, base::Bind(&SSLServerSocketNSS::BufferSendComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING) { @@ -613,8 +706,7 @@ int SSLServerSocketNSS::BufferRecv(void) { } else { recv_buffer_ = new IOBuffer(nb); rv = transport_socket_->Read( - recv_buffer_.get(), - nb, + recv_buffer_.get(), nb, base::Bind(&SSLServerSocketNSS::BufferRecvComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING) { @@ -798,7 +890,7 @@ int SSLServerSocketNSS::DoHandshake() { void SSLServerSocketNSS::DoHandshakeCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - ResetAndReturn(&user_handshake_callback_).Run(rv > OK ? OK : rv); + base::ResetAndReturn(&user_handshake_callback_).Run(rv > OK ? OK : rv); } void SSLServerSocketNSS::DoReadCallback(int rv) { @@ -807,7 +899,7 @@ void SSLServerSocketNSS::DoReadCallback(int rv) { user_read_buf_ = NULL; user_read_buf_len_ = 0; - ResetAndReturn(&user_read_callback_).Run(rv); + base::ResetAndReturn(&user_read_callback_).Run(rv); } void SSLServerSocketNSS::DoWriteCallback(int rv) { @@ -816,7 +908,7 @@ void SSLServerSocketNSS::DoWriteCallback(int rv) { user_write_buf_ = NULL; user_write_buf_len_ = 0; - ResetAndReturn(&user_write_callback_).Run(rv); + base::ResetAndReturn(&user_write_callback_).Run(rv); } // static @@ -837,8 +929,7 @@ SECStatus SSLServerSocketNSS::OwnAuthCertHandler(void* arg, // static // NSS calls this when handshake is completed. // After the SSL handshake is finished we need to verify the certificate. -void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket, - void* arg) { +void SSLServerSocketNSS::HandshakeCallback(PRFileDesc* socket, void* arg) { // TODO(hclam): Implement. } @@ -853,4 +944,39 @@ int SSLServerSocketNSS::Init() { return OK; } +} // namespace + +scoped_ptr<SSLServerContext> CreateSSLServerContext( + X509Certificate* certificate, + const crypto::RSAPrivateKey& key, + const SSLServerConfig& ssl_server_config) { + return scoped_ptr<SSLServerContext>( + new SSLServerContextNSS(certificate, key, ssl_server_config)); +} + +SSLServerContextNSS::SSLServerContextNSS( + X509Certificate* certificate, + const crypto::RSAPrivateKey& key, + const SSLServerConfig& ssl_server_config) + : ssl_server_config_(ssl_server_config), + cert_(certificate), + key_(key.Copy()) { + CHECK(key_); +} + +SSLServerContextNSS::~SSLServerContextNSS() {} + +scoped_ptr<SSLServerSocket> SSLServerContextNSS::CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket) { + DCHECK(g_nss_server_sockets_init) << "EnableSSLServerSockets() has not been" + << " called yet!"; + + return scoped_ptr<SSLServerSocket>(new SSLServerSocketNSS( + std::move(socket), cert_.get(), *key_, ssl_server_config_)); +} + +void EnableSSLServerSockets() { + g_nss_ssl_server_init_singleton.Get(); +} + } // namespace net diff --git a/chromium/net/socket/ssl_server_socket_nss.h b/chromium/net/socket/ssl_server_socket_nss.h index 6bdcf112f76..497d461767a 100644 --- a/chromium/net/socket/ssl_server_socket_nss.h +++ b/chromium/net/socket/ssl_server_socket_nss.h @@ -22,135 +22,25 @@ namespace net { -class SSLServerSocketNSS : public SSLServerSocket { +class SSLServerContextNSS : public SSLServerContext { public: - // See comments on CreateSSLServerSocket for details of how these - // parameters are used. - SSLServerSocketNSS(scoped_ptr<StreamSocket> socket, - scoped_refptr<X509Certificate> certificate, - const crypto::RSAPrivateKey& key, - const SSLServerConfig& ssl_config); - ~SSLServerSocketNSS() override; + SSLServerContextNSS(X509Certificate* certificate, + const crypto::RSAPrivateKey& key, + const SSLServerConfig& ssl_server_config); + ~SSLServerContextNSS() override; - // SSLServerSocket interface. - int Handshake(const CompletionCallback& callback) override; - - // SSLSocket interface. - int ExportKeyingMaterial(const base::StringPiece& label, - bool has_context, - const base::StringPiece& context, - unsigned char* out, - unsigned int outlen) override; - int GetTLSUniqueChannelBinding(std::string* out) override; - - // Socket interface (via StreamSocket). - int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) override; - int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) override; - int SetReceiveBufferSize(int32_t size) override; - int SetSendBufferSize(int32_t size) override; - - // StreamSocket implementation. - int Connect(const CompletionCallback& callback) override; - void Disconnect() override; - bool IsConnected() const override; - bool IsConnectedAndIdle() const override; - int GetPeerAddress(IPEndPoint* address) const override; - int GetLocalAddress(IPEndPoint* address) const override; - const BoundNetLog& NetLog() const override; - void SetSubresourceSpeculation() override; - void SetOmniboxSpeculation() override; - bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; - bool WasNpnNegotiated() const override; - NextProto GetNegotiatedProtocol() const override; - bool GetSSLInfo(SSLInfo* ssl_info) override; - void GetConnectionAttempts(ConnectionAttempts* out) const override; - void ClearConnectionAttempts() override {} - void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} - int64_t GetTotalReceivedBytes() const override; + scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket) override; private: - enum State { - STATE_NONE, - STATE_HANDSHAKE, - }; - - int InitializeSSLOptions(); - - void OnSendComplete(int result); - void OnRecvComplete(int result); - void OnHandshakeIOComplete(int result); - - int BufferSend(); - void BufferSendComplete(int result); - int BufferRecv(); - void BufferRecvComplete(int result); - bool DoTransportIO(); - int DoPayloadRead(); - int DoPayloadWrite(); - - int DoHandshakeLoop(int last_io_result); - int DoReadLoop(int result); - int DoWriteLoop(int result); - int DoHandshake(); - void DoHandshakeCallback(int result); - void DoReadCallback(int result); - void DoWriteCallback(int result); - - static SECStatus OwnAuthCertHandler(void* arg, - PRFileDesc* socket, - PRBool checksig, - PRBool is_server); - static void HandshakeCallback(PRFileDesc* socket, void* arg); - - int Init(); - - // Members used to send and receive buffer. - bool transport_send_busy_; - bool transport_recv_busy_; - - scoped_refptr<IOBuffer> recv_buffer_; - - BoundNetLog net_log_; - - CompletionCallback user_handshake_callback_; - CompletionCallback user_read_callback_; - CompletionCallback user_write_callback_; - - // Used by Read function. - scoped_refptr<IOBuffer> user_read_buf_; - int user_read_buf_len_; - - // Used by Write function. - scoped_refptr<IOBuffer> user_write_buf_; - int user_write_buf_len_; - - // The NSS SSL state machine - PRFileDesc* nss_fd_; - - // Buffers for the network end of the SSL state machine - memio_Private* nss_bufs_; - - // StreamSocket for sending and receiving data. - scoped_ptr<StreamSocket> transport_socket_; - // Options for the SSL socket. - SSLServerConfig ssl_config_; + SSLServerConfig ssl_server_config_; // Certificate for the server. scoped_refptr<X509Certificate> cert_; // Private key used by the server. scoped_ptr<crypto::RSAPrivateKey> key_; - - State next_handshake_state_; - bool completed_handshake_; - - DISALLOW_COPY_AND_ASSIGN(SSLServerSocketNSS); }; } // namespace net diff --git a/chromium/net/socket/ssl_server_socket_openssl.cc b/chromium/net/socket/ssl_server_socket_openssl.cc index c3869cd865d..74f223131d8 100644 --- a/chromium/net/socket/ssl_server_socket_openssl.cc +++ b/chromium/net/socket/ssl_server_socket_openssl.cc @@ -15,48 +15,176 @@ #include "crypto/rsa_private_key.h" #include "crypto/scoped_openssl_types.h" #include "net/base/net_errors.h" +#include "net/cert/cert_verify_result.h" +#include "net/cert/client_cert_verifier.h" +#include "net/cert/x509_util_openssl.h" #include "net/ssl/openssl_ssl_util.h" -#include "net/ssl/scoped_openssl_types.h" +#include "net/ssl/ssl_connection_status_flags.h" +#include "net/ssl/ssl_info.h" #define GotoState(s) next_handshake_state_ = s namespace net { -void EnableSSLServerSockets() { - // No-op because CreateSSLServerSocket() calls crypto::EnsureOpenSSLInit(). -} +namespace { + +// Creates an X509Certificate out of the concatenation of |cert|, if non-null, +// with |chain|. +scoped_refptr<X509Certificate> CreateX509Certificate(X509* cert, + STACK_OF(X509) * chain) { + std::vector<base::StringPiece> der_chain; + base::StringPiece der_cert; + scoped_refptr<X509Certificate> client_cert; + if (cert) { + if (!x509_util::GetDER(cert, &der_cert)) + return nullptr; + der_chain.push_back(der_cert); + } -scoped_ptr<SSLServerSocket> CreateSSLServerSocket( - scoped_ptr<StreamSocket> socket, - X509Certificate* certificate, - const crypto::RSAPrivateKey& key, - const SSLServerConfig& ssl_config) { - crypto::EnsureOpenSSLInit(); - return scoped_ptr<SSLServerSocket>(new SSLServerSocketOpenSSL( - std::move(socket), certificate, key, ssl_config)); -} + for (size_t i = 0; i < sk_X509_num(chain); ++i) { + X509* x = sk_X509_value(chain, i); + if (!x509_util::GetDER(x, &der_cert)) + return nullptr; + der_chain.push_back(der_cert); + } + + return X509Certificate::CreateFromDERCertChain(der_chain); +} + +class SSLServerSocketOpenSSL : public SSLServerSocket { + public: + // See comments on CreateSSLServerSocket for details of how these + // parameters are used. + SSLServerSocketOpenSSL(scoped_ptr<StreamSocket> socket, SSL* ssl); + ~SSLServerSocketOpenSSL() override; + + // SSLServerSocket interface. + int Handshake(const CompletionCallback& callback) override; + + // SSLSocket interface. + int ExportKeyingMaterial(const base::StringPiece& label, + bool has_context, + const base::StringPiece& context, + unsigned char* out, + unsigned int outlen) override; + + // Socket interface (via StreamSocket). + int Read(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int Write(IOBuffer* buf, + int buf_len, + const CompletionCallback& callback) override; + int SetReceiveBufferSize(int32_t size) override; + int SetSendBufferSize(int32_t size) override; + + // StreamSocket implementation. + int Connect(const CompletionCallback& callback) override; + void Disconnect() override; + bool IsConnected() const override; + bool IsConnectedAndIdle() const override; + int GetPeerAddress(IPEndPoint* address) const override; + int GetLocalAddress(IPEndPoint* address) const override; + const BoundNetLog& NetLog() const override; + void SetSubresourceSpeculation() override; + void SetOmniboxSpeculation() override; + bool WasEverUsed() const override; + bool WasNpnNegotiated() const override; + NextProto GetNegotiatedProtocol() const override; + bool GetSSLInfo(SSLInfo* ssl_info) override; + void GetConnectionAttempts(ConnectionAttempts* out) const override; + void ClearConnectionAttempts() override {} + void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} + int64_t GetTotalReceivedBytes() const override; + static int CertVerifyCallback(X509_STORE_CTX* store_ctx, void* arg); + + private: + enum State { + STATE_NONE, + STATE_HANDSHAKE, + }; + + void OnSendComplete(int result); + void OnRecvComplete(int result); + void OnHandshakeIOComplete(int result); + + int BufferSend(); + void BufferSendComplete(int result); + void TransportWriteComplete(int result); + int BufferRecv(); + void BufferRecvComplete(int result); + int TransportReadComplete(int result); + bool DoTransportIO(); + int DoPayloadRead(); + int DoPayloadWrite(); + + int DoHandshakeLoop(int last_io_result); + int DoReadLoop(int result); + int DoWriteLoop(int result); + int DoHandshake(); + void DoHandshakeCallback(int result); + void DoReadCallback(int result); + void DoWriteCallback(int result); + + int Init(); + void ExtractClientCert(); + + // Members used to send and receive buffer. + bool transport_send_busy_; + bool transport_recv_busy_; + bool transport_recv_eof_; + + scoped_refptr<DrainableIOBuffer> send_buffer_; + scoped_refptr<IOBuffer> recv_buffer_; + + BoundNetLog net_log_; + + CompletionCallback user_handshake_callback_; + CompletionCallback user_read_callback_; + CompletionCallback user_write_callback_; + + // Used by Read function. + scoped_refptr<IOBuffer> user_read_buf_; + int user_read_buf_len_; + + // Used by Write function. + scoped_refptr<IOBuffer> user_write_buf_; + int user_write_buf_len_; + + // Used by TransportWriteComplete() and TransportReadComplete() to signify an + // error writing to the transport socket. A value of OK indicates no error. + int transport_write_error_; + + // OpenSSL stuff + SSL* ssl_; + BIO* transport_bio_; + + // StreamSocket for sending and receiving data. + scoped_ptr<StreamSocket> transport_socket_; + + // Certificate for the client. + scoped_refptr<X509Certificate> client_cert_; + + State next_handshake_state_; + bool completed_handshake_; + + DISALLOW_COPY_AND_ASSIGN(SSLServerSocketOpenSSL); +}; SSLServerSocketOpenSSL::SSLServerSocketOpenSSL( scoped_ptr<StreamSocket> transport_socket, - scoped_refptr<X509Certificate> certificate, - const crypto::RSAPrivateKey& key, - const SSLServerConfig& ssl_config) + SSL* ssl) : transport_send_busy_(false), transport_recv_busy_(false), transport_recv_eof_(false), user_read_buf_len_(0), user_write_buf_len_(0), transport_write_error_(OK), - ssl_(NULL), + ssl_(ssl), transport_bio_(NULL), transport_socket_(std::move(transport_socket)), - ssl_config_(ssl_config), - cert_(certificate), - key_(key.Copy()), next_handshake_state_(STATE_NONE), - completed_handshake_(false) { - CHECK(key_); -} + completed_handshake_(false) {} SSLServerSocketOpenSSL::~SSLServerSocketOpenSSL() { if (ssl_) { @@ -123,12 +251,8 @@ int SSLServerSocketOpenSSL::ExportKeyingMaterial( return OK; } -int SSLServerSocketOpenSSL::GetTLSUniqueChannelBinding(std::string* out) { - NOTIMPLEMENTED(); - return ERR_NOT_IMPLEMENTED; -} - -int SSLServerSocketOpenSSL::Read(IOBuffer* buf, int buf_len, +int SSLServerSocketOpenSSL::Read(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) { DCHECK(user_read_callback_.is_null()); DCHECK(user_handshake_callback_.is_null()); @@ -152,7 +276,8 @@ int SSLServerSocketOpenSSL::Read(IOBuffer* buf, int buf_len, return rv; } -int SSLServerSocketOpenSSL::Write(IOBuffer* buf, int buf_len, +int SSLServerSocketOpenSSL::Write(IOBuffer* buf, + int buf_len, const CompletionCallback& callback) { DCHECK(user_write_callback_.is_null()); DCHECK(!user_write_buf_); @@ -227,10 +352,6 @@ bool SSLServerSocketOpenSSL::WasEverUsed() const { return transport_socket_->WasEverUsed(); } -bool SSLServerSocketOpenSSL::UsingTCPFastOpen() const { - return transport_socket_->UsingTCPFastOpen(); -} - bool SSLServerSocketOpenSSL::WasNpnNegotiated() const { NOTIMPLEMENTED(); return false; @@ -242,8 +363,30 @@ NextProto SSLServerSocketOpenSSL::GetNegotiatedProtocol() const { } bool SSLServerSocketOpenSSL::GetSSLInfo(SSLInfo* ssl_info) { - NOTIMPLEMENTED(); - return false; + ssl_info->Reset(); + if (!completed_handshake_) + return false; + + ssl_info->cert = client_cert_; + + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl_); + CHECK(cipher); + ssl_info->security_bits = SSL_CIPHER_get_bits(cipher, NULL); + + SSLConnectionStatusSetCipherSuite( + static_cast<uint16_t>(SSL_CIPHER_get_id(cipher)), + &ssl_info->connection_status); + SSLConnectionStatusSetVersion(GetNetSSLVersion(ssl_), + &ssl_info->connection_status); + + if (!SSL_get_secure_renegotiation_support(ssl_)) + ssl_info->connection_status |= SSL_CONNECTION_NO_RENEGOTIATION_EXTENSION; + + ssl_info->handshake_type = SSL_session_reused(ssl_) + ? SSLInfo::HANDSHAKE_RESUME + : SSLInfo::HANDSHAKE_FULL; + + return true; } void SSLServerSocketOpenSSL::GetConnectionAttempts( @@ -323,8 +466,7 @@ int SSLServerSocketOpenSSL::BufferSend() { } int rv = transport_socket_->Write( - send_buffer_.get(), - send_buffer_->BytesRemaining(), + send_buffer_.get(), send_buffer_->BytesRemaining(), base::Bind(&SSLServerSocketOpenSSL::BufferSendComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING) { @@ -396,8 +538,7 @@ int SSLServerSocketOpenSSL::BufferRecv() { recv_buffer_ = new IOBuffer(max_write); int rv = transport_socket_->Read( - recv_buffer_.get(), - max_write, + recv_buffer_.get(), max_write, base::Bind(&SSLServerSocketOpenSSL::BufferRecvComplete, base::Unretained(this))); if (rv == ERR_IO_PENDING) { @@ -566,11 +707,28 @@ int SSLServerSocketOpenSSL::DoHandshake() { if (rv == 1) { completed_handshake_ = true; + // The results of SSL_get_peer_certificate() must be explicitly freed. + ScopedX509 cert(SSL_get_peer_certificate(ssl_)); + if (cert) { + // The caller does not take ownership of SSL_get_peer_cert_chain's + // results. + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_); + client_cert_ = CreateX509Certificate(cert.get(), chain); + if (!client_cert_.get()) + return ERR_SSL_CLIENT_AUTH_CERT_BAD_FORMAT; + } } else { int ssl_error = SSL_get_error(ssl_, rv); OpenSSLErrorInfo error_info; net_error = MapOpenSSLErrorWithDetails(ssl_error, err_tracer, &error_info); + // This hack is necessary because the mapping of SSL error codes to + // net_errors assumes (correctly for client sockets, but erroneously for + // server sockets) that peer cert verification failure can only occur if + // the cert changed during a renego. crbug.com/570351 + if (net_error == ERR_SSL_SERVER_CERT_CHANGED) + net_error = ERR_BAD_SSL_CLIENT_AUTH_CERT; + // If not done, stay in this state if (net_error == ERR_IO_PENDING) { GotoState(STATE_HANDSHAKE); @@ -588,7 +746,7 @@ int SSLServerSocketOpenSSL::DoHandshake() { void SSLServerSocketOpenSSL::DoHandshakeCallback(int rv) { DCHECK_NE(rv, ERR_IO_PENDING); - ResetAndReturn(&user_handshake_callback_).Run(rv > OK ? OK : rv); + base::ResetAndReturn(&user_handshake_callback_).Run(rv > OK ? OK : rv); } void SSLServerSocketOpenSSL::DoReadCallback(int rv) { @@ -597,7 +755,7 @@ void SSLServerSocketOpenSSL::DoReadCallback(int rv) { user_read_buf_ = NULL; user_read_buf_len_ = 0; - ResetAndReturn(&user_read_callback_).Run(rv); + base::ResetAndReturn(&user_read_callback_).Run(rv); } void SSLServerSocketOpenSSL::DoWriteCallback(int rv) { @@ -606,21 +764,14 @@ void SSLServerSocketOpenSSL::DoWriteCallback(int rv) { user_write_buf_ = NULL; user_write_buf_len_ = 0; - ResetAndReturn(&user_write_callback_).Run(rv); + base::ResetAndReturn(&user_write_callback_).Run(rv); } int SSLServerSocketOpenSSL::Init() { - DCHECK(!ssl_); DCHECK(!transport_bio_); crypto::OpenSSLErrStackTracer err_tracer(FROM_HERE); - ScopedSSL_CTX ssl_ctx(SSL_CTX_new(SSLv23_server_method())); - - if (ssl_config_.require_client_cert) - SSL_CTX_set_verify(ssl_ctx.get(), SSL_VERIFY_PEER, NULL); - - ssl_ = SSL_new(ssl_ctx.get()); if (!ssl_) return ERR_UNEXPECTED; @@ -633,59 +784,122 @@ int SSLServerSocketOpenSSL::Init() { SSL_set_bio(ssl_, ssl_bio, ssl_bio); + return OK; +} + +// static +int SSLServerSocketOpenSSL::CertVerifyCallback(X509_STORE_CTX* store_ctx, + void* arg) { + ClientCertVerifier* verifier = reinterpret_cast<ClientCertVerifier*>(arg); + // If a verifier was not supplied, all certificates are accepted. + if (!verifier) + return 1; + STACK_OF(X509)* chain = store_ctx->untrusted; + scoped_refptr<X509Certificate> client_cert( + CreateX509Certificate(nullptr, chain)); + if (!client_cert.get()) { + X509_STORE_CTX_set_error(store_ctx, X509_V_ERR_CERT_REJECTED); + return 0; + } + // Asynchronous completion of Verify is currently not supported. + // http://crbug.com/347402 + // The API for Verify supports the parts needed for async completion + // but is currently expected to complete synchronously. + scoped_ptr<ClientCertVerifier::Request> ignore_async; + int res = + verifier->Verify(client_cert.get(), CompletionCallback(), &ignore_async); + DCHECK_NE(res, ERR_IO_PENDING); + + if (res != OK) { + X509_STORE_CTX_set_error(store_ctx, X509_V_ERR_CERT_REJECTED); + return 0; + } + return 1; +} + +} // namespace + +scoped_ptr<SSLServerContext> CreateSSLServerContext( + X509Certificate* certificate, + const crypto::RSAPrivateKey& key, + const SSLServerConfig& ssl_server_config) { + return scoped_ptr<SSLServerContext>( + new SSLServerContextOpenSSL(certificate, key, ssl_server_config)); +} + +SSLServerContextOpenSSL::SSLServerContextOpenSSL( + X509Certificate* certificate, + const crypto::RSAPrivateKey& key, + const SSLServerConfig& ssl_server_config) + : ssl_server_config_(ssl_server_config), + cert_(certificate), + key_(key.Copy()) { + CHECK(key_); + crypto::EnsureOpenSSLInit(); + ssl_ctx_.reset(SSL_CTX_new(TLS_method())); + SSL_CTX_set_session_cache_mode(ssl_ctx_.get(), SSL_SESS_CACHE_SERVER); + uint8_t session_ctx_id = 0; + SSL_CTX_set_session_id_context(ssl_ctx_.get(), &session_ctx_id, + sizeof(session_ctx_id)); + + int verify_mode = 0; + switch (ssl_server_config_.client_cert_type) { + case SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT: + verify_mode |= SSL_VERIFY_FAIL_IF_NO_PEER_CERT; + // Fall-through + case SSLServerConfig::ClientCertType::OPTIONAL_CLIENT_CERT: + verify_mode |= SSL_VERIFY_PEER; + SSL_CTX_set_verify(ssl_ctx_.get(), verify_mode, nullptr); + SSL_CTX_set_cert_verify_callback( + ssl_ctx_.get(), SSLServerSocketOpenSSL::CertVerifyCallback, + ssl_server_config_.client_cert_verifier); + break; + case SSLServerConfig::ClientCertType::NO_CLIENT_CERT: + break; + } + // Set certificate and private key. DCHECK(cert_->os_cert_handle()); #if defined(USE_OPENSSL_CERTS) - if (SSL_use_certificate(ssl_, cert_->os_cert_handle()) != 1) { - LOG(ERROR) << "Cannot set certificate."; - return ERR_UNEXPECTED; - } + CHECK(SSL_CTX_use_certificate(ssl_ctx_.get(), cert_->os_cert_handle())); #else // Convert OSCertHandle to X509 structure. std::string der_string; - if (!X509Certificate::GetDEREncoded(cert_->os_cert_handle(), &der_string)) - return ERR_UNEXPECTED; + CHECK(X509Certificate::GetDEREncoded(cert_->os_cert_handle(), &der_string)); const unsigned char* der_string_array = reinterpret_cast<const unsigned char*>(der_string.data()); ScopedX509 x509(d2i_X509(NULL, &der_string_array, der_string.length())); - if (!x509) - return ERR_UNEXPECTED; + CHECK(x509); - // On success, SSL_use_certificate acquires a reference to |x509|. - if (SSL_use_certificate(ssl_, x509.get()) != 1) { - LOG(ERROR) << "Cannot set certificate."; - return ERR_UNEXPECTED; - } + // On success, SSL_CTX_use_certificate acquires a reference to |x509|. + CHECK(SSL_CTX_use_certificate(ssl_ctx_.get(), x509.get())); #endif // USE_OPENSSL_CERTS DCHECK(key_->key()); - if (SSL_use_PrivateKey(ssl_, key_->key()) != 1) { - LOG(ERROR) << "Cannot set private key."; - return ERR_UNEXPECTED; - } + CHECK(SSL_CTX_use_PrivateKey(ssl_ctx_.get(), key_->key())); - DCHECK_LT(SSL3_VERSION, ssl_config_.version_min); - DCHECK_LT(SSL3_VERSION, ssl_config_.version_max); - SSL_set_min_version(ssl_, ssl_config_.version_min); - SSL_set_max_version(ssl_, ssl_config_.version_max); + DCHECK_LT(SSL3_VERSION, ssl_server_config_.version_min); + DCHECK_LT(SSL3_VERSION, ssl_server_config_.version_max); + SSL_CTX_set_min_version(ssl_ctx_.get(), ssl_server_config_.version_min); + SSL_CTX_set_max_version(ssl_ctx_.get(), ssl_server_config_.version_max); // OpenSSL defaults some options to on, others to off. To avoid ambiguity, // set everything we care about to an absolute value. SslSetClearMask options; options.ConfigureFlag(SSL_OP_NO_COMPRESSION, true); - SSL_set_options(ssl_, options.set_mask); - SSL_clear_options(ssl_, options.clear_mask); + SSL_CTX_set_options(ssl_ctx_.get(), options.set_mask); + SSL_CTX_clear_options(ssl_ctx_.get(), options.clear_mask); // Same as above, this time for the SSL mode. SslSetClearMask mode; mode.ConfigureFlag(SSL_MODE_RELEASE_BUFFERS, true); - SSL_set_mode(ssl_, mode.set_mask); - SSL_clear_mode(ssl_, mode.clear_mask); + SSL_CTX_set_mode(ssl_ctx_.get(), mode.set_mask); + SSL_CTX_clear_mode(ssl_ctx_.get(), mode.clear_mask); // See SSLServerConfig::disabled_cipher_suites for description of the suites // disabled by default. Note that !SHA256 and !SHA384 only remove HMAC-SHA256 @@ -693,11 +907,11 @@ int SSLServerSocketOpenSSL::Init() { // as the handshake hash. std::string command("DEFAULT:!SHA256:!SHA384:!AESGCM+AES256:!aPSK"); - if (ssl_config_.require_ecdhe) + if (ssl_server_config_.require_ecdhe) command.append(":!kRSA:!kDHE"); // Remove any disabled ciphers. - for (uint16_t id : ssl_config_.disabled_cipher_suites) { + for (uint16_t id : ssl_server_config_.disabled_cipher_suites) { const SSL_CIPHER* cipher = SSL_get_cipher_by_value(id); if (cipher) { command.append(":!"); @@ -705,14 +919,39 @@ int SSLServerSocketOpenSSL::Init() { } } - int rv = SSL_set_cipher_list(ssl_, command.c_str()); + int rv = SSL_CTX_set_cipher_list(ssl_ctx_.get(), command.c_str()); // If this fails (rv = 0) it means there are no ciphers enabled on this SSL. // This will almost certainly result in the socket failing to complete the // handshake at which point the appropriate error is bubbled up to the client. LOG_IF(WARNING, rv != 1) << "SSL_set_cipher_list('" << command << "') returned " << rv; - return OK; + if (ssl_server_config_.client_cert_type != + SSLServerConfig::ClientCertType::NO_CLIENT_CERT && + !ssl_server_config_.cert_authorities_.empty()) { + ScopedX509NameStack stack(sk_X509_NAME_new_null()); + for (const auto& authority : ssl_server_config_.cert_authorities_) { + const uint8_t* name = reinterpret_cast<const uint8_t*>(authority.c_str()); + const uint8_t* name_start = name; + ScopedX509_NAME subj(d2i_X509_NAME(nullptr, &name, authority.length())); + CHECK(subj && name == name_start + authority.length()); + sk_X509_NAME_push(stack.get(), subj.release()); + } + SSL_CTX_set_client_CA_list(ssl_ctx_.get(), stack.release()); + } +} + +SSLServerContextOpenSSL::~SSLServerContextOpenSSL() {} + +scoped_ptr<SSLServerSocket> SSLServerContextOpenSSL::CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket) { + SSL* ssl = SSL_new(ssl_ctx_.get()); + return scoped_ptr<SSLServerSocket>( + new SSLServerSocketOpenSSL(std::move(socket), ssl)); +} + +void EnableSSLServerSockets() { + // No-op because CreateSSLServerSocket() calls crypto::EnsureOpenSSLInit(). } } // namespace net diff --git a/chromium/net/socket/ssl_server_socket_openssl.h b/chromium/net/socket/ssl_server_socket_openssl.h index fd7824970fc..3a9d9c85fc1 100644 --- a/chromium/net/socket/ssl_server_socket_openssl.h +++ b/chromium/net/socket/ssl_server_socket_openssl.h @@ -13,6 +13,7 @@ #include "net/base/io_buffer.h" #include "net/log/net_log.h" #include "net/socket/ssl_server_socket.h" +#include "net/ssl/scoped_openssl_types.h" #include "net/ssl/ssl_server_config.h" // Avoid including misc OpenSSL headers, i.e.: @@ -20,138 +21,33 @@ typedef struct bio_st BIO; // <openssl/ssl.h> typedef struct ssl_st SSL; +typedef struct x509_store_ctx_st X509_STORE_CTX; namespace net { class SSLInfo; -class SSLServerSocketOpenSSL : public SSLServerSocket { +class SSLServerContextOpenSSL : public SSLServerContext { public: - // See comments on CreateSSLServerSocket for details of how these - // parameters are used. - SSLServerSocketOpenSSL(scoped_ptr<StreamSocket> socket, - scoped_refptr<X509Certificate> certificate, - const crypto::RSAPrivateKey& key, - const SSLServerConfig& ssl_config); - ~SSLServerSocketOpenSSL() override; + SSLServerContextOpenSSL(X509Certificate* certificate, + const crypto::RSAPrivateKey& key, + const SSLServerConfig& ssl_server_config); + ~SSLServerContextOpenSSL() override; - // SSLServerSocket interface. - int Handshake(const CompletionCallback& callback) override; - - // SSLSocket interface. - int ExportKeyingMaterial(const base::StringPiece& label, - bool has_context, - const base::StringPiece& context, - unsigned char* out, - unsigned int outlen) override; - int GetTLSUniqueChannelBinding(std::string* out) override; - - // Socket interface (via StreamSocket). - int Read(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) override; - int Write(IOBuffer* buf, - int buf_len, - const CompletionCallback& callback) override; - int SetReceiveBufferSize(int32_t size) override; - int SetSendBufferSize(int32_t size) override; - - // StreamSocket implementation. - int Connect(const CompletionCallback& callback) override; - void Disconnect() override; - bool IsConnected() const override; - bool IsConnectedAndIdle() const override; - int GetPeerAddress(IPEndPoint* address) const override; - int GetLocalAddress(IPEndPoint* address) const override; - const BoundNetLog& NetLog() const override; - void SetSubresourceSpeculation() override; - void SetOmniboxSpeculation() override; - bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; - bool WasNpnNegotiated() const override; - NextProto GetNegotiatedProtocol() const override; - bool GetSSLInfo(SSLInfo* ssl_info) override; - void GetConnectionAttempts(ConnectionAttempts* out) const override; - void ClearConnectionAttempts() override {} - void AddConnectionAttempts(const ConnectionAttempts& attempts) override {} - int64_t GetTotalReceivedBytes() const override; + scoped_ptr<SSLServerSocket> CreateSSLServerSocket( + scoped_ptr<StreamSocket> socket) override; private: - enum State { - STATE_NONE, - STATE_HANDSHAKE, - }; - - void OnSendComplete(int result); - void OnRecvComplete(int result); - void OnHandshakeIOComplete(int result); - - int BufferSend(); - void BufferSendComplete(int result); - void TransportWriteComplete(int result); - int BufferRecv(); - void BufferRecvComplete(int result); - int TransportReadComplete(int result); - bool DoTransportIO(); - int DoPayloadRead(); - int DoPayloadWrite(); - - int DoHandshakeLoop(int last_io_result); - int DoReadLoop(int result); - int DoWriteLoop(int result); - int DoHandshake(); - void DoHandshakeCallback(int result); - void DoReadCallback(int result); - void DoWriteCallback(int result); - - int Init(); - - // Members used to send and receive buffer. - bool transport_send_busy_; - bool transport_recv_busy_; - bool transport_recv_eof_; - - scoped_refptr<DrainableIOBuffer> send_buffer_; - scoped_refptr<IOBuffer> recv_buffer_; - - BoundNetLog net_log_; - - CompletionCallback user_handshake_callback_; - CompletionCallback user_read_callback_; - CompletionCallback user_write_callback_; - - // Used by Read function. - scoped_refptr<IOBuffer> user_read_buf_; - int user_read_buf_len_; - - // Used by Write function. - scoped_refptr<IOBuffer> user_write_buf_; - int user_write_buf_len_; - - // Used by TransportWriteComplete() and TransportReadComplete() to signify an - // error writing to the transport socket. A value of OK indicates no error. - int transport_write_error_; - - // OpenSSL stuff - SSL* ssl_; - BIO* transport_bio_; - - // StreamSocket for sending and receiving data. - scoped_ptr<StreamSocket> transport_socket_; + ScopedSSL_CTX ssl_ctx_; // Options for the SSL socket. - SSLServerConfig ssl_config_; + SSLServerConfig ssl_server_config_; // Certificate for the server. scoped_refptr<X509Certificate> cert_; // Private key used by the server. scoped_ptr<crypto::RSAPrivateKey> key_; - - State next_handshake_state_; - bool completed_handshake_; - - DISALLOW_COPY_AND_ASSIGN(SSLServerSocketOpenSSL); }; } // namespace net diff --git a/chromium/net/socket/ssl_server_socket_unittest.cc b/chromium/net/socket/ssl_server_socket_unittest.cc index ac2d44ec413..7327586d37c 100644 --- a/chromium/net/socket/ssl_server_socket_unittest.cc +++ b/chromium/net/socket/ssl_server_socket_unittest.cc @@ -20,6 +20,7 @@ #include <queue> #include <utility> +#include "base/callback_helpers.h" #include "base/compiler_specific.h" #include "base/files/file_path.h" #include "base/files/file_util.h" @@ -29,17 +30,22 @@ #include "base/message_loop/message_loop.h" #include "base/single_thread_task_runner.h" #include "base/thread_task_runner_handle.h" +#include "build/build_config.h" #include "crypto/nss_util.h" #include "crypto/rsa_private_key.h" +#include "crypto/scoped_openssl_types.h" +#include "crypto/signature_creator.h" #include "net/base/address_list.h" #include "net/base/completion_callback.h" #include "net/base/host_port_pair.h" #include "net/base/io_buffer.h" +#include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/test_data_directory.h" #include "net/cert/cert_status_flags.h" #include "net/cert/mock_cert_verifier.h" +#include "net/cert/mock_client_cert_verifier.h" #include "net/cert/x509_certificate.h" #include "net/http/transport_security_state.h" #include "net/log/net_log.h" @@ -47,18 +53,34 @@ #include "net/socket/socket_test_util.h" #include "net/socket/ssl_client_socket.h" #include "net/socket/stream_socket.h" +#include "net/ssl/scoped_openssl_types.h" +#include "net/ssl/ssl_cert_request_info.h" #include "net/ssl/ssl_cipher_suite_names.h" #include "net/ssl/ssl_connection_status_flags.h" #include "net/ssl/ssl_info.h" +#include "net/ssl/ssl_private_key.h" #include "net/ssl/ssl_server_config.h" +#include "net/ssl/test_ssl_private_key.h" #include "net/test/cert_test_util.h" #include "testing/gtest/include/gtest/gtest.h" #include "testing/platform_test.h" +#if defined(USE_OPENSSL) +#include <openssl/evp.h> +#include <openssl/ssl.h> +#include <openssl/x509.h> +#endif + namespace net { namespace { +const char kClientCertFileName[] = "client_1.pem"; +const char kClientPrivateKeyFileName[] = "client_1.pk8"; +const char kWrongClientCertFileName[] = "client_2.pem"; +const char kWrongClientPrivateKeyFileName[] = "client_2.pk8"; +const char kClientCertCAFileName[] = "client_1_ca.pem"; + class FakeDataChannel { public: FakeDataChannel() @@ -110,11 +132,24 @@ class FakeDataChannel { // asynchronously, which is necessary to reproduce bug 127822. void Close() { closed_ = true; + if (!read_callback_.is_null()) { + base::ThreadTaskRunnerHandle::Get()->PostTask( + FROM_HERE, base::Bind(&FakeDataChannel::DoReadCallback, + weak_factory_.GetWeakPtr())); + } } private: void DoReadCallback() { - if (read_callback_.is_null() || data_.empty()) + if (read_callback_.is_null()) + return; + + if (closed_) { + base::ResetAndReturn(&read_callback_).Run(ERR_CONNECTION_CLOSED); + return; + } + + if (data_.empty()) return; int copied = PropagateData(read_buf_, read_buf_len_); @@ -170,9 +205,7 @@ class FakeSocket : public StreamSocket { public: FakeSocket(FakeDataChannel* incoming_channel, FakeDataChannel* outgoing_channel) - : incoming_(incoming_channel), - outgoing_(outgoing_channel) { - } + : incoming_(incoming_channel), outgoing_(outgoing_channel) {} ~FakeSocket() override {} @@ -208,14 +241,12 @@ class FakeSocket : public StreamSocket { bool IsConnectedAndIdle() const override { return true; } int GetPeerAddress(IPEndPoint* address) const override { - IPAddressNumber ip_address(kIPv4AddressSize); - *address = IPEndPoint(ip_address, 0 /*port*/); + *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/); return OK; } int GetLocalAddress(IPEndPoint* address) const override { - IPAddressNumber ip_address(4); - *address = IPEndPoint(ip_address, 0); + *address = IPEndPoint(IPAddress::IPv4AllZeros(), 0 /*port*/); return OK; } @@ -226,8 +257,6 @@ class FakeSocket : public StreamSocket { bool WasEverUsed() const override { return true; } - bool UsingTCPFastOpen() const override { return false; } - bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } @@ -302,37 +331,20 @@ class SSLServerSocketTest : public PlatformTest { SSLServerSocketTest() : socket_factory_(ClientSocketFactory::GetDefaultFactory()), cert_verifier_(new MockCertVerifier()), - transport_security_state_(new TransportSecurityState) { - cert_verifier_->set_default_result(CERT_STATUS_AUTHORITY_INVALID); - } + client_cert_verifier_(new MockClientCertVerifier()), + transport_security_state_(new TransportSecurityState) {} - protected: - void Initialize() { - scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); - client_connection->SetSocket( - scoped_ptr<StreamSocket>(new FakeSocket(&channel_1_, &channel_2_))); - scoped_ptr<StreamSocket> server_socket( - new FakeSocket(&channel_2_, &channel_1_)); - - base::FilePath certs_dir(GetTestCertsDirectory()); + void SetUp() override { + PlatformTest::SetUp(); - base::FilePath cert_path = certs_dir.AppendASCII("unittest.selfsigned.der"); - std::string cert_der; - ASSERT_TRUE(base::ReadFileToString(cert_path, &cert_der)); - - scoped_refptr<X509Certificate> cert = - X509Certificate::CreateFromBytes(cert_der.data(), cert_der.size()); - - base::FilePath key_path = certs_dir.AppendASCII("unittest.key.bin"); - std::string key_string; - ASSERT_TRUE(base::ReadFileToString(key_path, &key_string)); - std::vector<uint8_t> key_vector( - reinterpret_cast<const uint8_t*>(key_string.data()), - reinterpret_cast<const uint8_t*>(key_string.data() + - key_string.length())); + cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID); + client_cert_verifier_->set_default_result(ERR_CERT_AUTHORITY_INVALID); - scoped_ptr<crypto::RSAPrivateKey> private_key( - crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); + server_cert_ = + ImportCertFromFile(GetTestCertsDirectory(), "unittest.selfsigned.der"); + ASSERT_TRUE(server_cert_); + server_private_key_ = ReadTestKey("unittest.key.bin"); + ASSERT_TRUE(server_private_key_); client_ssl_config_.false_start_enabled = false; client_ssl_config_.channel_id_enabled = false; @@ -340,59 +352,152 @@ class SSLServerSocketTest : public PlatformTest { // Certificate provided by the host doesn't need authority. SSLConfig::CertAndStatus cert_and_status; cert_and_status.cert_status = CERT_STATUS_AUTHORITY_INVALID; - cert_and_status.der_cert = cert_der; + std::string server_cert_der; + ASSERT_TRUE(X509Certificate::GetDEREncoded(server_cert_->os_cert_handle(), + &server_cert_der)); + cert_and_status.der_cert = server_cert_der; client_ssl_config_.allowed_bad_certs.push_back(cert_and_status); + } + + protected: + void CreateContext() { + client_socket_.reset(); + server_socket_.reset(); + channel_1_.reset(); + channel_2_.reset(); + server_context_.reset(); + server_context_ = CreateSSLServerContext( + server_cert_.get(), *server_private_key_, server_ssl_config_); + } + + void CreateSockets() { + client_socket_.reset(); + server_socket_.reset(); + channel_1_.reset(new FakeDataChannel()); + channel_2_.reset(new FakeDataChannel()); + scoped_ptr<ClientSocketHandle> client_connection(new ClientSocketHandle); + client_connection->SetSocket(scoped_ptr<StreamSocket>( + new FakeSocket(channel_1_.get(), channel_2_.get()))); + scoped_ptr<StreamSocket> server_socket( + new FakeSocket(channel_2_.get(), channel_1_.get())); HostPortPair host_and_pair("unittest", 0); SSLClientSocketContext context; context.cert_verifier = cert_verifier_.get(); context.transport_security_state = transport_security_state_.get(); + client_socket_ = socket_factory_->CreateSSLClientSocket( std::move(client_connection), host_and_pair, client_ssl_config_, context); - server_socket_ = CreateSSLServerSocket(std::move(server_socket), cert.get(), - *private_key, server_ssl_config_); + ASSERT_TRUE(client_socket_); + + server_socket_ = + server_context_->CreateSSLServerSocket(std::move(server_socket)); + ASSERT_TRUE(server_socket_); } - FakeDataChannel channel_1_; - FakeDataChannel channel_2_; +#if defined(USE_OPENSSL) + void ConfigureClientCertsForClient(const char* cert_file_name, + const char* private_key_file_name) { + client_ssl_config_.send_client_cert = true; + client_ssl_config_.client_cert = + ImportCertFromFile(GetTestCertsDirectory(), cert_file_name); + ASSERT_TRUE(client_ssl_config_.client_cert); + + scoped_ptr<crypto::RSAPrivateKey> key = ReadTestKey(private_key_file_name); + ASSERT_TRUE(key); + + client_ssl_config_.client_private_key = WrapOpenSSLPrivateKey( + crypto::ScopedEVP_PKEY(EVP_PKEY_up_ref(key->key()))); + } + + void ConfigureClientCertsForServer() { + server_ssl_config_.client_cert_type = + SSLServerConfig::ClientCertType::REQUIRE_CLIENT_CERT; + + ScopedX509NameStack cert_names( + SSL_load_client_CA_file(GetTestCertsDirectory() + .AppendASCII(kClientCertCAFileName) + .MaybeAsASCII() + .c_str())); + ASSERT_TRUE(cert_names); + + for (size_t i = 0; i < sk_X509_NAME_num(cert_names.get()); ++i) { + uint8_t* str = nullptr; + int length = i2d_X509_NAME(sk_X509_NAME_value(cert_names.get(), i), &str); + ASSERT_LT(0, length); + + server_ssl_config_.cert_authorities_.push_back(std::string( + reinterpret_cast<const char*>(str), static_cast<size_t>(length))); + OPENSSL_free(str); + } + + scoped_refptr<X509Certificate> expected_client_cert( + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName)); + ASSERT_TRUE(expected_client_cert); + + client_cert_verifier_->AddResultForCert(expected_client_cert.get(), OK); + + server_ssl_config_.client_cert_verifier = client_cert_verifier_.get(); + } + + scoped_ptr<crypto::RSAPrivateKey> ReadTestKey(const base::StringPiece& name) { + base::FilePath certs_dir(GetTestCertsDirectory()); + base::FilePath key_path = certs_dir.AppendASCII(name); + std::string key_string; + if (!base::ReadFileToString(key_path, &key_string)) + return nullptr; + std::vector<uint8_t> key_vector( + reinterpret_cast<const uint8_t*>(key_string.data()), + reinterpret_cast<const uint8_t*>(key_string.data() + + key_string.length())); + scoped_ptr<crypto::RSAPrivateKey> key( + crypto::RSAPrivateKey::CreateFromPrivateKeyInfo(key_vector)); + return key; + } +#endif + + scoped_ptr<FakeDataChannel> channel_1_; + scoped_ptr<FakeDataChannel> channel_2_; SSLConfig client_ssl_config_; SSLServerConfig server_ssl_config_; scoped_ptr<SSLClientSocket> client_socket_; scoped_ptr<SSLServerSocket> server_socket_; ClientSocketFactory* socket_factory_; scoped_ptr<MockCertVerifier> cert_verifier_; + scoped_ptr<MockClientCertVerifier> client_cert_verifier_; scoped_ptr<TransportSecurityState> transport_security_state_; + scoped_ptr<SSLServerContext> server_context_; + scoped_ptr<crypto::RSAPrivateKey> server_private_key_; + scoped_refptr<X509Certificate> server_cert_; }; // This test only executes creation of client and server sockets. This is to // test that creation of sockets doesn't crash and have minimal code to run // under valgrind in order to help debugging memory problems. TEST_F(SSLServerSocketTest, Initialize) { - Initialize(); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); } // This test executes Connect() on SSLClientSocket and Handshake() on // SSLServerSocket to make sure handshaking between the two sockets is // completed successfully. TEST_F(SSLServerSocketTest, Handshake) { - Initialize(); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); - TestCompletionCallback connect_callback; TestCompletionCallback handshake_callback; - int server_ret = server_socket_->Handshake(handshake_callback.callback()); - EXPECT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); + TestCompletionCallback connect_callback; int client_ret = client_socket_->Connect(connect_callback.callback()); - EXPECT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); - if (client_ret == ERR_IO_PENDING) { - EXPECT_EQ(OK, connect_callback.WaitForResult()); - } - if (server_ret == ERR_IO_PENDING) { - EXPECT_EQ(OK, handshake_callback.WaitForResult()); - } + client_ret = connect_callback.GetResult(client_ret); + server_ret = handshake_callback.GetResult(server_ret); + + ASSERT_EQ(OK, client_ret); + ASSERT_EQ(OK, server_ret); // Make sure the cert status is expected. SSLInfo ssl_info; @@ -412,16 +517,363 @@ TEST_F(SSLServerSocketTest, Handshake) { EXPECT_TRUE(is_aead); } -TEST_F(SSLServerSocketTest, DataTransfer) { - Initialize(); +// NSS ports don't support client certificates and have a global session cache. +#if defined(USE_OPENSSL) + +// This test makes sure the session cache is working. +TEST_F(SSLServerSocketTest, HandshakeCached) { + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); TestCompletionCallback connect_callback; + int client_ret = client_socket_->Connect(connect_callback.callback()); + + client_ret = connect_callback.GetResult(client_ret); + server_ret = handshake_callback.GetResult(server_ret); + + ASSERT_EQ(OK, client_ret); + ASSERT_EQ(OK, server_ret); + + // Make sure the cert status is expected. + SSLInfo ssl_info; + ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info)); + EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL); + SSLInfo ssl_server_info; + ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info)); + EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL); + + // Make sure the second connection is cached. + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + TestCompletionCallback handshake_callback2; + int server_ret2 = server_socket_->Handshake(handshake_callback2.callback()); + + TestCompletionCallback connect_callback2; + int client_ret2 = client_socket_->Connect(connect_callback2.callback()); + + client_ret2 = connect_callback2.GetResult(client_ret2); + server_ret2 = handshake_callback2.GetResult(server_ret2); + + ASSERT_EQ(OK, client_ret2); + ASSERT_EQ(OK, server_ret2); + + // Make sure the cert status is expected. + SSLInfo ssl_info2; + ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2)); + EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME); + SSLInfo ssl_server_info2; + ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2)); + EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME); +} + +// This test makes sure the session cache separates out by server context. +TEST_F(SSLServerSocketTest, HandshakeCachedContextSwitch) { + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + + TestCompletionCallback connect_callback; + int client_ret = client_socket_->Connect(connect_callback.callback()); + + client_ret = connect_callback.GetResult(client_ret); + server_ret = handshake_callback.GetResult(server_ret); + + ASSERT_EQ(OK, client_ret); + ASSERT_EQ(OK, server_ret); + + // Make sure the cert status is expected. + SSLInfo ssl_info; + ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info)); + EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL); + SSLInfo ssl_server_info; + ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info)); + EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL); + + // Make sure the second connection is NOT cached when using a new context. + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + + TestCompletionCallback handshake_callback2; + int server_ret2 = server_socket_->Handshake(handshake_callback2.callback()); + + TestCompletionCallback connect_callback2; + int client_ret2 = client_socket_->Connect(connect_callback2.callback()); + + client_ret2 = connect_callback2.GetResult(client_ret2); + server_ret2 = handshake_callback2.GetResult(server_ret2); + + ASSERT_EQ(OK, client_ret2); + ASSERT_EQ(OK, server_ret2); + + // Make sure the cert status is expected. + SSLInfo ssl_info2; + ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2)); + EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_FULL); + SSLInfo ssl_server_info2; + ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2)); + EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_FULL); +} + +// This test executes Connect() on SSLClientSocket and Handshake() on +// SSLServerSocket to make sure handshaking between the two sockets is +// completed successfully, using client certificate. +TEST_F(SSLServerSocketTest, HandshakeWithClientCert) { + scoped_refptr<X509Certificate> client_cert = + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient( + kClientCertFileName, kClientPrivateKeyFileName)); + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer()); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + + TestCompletionCallback connect_callback; + int client_ret = client_socket_->Connect(connect_callback.callback()); + + client_ret = connect_callback.GetResult(client_ret); + server_ret = handshake_callback.GetResult(server_ret); + + ASSERT_EQ(OK, client_ret); + ASSERT_EQ(OK, server_ret); + + // Make sure the cert status is expected. + SSLInfo ssl_info; + client_socket_->GetSSLInfo(&ssl_info); + EXPECT_EQ(CERT_STATUS_AUTHORITY_INVALID, ssl_info.cert_status); + server_socket_->GetSSLInfo(&ssl_info); + ASSERT_TRUE(ssl_info.cert.get()); + EXPECT_TRUE(client_cert->Equals(ssl_info.cert.get())); +} + +// This test executes Connect() on SSLClientSocket and Handshake() twice on +// SSLServerSocket to make sure handshaking between the two sockets is +// completed successfully, using client certificate. The second connection is +// expected to succeed through the session cache. +TEST_F(SSLServerSocketTest, HandshakeWithClientCertCached) { + scoped_refptr<X509Certificate> client_cert = + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient( + kClientCertFileName, kClientPrivateKeyFileName)); + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer()); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + + TestCompletionCallback connect_callback; + int client_ret = client_socket_->Connect(connect_callback.callback()); + + client_ret = connect_callback.GetResult(client_ret); + server_ret = handshake_callback.GetResult(server_ret); + + ASSERT_EQ(OK, client_ret); + ASSERT_EQ(OK, server_ret); + + // Make sure the cert status is expected. + SSLInfo ssl_info; + ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info)); + EXPECT_EQ(ssl_info.handshake_type, SSLInfo::HANDSHAKE_FULL); + SSLInfo ssl_server_info; + ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info)); + ASSERT_TRUE(ssl_server_info.cert.get()); + EXPECT_TRUE(client_cert->Equals(ssl_server_info.cert.get())); + EXPECT_EQ(ssl_server_info.handshake_type, SSLInfo::HANDSHAKE_FULL); + server_socket_->Disconnect(); + client_socket_->Disconnect(); + + // Create the connection again. + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + TestCompletionCallback handshake_callback2; + int server_ret2 = server_socket_->Handshake(handshake_callback2.callback()); + + TestCompletionCallback connect_callback2; + int client_ret2 = client_socket_->Connect(connect_callback2.callback()); + + client_ret2 = connect_callback2.GetResult(client_ret2); + server_ret2 = handshake_callback2.GetResult(server_ret2); + + ASSERT_EQ(OK, client_ret2); + ASSERT_EQ(OK, server_ret2); + + // Make sure the cert status is expected. + SSLInfo ssl_info2; + ASSERT_TRUE(client_socket_->GetSSLInfo(&ssl_info2)); + EXPECT_EQ(ssl_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME); + SSLInfo ssl_server_info2; + ASSERT_TRUE(server_socket_->GetSSLInfo(&ssl_server_info2)); + ASSERT_TRUE(ssl_server_info2.cert.get()); + EXPECT_TRUE(client_cert->Equals(ssl_server_info2.cert.get())); + EXPECT_EQ(ssl_server_info2.handshake_type, SSLInfo::HANDSHAKE_RESUME); +} + +TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSupplied) { + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer()); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + // Use the default setting for the client socket, which is to not send + // a client certificate. This will cause the client to receive an + // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the + // requested cert_authorities from the CertificateRequest sent by the + // server. + + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + + TestCompletionCallback connect_callback; + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, + connect_callback.GetResult( + client_socket_->Connect(connect_callback.callback()))); + + scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); + client_socket_->GetSSLCertRequestInfo(request_info.get()); + + // Check that the authority name that arrived in the CertificateRequest + // handshake message is as expected. + scoped_refptr<X509Certificate> client_cert = + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); + ASSERT_TRUE(client_cert); + EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities)); + + client_socket_->Disconnect(); + + EXPECT_EQ(ERR_FAILED, handshake_callback.GetResult(server_ret)); +} + +TEST_F(SSLServerSocketTest, HandshakeWithClientCertRequiredNotSuppliedCached) { + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer()); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + // Use the default setting for the client socket, which is to not send + // a client certificate. This will cause the client to receive an + // ERR_SSL_CLIENT_AUTH_CERT_NEEDED error, and allow for inspecting the + // requested cert_authorities from the CertificateRequest sent by the + // server. + + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + + TestCompletionCallback connect_callback; + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, + connect_callback.GetResult( + client_socket_->Connect(connect_callback.callback()))); + + scoped_refptr<SSLCertRequestInfo> request_info = new SSLCertRequestInfo(); + client_socket_->GetSSLCertRequestInfo(request_info.get()); + + // Check that the authority name that arrived in the CertificateRequest + // handshake message is as expected. + scoped_refptr<X509Certificate> client_cert = + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); + ASSERT_TRUE(client_cert); + EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info->cert_authorities)); + + client_socket_->Disconnect(); + + EXPECT_EQ(ERR_FAILED, handshake_callback.GetResult(server_ret)); + server_socket_->Disconnect(); + + // Below, check that the cache didn't store the result of a failed handshake. + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + TestCompletionCallback handshake_callback2; + int server_ret2 = server_socket_->Handshake(handshake_callback2.callback()); + + TestCompletionCallback connect_callback2; + EXPECT_EQ(ERR_SSL_CLIENT_AUTH_CERT_NEEDED, + connect_callback2.GetResult( + client_socket_->Connect(connect_callback2.callback()))); + + scoped_refptr<SSLCertRequestInfo> request_info2 = new SSLCertRequestInfo(); + client_socket_->GetSSLCertRequestInfo(request_info2.get()); + + // Check that the authority name that arrived in the CertificateRequest + // handshake message is as expected. + EXPECT_TRUE(client_cert->IsIssuedByEncoded(request_info2->cert_authorities)); + + client_socket_->Disconnect(); + + EXPECT_EQ(ERR_FAILED, handshake_callback2.GetResult(server_ret2)); +} + +TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSupplied) { + scoped_refptr<X509Certificate> client_cert = + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); + ASSERT_TRUE(client_cert); + + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient( + kWrongClientCertFileName, kWrongClientPrivateKeyFileName)); + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer()); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + + TestCompletionCallback connect_callback; + int client_ret = client_socket_->Connect(connect_callback.callback()); + + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, + connect_callback.GetResult(client_ret)); + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, + handshake_callback.GetResult(server_ret)); +} + +TEST_F(SSLServerSocketTest, HandshakeWithWrongClientCertSuppliedCached) { + scoped_refptr<X509Certificate> client_cert = + ImportCertFromFile(GetTestCertsDirectory(), kClientCertFileName); + ASSERT_TRUE(client_cert); + + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForClient( + kWrongClientCertFileName, kWrongClientPrivateKeyFileName)); + ASSERT_NO_FATAL_FAILURE(ConfigureClientCertsForServer()); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + + TestCompletionCallback handshake_callback; + int server_ret = server_socket_->Handshake(handshake_callback.callback()); + + TestCompletionCallback connect_callback; + int client_ret = client_socket_->Connect(connect_callback.callback()); + + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, + connect_callback.GetResult(client_ret)); + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, + handshake_callback.GetResult(server_ret)); + + client_socket_->Disconnect(); + server_socket_->Disconnect(); + + // Below, check that the cache didn't store the result of a failed handshake. + ASSERT_NO_FATAL_FAILURE(CreateSockets()); + TestCompletionCallback handshake_callback2; + int server_ret2 = server_socket_->Handshake(handshake_callback2.callback()); + + TestCompletionCallback connect_callback2; + int client_ret2 = client_socket_->Connect(connect_callback2.callback()); + + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, + connect_callback2.GetResult(client_ret2)); + EXPECT_EQ(ERR_BAD_SSL_CLIENT_AUTH_CERT, + handshake_callback2.GetResult(server_ret2)); +} +#endif // defined(USE_OPENSSL) + +TEST_F(SSLServerSocketTest, DataTransfer) { + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); // Establish connection. + TestCompletionCallback connect_callback; int client_ret = client_socket_->Connect(connect_callback.callback()); ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); + TestCompletionCallback handshake_callback; int server_ret = server_socket_->Handshake(handshake_callback.callback()); ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); @@ -439,8 +891,8 @@ TEST_F(SSLServerSocketTest, DataTransfer) { // Write then read. TestCompletionCallback write_callback; TestCompletionCallback read_callback; - server_ret = server_socket_->Write( - write_buf.get(), write_buf->size(), write_callback.callback()); + server_ret = server_socket_->Write(write_buf.get(), write_buf->size(), + write_callback.callback()); EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); client_ret = client_socket_->Read( read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); @@ -469,8 +921,8 @@ TEST_F(SSLServerSocketTest, DataTransfer) { server_ret = server_socket_->Read( read_buf.get(), read_buf->BytesRemaining(), read_callback.callback()); EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); - client_ret = client_socket_->Write( - write_buf.get(), write_buf->size(), write_callback.callback()); + client_ret = client_socket_->Write(write_buf.get(), write_buf->size(), + write_callback.callback()); EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); server_ret = read_callback.GetResult(server_ret); @@ -497,15 +949,15 @@ TEST_F(SSLServerSocketTest, DataTransfer) { // the client's Write() call should not cause an infinite loop. // NOTE: this is a test for SSLClientSocket rather than SSLServerSocket. TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { - Initialize(); - - TestCompletionCallback connect_callback; - TestCompletionCallback handshake_callback; + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); // Establish connection. + TestCompletionCallback connect_callback; int client_ret = client_socket_->Connect(connect_callback.callback()); ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); + TestCompletionCallback handshake_callback; int server_ret = server_socket_->Handshake(handshake_callback.callback()); ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); @@ -521,9 +973,8 @@ TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { // socket won't return ERR_IO_PENDING. This ensures that the client // will call Read() on the transport socket again. TestCompletionCallback write_callback; - - server_ret = server_socket_->Write( - write_buf.get(), write_buf->size(), write_callback.callback()); + server_ret = server_socket_->Write(write_buf.get(), write_buf->size(), + write_callback.callback()); EXPECT_TRUE(server_ret > 0 || server_ret == ERR_IO_PENDING); server_ret = write_callback.GetResult(server_ret); @@ -532,8 +983,8 @@ TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { server_socket_->Disconnect(); // The client writes some data. This should not cause an infinite loop. - client_ret = client_socket_->Write( - write_buf.get(), write_buf->size(), write_callback.callback()); + client_ret = client_socket_->Write(write_buf.get(), write_buf->size(), + write_callback.callback()); EXPECT_TRUE(client_ret > 0 || client_ret == ERR_IO_PENDING); client_ret = write_callback.GetResult(client_ret); @@ -549,14 +1000,14 @@ TEST_F(SSLServerSocketTest, ClientWriteAfterServerClose) { // after connecting them, and verifies that the results match. // This test will fail if False Start is enabled (see crbug.com/90208). TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { - Initialize(); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); TestCompletionCallback connect_callback; - TestCompletionCallback handshake_callback; - int client_ret = client_socket_->Connect(connect_callback.callback()); ASSERT_TRUE(client_ret == OK || client_ret == ERR_IO_PENDING); + TestCompletionCallback handshake_callback; int server_ret = server_socket_->Handshake(handshake_callback.callback()); ASSERT_TRUE(server_ret == OK || server_ret == ERR_IO_PENDING); @@ -571,23 +1022,20 @@ TEST_F(SSLServerSocketTest, ExportKeyingMaterial) { const char kKeyingLabel[] = "EXPERIMENTAL-server-socket-test"; const char kKeyingContext[] = ""; unsigned char server_out[kKeyingMaterialSize]; - int rv = server_socket_->ExportKeyingMaterial(kKeyingLabel, - false, kKeyingContext, - server_out, sizeof(server_out)); + int rv = server_socket_->ExportKeyingMaterial( + kKeyingLabel, false, kKeyingContext, server_out, sizeof(server_out)); ASSERT_EQ(OK, rv); unsigned char client_out[kKeyingMaterialSize]; - rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, - false, kKeyingContext, + rv = client_socket_->ExportKeyingMaterial(kKeyingLabel, false, kKeyingContext, client_out, sizeof(client_out)); ASSERT_EQ(OK, rv); EXPECT_EQ(0, memcmp(server_out, client_out, sizeof(server_out))); const char kKeyingLabelBad[] = "EXPERIMENTAL-server-socket-test-bad"; unsigned char client_bad[kKeyingMaterialSize]; - rv = client_socket_->ExportKeyingMaterial(kKeyingLabelBad, - false, kKeyingContext, - client_bad, sizeof(client_bad)); + rv = client_socket_->ExportKeyingMaterial( + kKeyingLabelBad, false, kKeyingContext, client_bad, sizeof(client_bad)); ASSERT_EQ(rv, OK); EXPECT_NE(0, memcmp(server_out, client_bad, sizeof(server_out))); } @@ -613,12 +1061,13 @@ TEST_F(SSLServerSocketTest, RequireEcdheFlag) { // Require ECDHE on the server. server_ssl_config_.require_ecdhe = true; - Initialize(); + ASSERT_NO_FATAL_FAILURE(CreateContext()); + ASSERT_NO_FATAL_FAILURE(CreateSockets()); TestCompletionCallback connect_callback; - TestCompletionCallback handshake_callback; - int client_ret = client_socket_->Connect(connect_callback.callback()); + + TestCompletionCallback handshake_callback; int server_ret = server_socket_->Handshake(handshake_callback.callback()); client_ret = connect_callback.GetResult(client_ret); diff --git a/chromium/net/socket/ssl_socket.h b/chromium/net/socket/ssl_socket.h index 5f60e800b5b..bd813c35fea 100644 --- a/chromium/net/socket/ssl_socket.h +++ b/chromium/net/socket/ssl_socket.h @@ -26,9 +26,6 @@ public: const base::StringPiece& context, unsigned char* out, unsigned int outlen) = 0; - - // Stores the the tls-unique channel binding (see RFC 5929) in |*out|. - virtual int GetTLSUniqueChannelBinding(std::string* out) = 0; }; } // namespace net diff --git a/chromium/net/socket/stream_socket.h b/chromium/net/socket/stream_socket.h index e2f9d3cd790..98efcdafb9b 100644 --- a/chromium/net/socket/stream_socket.h +++ b/chromium/net/socket/stream_socket.h @@ -79,11 +79,6 @@ class NET_EXPORT_PRIVATE StreamSocket : public Socket { // Write() methods had been called, not the underlying transport's. virtual bool WasEverUsed() const = 0; - // TODO(jri): Clean up -- remove this method. - // Returns true if the underlying transport socket is using TCP FastOpen. - // TCP FastOpen is an experiment with sending data in the TCP SYN packet. - virtual bool UsingTCPFastOpen() const = 0; - // TODO(jri): Clean up -- rename to a more general EnableAutoConnectOnWrite. // Enables use of TCP FastOpen for the underlying transport socket. virtual void EnableTCPFastOpenIfSupported() {} diff --git a/chromium/net/socket/tcp_client_socket.cc b/chromium/net/socket/tcp_client_socket.cc index 700fa1c1655..56238ba0d12 100644 --- a/chromium/net/socket/tcp_client_socket.cc +++ b/chromium/net/socket/tcp_client_socket.cc @@ -14,7 +14,6 @@ #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_util.h" namespace net { @@ -240,10 +239,6 @@ bool TCPClientSocket::WasEverUsed() const { return use_history_.was_used_to_convey_data(); } -bool TCPClientSocket::UsingTCPFastOpen() const { - return socket_->UsingTCPFastOpen(); -} - void TCPClientSocket::EnableTCPFastOpenIfSupported() { socket_->EnableTCPFastOpenIfSupported(); } diff --git a/chromium/net/socket/tcp_client_socket.h b/chromium/net/socket/tcp_client_socket.h index 73ee62bfeb2..ae6083fed6c 100644 --- a/chromium/net/socket/tcp_client_socket.h +++ b/chromium/net/socket/tcp_client_socket.h @@ -51,7 +51,6 @@ class NET_EXPORT TCPClientSocket : public StreamSocket { void SetSubresourceSpeculation() override; void SetOmniboxSpeculation() override; bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; void EnableTCPFastOpenIfSupported() override; bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; diff --git a/chromium/net/socket/tcp_client_socket_unittest.cc b/chromium/net/socket/tcp_client_socket_unittest.cc index ce0c53559f8..1c39719f59d 100644 --- a/chromium/net/socket/tcp_client_socket_unittest.cc +++ b/chromium/net/socket/tcp_client_socket_unittest.cc @@ -8,9 +8,9 @@ #include "net/socket/tcp_client_socket.h" +#include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" -#include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/socket/tcp_server_socket.h" #include "testing/gtest/include/gtest/gtest.h" @@ -22,8 +22,7 @@ namespace { // Try binding a socket to loopback interface and verify that we can // still connect to a server on the same interface. TEST(TCPClientSocketTest, BindLoopbackToLoopback) { - IPAddressNumber lo_address; - ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &lo_address)); + IPAddress lo_address = IPAddress::IPv4Localhost(); TCPServerSocket server(NULL, NetLog::Source()); ASSERT_EQ(OK, server.Listen(IPEndPoint(lo_address, 0), 1)); @@ -60,14 +59,11 @@ TEST(TCPClientSocketTest, BindLoopbackToLoopback) { // Try to bind socket to the loopback interface and connect to an // external address, verify that connection fails. TEST(TCPClientSocketTest, BindLoopbackToExternal) { - IPAddressNumber external_ip; - ASSERT_TRUE(ParseIPLiteralToNumber("72.14.213.105", &external_ip)); + IPAddress external_ip(72, 14, 213, 105); TCPClientSocket socket(AddressList::CreateFromIPAddress(external_ip, 80), NULL, NetLog::Source()); - IPAddressNumber lo_address; - ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &lo_address)); - EXPECT_EQ(OK, socket.Bind(IPEndPoint(lo_address, 0))); + EXPECT_EQ(OK, socket.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0))); TestCompletionCallback connect_callback; int result = socket.Connect(connect_callback.callback()); @@ -82,10 +78,9 @@ TEST(TCPClientSocketTest, BindLoopbackToExternal) { // Bind a socket to the IPv4 loopback interface and try to connect to // the IPv6 loopback interface, verify that connection fails. TEST(TCPClientSocketTest, BindLoopbackToIPv6) { - IPAddressNumber ipv6_lo_ip; - ASSERT_TRUE(ParseIPLiteralToNumber("::1", &ipv6_lo_ip)); TCPServerSocket server(NULL, NetLog::Source()); - int listen_result = server.Listen(IPEndPoint(ipv6_lo_ip, 0), 1); + int listen_result = + server.Listen(IPEndPoint(IPAddress::IPv6Localhost(), 0), 1); if (listen_result != OK) { LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is disabled." " Skipping the test"; @@ -96,9 +91,7 @@ TEST(TCPClientSocketTest, BindLoopbackToIPv6) { ASSERT_EQ(OK, server.GetLocalAddress(&server_address)); TCPClientSocket socket(AddressList(server_address), NULL, NetLog::Source()); - IPAddressNumber ipv4_lo_ip; - ASSERT_TRUE(ParseIPLiteralToNumber("127.0.0.1", &ipv4_lo_ip)); - EXPECT_EQ(OK, socket.Bind(IPEndPoint(ipv4_lo_ip, 0))); + EXPECT_EQ(OK, socket.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0))); TestCompletionCallback connect_callback; int result = socket.Connect(connect_callback.callback()); diff --git a/chromium/net/socket/tcp_server_socket.h b/chromium/net/socket/tcp_server_socket.h index ea2021c921e..7611ff911b9 100644 --- a/chromium/net/socket/tcp_server_socket.h +++ b/chromium/net/socket/tcp_server_socket.h @@ -5,7 +5,6 @@ #ifndef NET_SOCKET_TCP_SERVER_SOCKET_H_ #define NET_SOCKET_TCP_SERVER_SOCKET_H_ -#include "base/compiler_specific.h" #include "base/macros.h" #include "base/memory/scoped_ptr.h" #include "net/base/ip_endpoint.h" @@ -16,7 +15,7 @@ namespace net { -class NET_EXPORT_PRIVATE TCPServerSocket : public ServerSocket { +class NET_EXPORT TCPServerSocket : public ServerSocket { public: TCPServerSocket(NetLog* net_log, const NetLog::Source& source); ~TCPServerSocket() override; diff --git a/chromium/net/socket/tcp_server_socket_unittest.cc b/chromium/net/socket/tcp_server_socket_unittest.cc index 133c8073188..651cd6964ad 100644 --- a/chromium/net/socket/tcp_server_socket_unittest.cc +++ b/chromium/net/socket/tcp_server_socket_unittest.cc @@ -12,6 +12,7 @@ #include "base/memory/scoped_ptr.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" +#include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" @@ -31,16 +32,14 @@ class TCPServerSocketTest : public PlatformTest { } void SetUpIPv4() { - IPEndPoint address; - ParseAddress("127.0.0.1", 0, &address); + IPEndPoint address(IPAddress::IPv4Localhost(), 0); ASSERT_EQ(OK, socket_.Listen(address, kListenBacklog)); ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); } void SetUpIPv6(bool* success) { *success = false; - IPEndPoint address; - ParseAddress("::1", 0, &address); + IPEndPoint address(IPAddress::IPv6Localhost(), 0); if (socket_.Listen(address, kListenBacklog) != 0) { LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is " "disabled. Skipping the test"; @@ -50,16 +49,6 @@ class TCPServerSocketTest : public PlatformTest { *success = true; } - void ParseAddress(const std::string& ip_str, - uint16_t port, - IPEndPoint* address) { - IPAddressNumber ip_number; - bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); - if (!rv) - return; - *address = IPEndPoint(ip_number, port); - } - static IPEndPoint GetPeerAddress(StreamSocket* socket) { IPEndPoint address; EXPECT_EQ(OK, socket->GetPeerAddress(&address)); diff --git a/chromium/net/socket/tcp_socket.cc b/chromium/net/socket/tcp_socket.cc new file mode 100644 index 00000000000..ad5250620d0 --- /dev/null +++ b/chromium/net/socket/tcp_socket.cc @@ -0,0 +1,29 @@ +// Copyright 2016 The Chromium Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +#include "net/socket/tcp_socket.h" + +#include "build/build_config.h" + +#if defined(OS_POSIX) +#include <sys/socket.h> +#include <netinet/in.h> +#include <netinet/tcp.h> +#elif defined(OS_WIN) +#include <winsock2.h> +#endif + +namespace net { + +bool SetTCPNoDelay(SocketDescriptor socket, bool no_delay) { +#if defined(OS_POSIX) + int on = no_delay ? 1 : 0; +#elif defined(OS_WIN) + BOOL on = no_delay ? TRUE : FALSE; +#endif + return setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast<const char*>(&on), sizeof(on)) == 0; +} + +} // namespace net diff --git a/chromium/net/socket/tcp_socket.h b/chromium/net/socket/tcp_socket.h index 58797bb0851..319a6b261d0 100644 --- a/chromium/net/socket/tcp_socket.h +++ b/chromium/net/socket/tcp_socket.h @@ -7,6 +7,7 @@ #include "build/build_config.h" #include "net/base/net_export.h" +#include "net/socket/socket_descriptor.h" #if defined(OS_WIN) #include "net/socket/tcp_socket_win.h" @@ -39,6 +40,39 @@ bool IsTCPFastOpenUserEnabled(); // Not thread safe. Must be called during initialization/startup only. NET_EXPORT void CheckSupportAndMaybeEnableTCPFastOpen(bool user_enabled); +// This function enables/disables buffering in the kernel. By default, on Linux, +// TCP sockets will wait up to 200ms for more data to complete a packet before +// transmitting. After calling this function, the kernel will not wait. See +// TCP_NODELAY in `man 7 tcp`. +// +// For Windows: +// +// The Nagle implementation on Windows is governed by RFC 896. The idea +// behind Nagle is to reduce small packets on the network. When Nagle is +// enabled, if a partial packet has been sent, the TCP stack will disallow +// further *partial* packets until an ACK has been received from the other +// side. Good applications should always strive to send as much data as +// possible and avoid partial-packet sends. However, in most real world +// applications, there are edge cases where this does not happen, and two +// partial packets may be sent back to back. For a browser, it is NEVER +// a benefit to delay for an RTT before the second packet is sent. +// +// As a practical example in Chromium today, consider the case of a small +// POST. I have verified this: +// Client writes 649 bytes of header (partial packet #1) +// Client writes 50 bytes of POST data (partial packet #2) +// In the above example, with Nagle, a RTT delay is inserted between these +// two sends due to nagle. RTTs can easily be 100ms or more. The best +// fix is to make sure that for POSTing data, we write as much data as +// possible and minimize partial packets. We will fix that. But disabling +// Nagle also ensure we don't run into this delay in other edge cases. +// See also: +// http://technet.microsoft.com/en-us/library/bb726981.aspx +// +// This function returns true if it succeeds to set the TCP_NODELAY option, +// otherwise returns false. +NET_EXPORT_PRIVATE bool SetTCPNoDelay(SocketDescriptor socket, bool no_delay); + } // namespace net #endif // NET_SOCKET_TCP_SOCKET_H_ diff --git a/chromium/net/socket/tcp_socket_posix.cc b/chromium/net/socket/tcp_socket_posix.cc index 0546e6075a5..95fe6b5bd01 100644 --- a/chromium/net/socket/tcp_socket_posix.cc +++ b/chromium/net/socket/tcp_socket_posix.cc @@ -44,16 +44,6 @@ bool g_tcp_fastopen_user_enabled = false; // True if TCP FastOpen connect-with-write has failed at least once. bool g_tcp_fastopen_has_failed = false; -// SetTCPNoDelay turns on/off buffering in the kernel. By default, TCP sockets -// will wait up to 200ms for more data to complete a packet before transmitting. -// After calling this function, the kernel will not wait. See TCP_NODELAY in -// `man 7 tcp`. -bool SetTCPNoDelay(int fd, bool no_delay) { - int on = no_delay ? 1 : 0; - int error = setsockopt(fd, IPPROTO_TCP, TCP_NODELAY, &on, sizeof(on)); - return error == 0; -} - // SetTCPKeepAlive sets SO_KEEPALIVE. bool SetTCPKeepAlive(int fd, bool enable, int delay) { // Enabling TCP keepalives is the same on all platforms. @@ -443,10 +433,6 @@ void TCPSocketPosix::Close() { tcp_fastopen_status_ = TCP_FASTOPEN_STATUS_UNKNOWN; } -bool TCPSocketPosix::UsingTCPFastOpen() const { - return use_tcp_fastopen_; -} - void TCPSocketPosix::EnableTCPFastOpenIfSupported() { if (!IsTCPFastOpenSupported()) return; @@ -455,7 +441,7 @@ void TCPSocketPosix::EnableTCPFastOpenIfSupported() { // This check conservatively avoids middleboxes that may blackhole // TCP FastOpen SYN+Data packets; on such a failure, subsequent sockets // should not use TCP FastOpen. - if(!g_tcp_fastopen_has_failed) + if (!g_tcp_fastopen_has_failed) use_tcp_fastopen_ = true; else tcp_fastopen_status_ = TCP_FASTOPEN_PREVIOUSLY_FAILED; diff --git a/chromium/net/socket/tcp_socket_posix.h b/chromium/net/socket/tcp_socket_posix.h index 8d53caf567f..98849e45a19 100644 --- a/chromium/net/socket/tcp_socket_posix.h +++ b/chromium/net/socket/tcp_socket_posix.h @@ -73,8 +73,6 @@ class NET_EXPORT TCPSocketPosix { void Close(); - // Setter/Getter methods for TCP FastOpen socket option. - bool UsingTCPFastOpen() const; void EnableTCPFastOpenIfSupported(); bool IsValid() const; diff --git a/chromium/net/socket/tcp_socket_unittest.cc b/chromium/net/socket/tcp_socket_unittest.cc index ade3ee7fd8d..15f93551673 100644 --- a/chromium/net/socket/tcp_socket_unittest.cc +++ b/chromium/net/socket/tcp_socket_unittest.cc @@ -14,6 +14,7 @@ #include "net/base/address_list.h" #include "net/base/io_buffer.h" #include "net/base/ip_endpoint.h" +#include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/sockaddr_storage.h" #include "net/base/test_completion_callback.h" @@ -32,22 +33,17 @@ class TCPSocketTest : public PlatformTest { } void SetUpListenIPv4() { - IPEndPoint address; - ParseAddress("127.0.0.1", 0, &address); - ASSERT_EQ(OK, socket_.Open(ADDRESS_FAMILY_IPV4)); - ASSERT_EQ(OK, socket_.Bind(address)); + ASSERT_EQ(OK, socket_.Bind(IPEndPoint(IPAddress::IPv4Localhost(), 0))); ASSERT_EQ(OK, socket_.Listen(kListenBacklog)); ASSERT_EQ(OK, socket_.GetLocalAddress(&local_address_)); } void SetUpListenIPv6(bool* success) { *success = false; - IPEndPoint address; - ParseAddress("::1", 0, &address); if (socket_.Open(ADDRESS_FAMILY_IPV6) != OK || - socket_.Bind(address) != OK || + socket_.Bind(IPEndPoint(IPAddress::IPv6Localhost(), 0)) != OK || socket_.Listen(kListenBacklog) != OK) { LOG(ERROR) << "Failed to listen on ::1 - probably because IPv6 is " "disabled. Skipping the test"; @@ -57,16 +53,6 @@ class TCPSocketTest : public PlatformTest { *success = true; } - void ParseAddress(const std::string& ip_str, - uint16_t port, - IPEndPoint* address) { - IPAddressNumber ip_number; - bool rv = ParseIPLiteralToNumber(ip_str, &ip_number); - if (!rv) - return; - *address = IPEndPoint(ip_number, port); - } - void TestAcceptAsync() { TestCompletionCallback accept_callback; scoped_ptr<TCPSocket> accepted_socket; @@ -138,8 +124,7 @@ TEST_F(TCPSocketTest, AcceptForAdoptedListenSocket) { SOCKET existing_socket = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); ASSERT_EQ(OK, socket_.AdoptListenSocket(existing_socket)); - IPEndPoint address; - ParseAddress("127.0.0.1", 0, &address); + IPEndPoint address(IPAddress::IPv4Localhost(), 0); SockaddrStorage storage; ASSERT_TRUE(address.ToSockAddr(storage.addr, &storage.addr_len)); ASSERT_EQ(0, bind(existing_socket, storage.addr, storage.addr_len)); diff --git a/chromium/net/socket/tcp_socket_win.cc b/chromium/net/socket/tcp_socket_win.cc index a9998e2c0c6..0d11a0d7443 100644 --- a/chromium/net/socket/tcp_socket_win.cc +++ b/chromium/net/socket/tcp_socket_win.cc @@ -8,12 +8,13 @@ #include <errno.h> #include <mstcpip.h> +#include <utility> + #include "base/callback_helpers.h" #include "base/files/file_util.h" #include "base/logging.h" #include "base/macros.h" #include "base/profiler/scoped_tracker.h" -#include "base/win/windows_version.h" #include "net/base/address_list.h" #include "net/base/connection_type_histograms.h" #include "net/base/io_buffer.h" @@ -50,36 +51,6 @@ int SetSocketSendBufferSize(SOCKET socket, int32_t size) { } // Disable Nagle. -// The Nagle implementation on windows is governed by RFC 896. The idea -// behind Nagle is to reduce small packets on the network. When Nagle is -// enabled, if a partial packet has been sent, the TCP stack will disallow -// further *partial* packets until an ACK has been received from the other -// side. Good applications should always strive to send as much data as -// possible and avoid partial-packet sends. However, in most real world -// applications, there are edge cases where this does not happen, and two -// partial packets may be sent back to back. For a browser, it is NEVER -// a benefit to delay for an RTT before the second packet is sent. -// -// As a practical example in Chromium today, consider the case of a small -// POST. I have verified this: -// Client writes 649 bytes of header (partial packet #1) -// Client writes 50 bytes of POST data (partial packet #2) -// In the above example, with Nagle, a RTT delay is inserted between these -// two sends due to nagle. RTTs can easily be 100ms or more. The best -// fix is to make sure that for POSTing data, we write as much data as -// possible and minimize partial packets. We will fix that. But disabling -// Nagle also ensure we don't run into this delay in other edge cases. -// See also: -// http://technet.microsoft.com/en-us/library/bb726981.aspx -bool DisableNagle(SOCKET socket, bool disable) { - BOOL val = disable ? TRUE : FALSE; - int rv = setsockopt(socket, IPPROTO_TCP, TCP_NODELAY, - reinterpret_cast<const char*>(&val), - sizeof(val)); - DCHECK(!rv) << "Could not disable nagle"; - return rv == 0; -} - // Enable TCP Keep-Alive to prevent NAT routers from timing out TCP // connections. See http://crbug.com/27400 for details. bool SetTCPKeepAlive(SOCKET socket, BOOL enable, int delay_secs) { @@ -580,23 +551,7 @@ int TCPSocketWin::SetDefaultOptionsForServer() { } void TCPSocketWin::SetDefaultOptionsForClient() { - // Increase the socket buffer sizes from the default sizes for WinXP. In - // performance testing, there is substantial benefit by increasing from 8KB - // to 64KB. - // See also: - // http://support.microsoft.com/kb/823764/EN-US - // On Vista, if we manually set these sizes, Vista turns off its receive - // window auto-tuning feature. - // http://blogs.msdn.com/wndp/archive/2006/05/05/Winhec-blog-tcpip-2.aspx - // Since Vista's auto-tune is better than any static value we can could set, - // only change these on pre-vista machines. - if (base::win::GetVersion() < base::win::VERSION_VISTA) { - const int32_t kSocketBufferSize = 64 * 1024; - SetSocketReceiveBufferSize(socket_, kSocketBufferSize); - SetSocketSendBufferSize(socket_, kSocketBufferSize); - } - - DisableNagle(socket_, true); + SetTCPNoDelay(socket_, /*no_delay=*/true); SetTCPKeepAlive(socket_, true, kTCPKeepAliveSeconds); } @@ -641,7 +596,7 @@ bool TCPSocketWin::SetKeepAlive(bool enable, int delay) { } bool TCPSocketWin::SetNoDelay(bool no_delay) { - return DisableNagle(socket_, no_delay); + return SetTCPNoDelay(socket_, no_delay); } void TCPSocketWin::Close() { @@ -751,7 +706,7 @@ int TCPSocketWin::AcceptInternal(scoped_ptr<TCPSocketWin>* socket, net_log_.EndEventWithNetErrorCode(NetLog::TYPE_TCP_ACCEPT, adopt_result); return adopt_result; } - *socket = tcp_socket.Pass(); + *socket = std::move(tcp_socket); *address = ip_end_point; net_log_.EndEvent(NetLog::TYPE_TCP_ACCEPT, CreateNetLogIPEndPointCallback(&ip_end_point)); diff --git a/chromium/net/socket/tcp_socket_win.h b/chromium/net/socket/tcp_socket_win.h index f24ffb801b7..1786af11ada 100644 --- a/chromium/net/socket/tcp_socket_win.h +++ b/chromium/net/socket/tcp_socket_win.h @@ -81,9 +81,7 @@ class NET_EXPORT TCPSocketWin : NON_EXPORTED_BASE(public base::NonThreadSafe), void Close(); - // Setter/Getter methods for TCP FastOpen socket option. - // NOOPs since TCP FastOpen is not implemented in Windows. - bool UsingTCPFastOpen() const { return false; } + // NOOP since TCP FastOpen is not implemented in Windows. void EnableTCPFastOpenIfSupported() {} bool IsValid() const { return socket_ != INVALID_SOCKET; } diff --git a/chromium/net/socket/transport_client_socket_pool.cc b/chromium/net/socket/transport_client_socket_pool.cc index 0949193ff2e..dbc701a9386 100644 --- a/chromium/net/socket/transport_client_socket_pool.cc +++ b/chromium/net/socket/transport_client_socket_pool.cc @@ -16,6 +16,7 @@ #include "base/strings/string_util.h" #include "base/synchronization/lock.h" #include "base/time/time.h" +#include "base/trace_event/trace_event.h" #include "base/values.h" #include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" @@ -61,11 +62,9 @@ static base::LazyInstance<base::TimeTicks>::Leaky TransportSocketParams::TransportSocketParams( const HostPortPair& host_port_pair, bool disable_resolver_cache, - bool ignore_limits, const OnHostResolutionCallback& host_resolution_callback, CombineConnectAndWritePolicy combine_connect_and_write_if_supported) : destination_(host_port_pair), - ignore_limits_(ignore_limits), host_resolution_callback_(host_resolution_callback), combine_connect_and_write_(combine_connect_and_write_if_supported) { if (disable_resolver_cache) @@ -117,6 +116,7 @@ int TransportConnectJobHelper::DoResolveHost(RequestPriority priority, int TransportConnectJobHelper::DoResolveHostComplete( int result, const BoundNetLog& net_log) { + TRACE_EVENT0("net", "TransportConnectJobHelper::DoResolveHostComplete"); connect_timing_->dns_end = base::TimeTicks::Now(); // Overwrite connection start time, since for connections that do not go // through proxies, |connect_start| should not include dns lookup time. @@ -196,6 +196,7 @@ base::TimeDelta TransportConnectJobHelper::HistogramDuration( TransportConnectJob::TransportConnectJob( const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<TransportSocketParams>& params, base::TimeDelta timeout_duration, ClientSocketFactory* client_socket_factory, @@ -205,6 +206,7 @@ TransportConnectJob::TransportConnectJob( : ConnectJob(group_name, timeout_duration, priority, + respect_limits, delegate, BoundNetLog::Make(net_log, NetLog::SOURCE_CONNECT_JOB)), helper_(params, client_socket_factory, host_resolver, &connect_timing_), @@ -487,15 +489,10 @@ TransportClientSocketPool::TransportConnectJobFactory::NewConnectJob( const std::string& group_name, const PoolBase::Request& request, ConnectJob::Delegate* delegate) const { - return scoped_ptr<ConnectJob>( - new TransportConnectJob(group_name, - request.priority(), - request.params(), - ConnectionTimeout(), - client_socket_factory_, - host_resolver_, - delegate, - net_log_)); + return scoped_ptr<ConnectJob>(new TransportConnectJob( + group_name, request.priority(), request.respect_limits(), + request.params(), ConnectionTimeout(), client_socket_factory_, + host_resolver_, delegate, net_log_)); } base::TimeDelta @@ -523,20 +520,20 @@ TransportClientSocketPool::TransportClientSocketPool( TransportClientSocketPool::~TransportClientSocketPool() {} -int TransportClientSocketPool::RequestSocket( - const std::string& group_name, - const void* params, - RequestPriority priority, - ClientSocketHandle* handle, - const CompletionCallback& callback, - const BoundNetLog& net_log) { +int TransportClientSocketPool::RequestSocket(const std::string& group_name, + const void* params, + RequestPriority priority, + RespectLimits respect_limits, + ClientSocketHandle* handle, + const CompletionCallback& callback, + const BoundNetLog& net_log) { const scoped_refptr<TransportSocketParams>* casted_params = static_cast<const scoped_refptr<TransportSocketParams>*>(params); NetLogTcpClientSocketPoolRequestedSocket(net_log, casted_params); - return base_.RequestSocket(group_name, *casted_params, priority, handle, - callback, net_log); + return base_.RequestSocket(group_name, *casted_params, priority, + respect_limits, handle, callback, net_log); } void TransportClientSocketPool::NetLogTcpClientSocketPoolRequestedSocket( diff --git a/chromium/net/socket/transport_client_socket_pool.h b/chromium/net/socket/transport_client_socket_pool.h index 07a67b515a5..085fbb91e86 100644 --- a/chromium/net/socket/transport_client_socket_pool.h +++ b/chromium/net/socket/transport_client_socket_pool.h @@ -50,12 +50,10 @@ class NET_EXPORT_PRIVATE TransportSocketParams TransportSocketParams( const HostPortPair& host_port_pair, bool disable_resolver_cache, - bool ignore_limits, const OnHostResolutionCallback& host_resolution_callback, CombineConnectAndWritePolicy combine_connect_and_write); const HostResolver::RequestInfo& destination() const { return destination_; } - bool ignore_limits() const { return ignore_limits_; } const OnHostResolutionCallback& host_resolution_callback() const { return host_resolution_callback_; } @@ -69,7 +67,6 @@ class NET_EXPORT_PRIVATE TransportSocketParams ~TransportSocketParams(); HostResolver::RequestInfo destination_; - bool ignore_limits_; const OnHostResolutionCallback host_resolution_callback_; CombineConnectAndWritePolicy combine_connect_and_write_; @@ -158,6 +155,7 @@ class NET_EXPORT_PRIVATE TransportConnectJob : public ConnectJob { public: TransportConnectJob(const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<TransportSocketParams>& params, base::TimeDelta timeout_duration, ClientSocketFactory* client_socket_factory, @@ -241,6 +239,7 @@ class NET_EXPORT_PRIVATE TransportClientSocketPool : public ClientSocketPool { int RequestSocket(const std::string& group_name, const void* resolve_info, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) override; diff --git a/chromium/net/socket/transport_client_socket_pool_test_util.cc b/chromium/net/socket/transport_client_socket_pool_test_util.cc index 0140c255e18..190b1c57579 100644 --- a/chromium/net/socket/transport_client_socket_pool_test_util.cc +++ b/chromium/net/socket/transport_client_socket_pool_test_util.cc @@ -15,10 +15,10 @@ #include "base/run_loop.h" #include "base/single_thread_task_runner.h" #include "base/thread_task_runner_handle.h" +#include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/load_timing_info.h" #include "net/base/load_timing_info_test_util.h" -#include "net/base/net_util.h" #include "net/socket/client_socket_handle.h" #include "net/socket/ssl_client_socket.h" #include "net/udp/datagram_client_socket.h" @@ -28,10 +28,10 @@ namespace net { namespace { -IPAddressNumber ParseIP(const std::string& ip) { - IPAddressNumber number; - CHECK(ParseIPLiteralToNumber(ip, &number)); - return number; +IPAddress ParseIP(const std::string& ip) { + IPAddress address; + CHECK(address.AssignFromIPLiteral(ip)); + return address; } // A StreamSocket which connects synchronously and successfully. @@ -40,8 +40,7 @@ class MockConnectClientSocket : public StreamSocket { MockConnectClientSocket(const AddressList& addrlist, net::NetLog* net_log) : connected_(false), addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - use_tcp_fastopen_(false) {} + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {} // StreamSocket implementation. int Connect(const CompletionCallback& callback) override { @@ -70,8 +69,7 @@ class MockConnectClientSocket : public StreamSocket { void SetSubresourceSpeculation() override {} void SetOmniboxSpeculation() override {} bool WasEverUsed() const override { return false; } - void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; } - bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; } + void EnableTCPFastOpenIfSupported() override {} bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } @@ -103,7 +101,6 @@ class MockConnectClientSocket : public StreamSocket { bool connected_; const AddressList addrlist_; BoundNetLog net_log_; - bool use_tcp_fastopen_; DISALLOW_COPY_AND_ASSIGN(MockConnectClientSocket); }; @@ -112,8 +109,7 @@ class MockFailingClientSocket : public StreamSocket { public: MockFailingClientSocket(const AddressList& addrlist, net::NetLog* net_log) : addrlist_(addrlist), - net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - use_tcp_fastopen_(false) {} + net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)) {} // StreamSocket implementation. int Connect(const CompletionCallback& callback) override { @@ -135,8 +131,7 @@ class MockFailingClientSocket : public StreamSocket { void SetSubresourceSpeculation() override {} void SetOmniboxSpeculation() override {} bool WasEverUsed() const override { return false; } - void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; } - bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; } + void EnableTCPFastOpenIfSupported() override {} bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } @@ -170,7 +165,6 @@ class MockFailingClientSocket : public StreamSocket { private: const AddressList addrlist_; BoundNetLog net_log_; - bool use_tcp_fastopen_; DISALLOW_COPY_AND_ASSIGN(MockFailingClientSocket); }; @@ -186,7 +180,6 @@ class MockTriggerableClientSocket : public StreamSocket { is_connected_(false), addrlist_(addrlist), net_log_(BoundNetLog::Make(net_log, NetLog::SOURCE_SOCKET)), - use_tcp_fastopen_(false), weak_factory_(this) {} // Call this method to get a closure which will trigger the connect callback @@ -264,8 +257,7 @@ class MockTriggerableClientSocket : public StreamSocket { void SetSubresourceSpeculation() override {} void SetOmniboxSpeculation() override {} bool WasEverUsed() const override { return false; } - void EnableTCPFastOpenIfSupported() override { use_tcp_fastopen_ = true; } - bool UsingTCPFastOpen() const override { return use_tcp_fastopen_; } + void EnableTCPFastOpenIfSupported() override {} bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } bool GetSSLInfo(SSLInfo* ssl_info) override { return false; } @@ -308,7 +300,6 @@ class MockTriggerableClientSocket : public StreamSocket { const AddressList addrlist_; BoundNetLog net_log_; CompletionCallback callback_; - bool use_tcp_fastopen_; ConnectionAttempts connection_attempts_; base::WeakPtrFactory<MockTriggerableClientSocket> weak_factory_; diff --git a/chromium/net/socket/transport_client_socket_pool_unittest.cc b/chromium/net/socket/transport_client_socket_pool_unittest.cc index dab0d076398..c8702f70fb2 100644 --- a/chromium/net/socket/transport_client_socket_pool_unittest.cc +++ b/chromium/net/socket/transport_client_socket_pool_unittest.cc @@ -10,11 +10,11 @@ #include "base/macros.h" #include "base/message_loop/message_loop.h" #include "base/threading/platform_thread.h" +#include "net/base/ip_address.h" #include "net/base/ip_endpoint.h" #include "net/base/load_timing_info.h" #include "net/base/load_timing_info_test_util.h" #include "net/base/net_errors.h" -#include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/dns/mock_host_resolver.h" #include "net/log/test_net_log.h" @@ -43,7 +43,6 @@ class TransportClientSocketPoolTest : public testing::Test { new TransportSocketParams( HostPortPair("www.google.com", 80), false, - false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), host_resolver_(new MockHostResolver), @@ -61,18 +60,18 @@ class TransportClientSocketPoolTest : public testing::Test { } scoped_refptr<TransportSocketParams> CreateParamsForTCPFastOpen() { - return new TransportSocketParams(HostPortPair("www.google.com", 80), - false, false, OnHostResolutionCallback(), - TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED); + return new TransportSocketParams( + HostPortPair("www.google.com", 80), false, OnHostResolutionCallback(), + TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DESIRED); } int StartRequest(const std::string& group_name, RequestPriority priority) { scoped_refptr<TransportSocketParams> params(new TransportSocketParams( - HostPortPair("www.google.com", 80), false, false, - OnHostResolutionCallback(), + HostPortPair("www.google.com", 80), false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); return test_base_.StartRequestUsingPool( - &pool_, group_name, priority, params); + &pool_, group_name, priority, ClientSocketPool::RespectLimits::ENABLED, + params); } int GetOrderOfRequest(size_t index) { @@ -105,15 +104,13 @@ class TransportClientSocketPoolTest : public testing::Test { }; TEST(TransportConnectJobTest, MakeAddrListStartWithIPv4) { - IPAddressNumber ip_number; - ASSERT_TRUE(ParseIPLiteralToNumber("192.168.1.1", &ip_number)); - IPEndPoint addrlist_v4_1(ip_number, 80); - ASSERT_TRUE(ParseIPLiteralToNumber("192.168.1.2", &ip_number)); - IPEndPoint addrlist_v4_2(ip_number, 80); - ASSERT_TRUE(ParseIPLiteralToNumber("2001:4860:b006::64", &ip_number)); - IPEndPoint addrlist_v6_1(ip_number, 80); - ASSERT_TRUE(ParseIPLiteralToNumber("2001:4860:b006::66", &ip_number)); - IPEndPoint addrlist_v6_2(ip_number, 80); + IPEndPoint addrlist_v4_1(IPAddress(192, 168, 1, 1), 80); + IPEndPoint addrlist_v4_2(IPAddress(192, 168, 1, 2), 80); + IPAddress ip_address; + ASSERT_TRUE(ip_address.AssignFromIPLiteral("2001:4860:b006::64")); + IPEndPoint addrlist_v6_1(ip_address, 80); + ASSERT_TRUE(ip_address.AssignFromIPLiteral("2001:4860:b006::66")); + IPEndPoint addrlist_v6_2(ip_address, 80); AddressList addrlist; @@ -178,8 +175,9 @@ TEST(TransportConnectJobTest, MakeAddrListStartWithIPv4) { TEST_F(TransportClientSocketPoolTest, Basic) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, - BoundNetLog()); + int rv = + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -199,8 +197,9 @@ TEST_F(TransportClientSocketPoolTest, SetResolvePriorityOnInit) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", params_, priority, callback.callback(), &pool_, - BoundNetLog())); + handle.Init("a", params_, priority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(priority, host_resolver_->last_request_priority()); } } @@ -211,11 +210,12 @@ TEST_F(TransportClientSocketPoolTest, InitHostResolutionFailure) { ClientSocketHandle handle; HostPortPair host_port_pair("unresolvable.host.name", 80); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - host_port_pair, false, false, OnHostResolutionCallback(), + host_port_pair, false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", dest, kDefaultPriority, callback.callback(), - &pool_, BoundNetLog())); + handle.Init("a", dest, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback.WaitForResult()); ASSERT_EQ(1u, handle.connection_attempts().size()); EXPECT_TRUE(handle.connection_attempts()[0].endpoint.address().empty()); @@ -228,8 +228,9 @@ TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", params_, kDefaultPriority, callback.callback(), - &pool_, BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); ASSERT_EQ(1u, handle.connection_attempts().size()); EXPECT_EQ("127.0.0.1:80", @@ -239,8 +240,9 @@ TEST_F(TransportClientSocketPoolTest, InitConnectionFailure) { // Make the host resolutions complete synchronously this time. host_resolver_->set_synchronous_mode(true); EXPECT_EQ(ERR_CONNECTION_FAILED, - handle.Init("a", params_, kDefaultPriority, callback.callback(), - &pool_, BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); ASSERT_EQ(1u, handle.connection_attempts().size()); EXPECT_EQ("127.0.0.1:80", handle.connection_attempts()[0].endpoint.ToString()); @@ -350,8 +352,9 @@ TEST_F(TransportClientSocketPoolTest, CancelRequestClearGroup) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", params_, kDefaultPriority, callback.callback(), - &pool_, BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); handle.Reset(); } @@ -362,11 +365,13 @@ TEST_F(TransportClientSocketPoolTest, TwoRequestsCancelOne) { TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", params_, kDefaultPriority, callback.callback(), - &pool_, BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(ERR_IO_PENDING, - handle2.Init("a", params_, kDefaultPriority, callback2.callback(), - &pool_, BoundNetLog())); + handle2.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), &pool_, BoundNetLog())); handle.Reset(); @@ -380,15 +385,17 @@ TEST_F(TransportClientSocketPoolTest, ConnectCancelConnect) { ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", params_, kDefaultPriority, callback.callback(), - &pool_, BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); handle.Reset(); TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", params_, kDefaultPriority, callback2.callback(), - &pool_, BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), &pool_, BoundNetLog())); host_resolver_->set_synchronous_mode(true); // At this point, handle has two ConnectingSockets out for it. Due to the @@ -497,11 +504,11 @@ class RequestSocketCallback : public TestCompletionCallbackBase { } within_callback_ = true; scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - HostPortPair("www.google.com", 80), false, false, - OnHostResolutionCallback(), + HostPortPair("www.google.com", 80), false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); - int rv = handle_->Init("a", dest, LOWEST, callback(), pool_, - BoundNetLog()); + int rv = handle_->Init("a", dest, LOWEST, + ClientSocketPool::RespectLimits::ENABLED, + callback(), pool_, BoundNetLog()); EXPECT_EQ(OK, rv); } } @@ -518,11 +525,11 @@ TEST_F(TransportClientSocketPoolTest, RequestTwice) { ClientSocketHandle handle; RequestSocketCallback callback(&handle, &pool_); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - HostPortPair("www.google.com", 80), false, false, - OnHostResolutionCallback(), + HostPortPair("www.google.com", 80), false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); - int rv = handle.Init("a", dest, LOWEST, callback.callback(), &pool_, - BoundNetLog()); + int rv = + handle.Init("a", dest, LOWEST, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); ASSERT_EQ(ERR_IO_PENDING, rv); // The callback is going to request "www.google.com". We want it to complete @@ -584,8 +591,9 @@ TEST_F(TransportClientSocketPoolTest, FailingActiveRequestWithPendingRequests) { TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, - BoundNetLog()); + int rv = + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -602,8 +610,8 @@ TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { // Now we should have 1 idle socket. EXPECT_EQ(1, pool_.IdleSocketCount()); - rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, - BoundNetLog()); + rv = handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(OK, rv); EXPECT_EQ(0, pool_.IdleSocketCount()); TestLoadTimingInfoConnectedReused(handle); @@ -612,8 +620,9 @@ TEST_F(TransportClientSocketPoolTest, IdleSocketLoadTiming) { TEST_F(TransportClientSocketPoolTest, ResetIdleSocketsOnIPAddressChange) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", params_, LOW, callback.callback(), &pool_, - BoundNetLog()); + int rv = + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -668,8 +677,9 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketConnect) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, - BoundNetLog()); + int rv = + handle.Init("b", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -710,8 +720,9 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketCancel) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("c", params_, LOW, callback.callback(), &pool_, - BoundNetLog()); + int rv = + handle.Init("c", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -756,8 +767,9 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterStall) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, - BoundNetLog()); + int rv = + handle.Init("b", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -806,8 +818,9 @@ TEST_F(TransportClientSocketPoolTest, BackupSocketFailAfterDelay) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("b", params_, LOW, callback.callback(), &pool_, - BoundNetLog()); + int rv = + handle.Init("b", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -863,8 +876,9 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv4FinishesFirst) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, - BoundNetLog()); + int rv = + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -915,8 +929,9 @@ TEST_F(TransportClientSocketPoolTest, IPv6FallbackSocketIPv6FinishesFirst) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, - BoundNetLog()); + int rv = + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -957,8 +972,9 @@ TEST_F(TransportClientSocketPoolTest, IPv6NoIPv4AddressesToFallbackTo) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, - BoundNetLog()); + int rv = + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -990,8 +1006,9 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init("a", params_, LOW, callback.callback(), &pool, - BoundNetLog()); + int rv = + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -1009,15 +1026,16 @@ TEST_F(TransportClientSocketPoolTest, IPv4HasNoFallback) { // Test that if TCP FastOpen is enabled, it is set on the socket // when we have only an IPv4 address. TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv4WithNoFallback) { + SequencedSocketData socket_data(nullptr, 0, nullptr, 0); + MockClientSocketFactory factory; + factory.AddSocketDataProvider(&socket_data); // Create a pool without backup jobs. ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, host_resolver_.get(), - &client_socket_factory_, + &factory, NULL); - client_socket_factory_.set_default_client_socket_type( - MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); // Resolve an AddressList with only IPv4 addresses. host_resolver_->rules()->AddIPLiteralRule("*", "1.1.1.1", std::string()); @@ -1025,20 +1043,24 @@ TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv4WithNoFallback) { ClientSocketHandle handle; // Enable TCP FastOpen in TransportSocketParams. scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen(); - handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(OK, callback.WaitForResult()); - EXPECT_TRUE(handle.socket()->UsingTCPFastOpen()); + EXPECT_TRUE(socket_data.IsUsingTCPFastOpen()); } // Test that if TCP FastOpen is enabled, it is set on the socket // when we have only IPv6 addresses. TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv6WithNoFallback) { + SequencedSocketData socket_data(nullptr, 0, nullptr, 0); + MockClientSocketFactory factory; + factory.AddSocketDataProvider(&socket_data); // Create a pool without backup jobs. ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, host_resolver_.get(), - &client_socket_factory_, + &factory, NULL); client_socket_factory_.set_default_client_socket_type( MockTransportClientSocketFactory::MOCK_DELAYED_CLIENT_SOCKET); @@ -1050,31 +1072,31 @@ TEST_F(TransportClientSocketPoolTest, TCPFastOpenOnIPv6WithNoFallback) { ClientSocketHandle handle; // Enable TCP FastOpen in TransportSocketParams. scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen(); - handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(OK, callback.WaitForResult()); - EXPECT_TRUE(handle.socket()->UsingTCPFastOpen()); + EXPECT_TRUE(socket_data.IsUsingTCPFastOpen()); } // Test that if TCP FastOpen is enabled, it does not do anything when there // is a IPv6 address with fallback to an IPv4 address. This test tests the case // when the IPv6 connect fails and the IPv4 one succeeds. TEST_F(TransportClientSocketPoolTest, - NoTCPFastOpenOnIPv6FailureWithIPv4Fallback) { + NoTCPFastOpenOnIPv6FailureWithIPv4Fallback) { + SequencedSocketData socket_data_1(nullptr, 0, nullptr, 0); + socket_data_1.set_connect_data(MockConnect(SYNCHRONOUS, ERR_IO_PENDING)); + SequencedSocketData socket_data_2(nullptr, 0, nullptr, 0); + MockClientSocketFactory factory; + factory.AddSocketDataProvider(&socket_data_1); + factory.AddSocketDataProvider(&socket_data_2); // Create a pool without backup jobs. ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, host_resolver_.get(), - &client_socket_factory_, + &factory, NULL); - MockTransportClientSocketFactory::ClientSocketType case_types[] = { - // This is the IPv6 socket. - MockTransportClientSocketFactory::MOCK_STALLED_CLIENT_SOCKET, - // This is the IPv4 socket. - MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET - }; - client_socket_factory_.set_client_socket_types(case_types, 2); // Resolve an AddressList with a IPv6 address first and then a IPv4 address. host_resolver_->rules() ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string()); @@ -1083,37 +1105,33 @@ TEST_F(TransportClientSocketPoolTest, ClientSocketHandle handle; // Enable TCP FastOpen in TransportSocketParams. scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen(); - handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(OK, callback.WaitForResult()); // Verify that the socket used is connected to the fallback IPv4 address. IPEndPoint endpoint; - handle.socket()->GetLocalAddress(&endpoint); + handle.socket()->GetPeerAddress(&endpoint); EXPECT_EQ(kIPv4AddressSize, endpoint.address().size()); - EXPECT_EQ(2, client_socket_factory_.allocation_count()); // Verify that TCP FastOpen was not turned on for the socket. - EXPECT_FALSE(handle.socket()->UsingTCPFastOpen()); + EXPECT_FALSE(socket_data_1.IsUsingTCPFastOpen()); } // Test that if TCP FastOpen is enabled, it does not do anything when there // is a IPv6 address with fallback to an IPv4 address. This test tests the case // when the IPv6 connect succeeds. TEST_F(TransportClientSocketPoolTest, - NoTCPFastOpenOnIPv6SuccessWithIPv4Fallback) { + NoTCPFastOpenOnIPv6SuccessWithIPv4Fallback) { + SequencedSocketData socket_data(nullptr, 0, nullptr, 0); + MockClientSocketFactory factory; + factory.AddSocketDataProvider(&socket_data); // Create a pool without backup jobs. ClientSocketPoolBaseHelper::set_connect_backup_jobs_enabled(false); TransportClientSocketPool pool(kMaxSockets, kMaxSocketsPerGroup, host_resolver_.get(), - &client_socket_factory_, + &factory, NULL); - MockTransportClientSocketFactory::ClientSocketType case_types[] = { - // This is the IPv6 socket. - MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET, - // This is the IPv4 socket. - MockTransportClientSocketFactory::MOCK_PENDING_CLIENT_SOCKET - }; - client_socket_factory_.set_client_socket_types(case_types, 2); // Resolve an AddressList with a IPv6 address first and then a IPv4 address. host_resolver_->rules() ->AddIPLiteralRule("*", "2:abcd::3:4:ff,2.2.2.2", std::string()); @@ -1122,15 +1140,15 @@ TEST_F(TransportClientSocketPoolTest, ClientSocketHandle handle; // Enable TCP FastOpen in TransportSocketParams. scoped_refptr<TransportSocketParams> params = CreateParamsForTCPFastOpen(); - handle.Init("a", params, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(OK, callback.WaitForResult()); - // Verify that the socket used is connected to the IPv6 address. IPEndPoint endpoint; - handle.socket()->GetLocalAddress(&endpoint); + handle.socket()->GetPeerAddress(&endpoint); + // Verify that the socket used is connected to the IPv6 address. EXPECT_EQ(kIPv6AddressSize, endpoint.address().size()); - EXPECT_EQ(1, client_socket_factory_.allocation_count()); // Verify that TCP FastOpen was not turned on for the socket. - EXPECT_FALSE(handle.socket()->UsingTCPFastOpen()); + EXPECT_FALSE(socket_data.IsUsingTCPFastOpen()); } } // namespace diff --git a/chromium/net/socket/transport_client_socket_unittest.cc b/chromium/net/socket/transport_client_socket_unittest.cc index ef5291d52d5..f9b095ca925 100644 --- a/chromium/net/socket/transport_client_socket_unittest.cc +++ b/chromium/net/socket/transport_client_socket_unittest.cc @@ -10,6 +10,7 @@ #include "base/run_loop.h" #include "net/base/address_list.h" #include "net/base/io_buffer.h" +#include "net/base/ip_address.h" #include "net/base/net_errors.h" #include "net/base/test_completion_callback.h" #include "net/dns/mock_host_resolver.h" @@ -97,9 +98,7 @@ void TransportClientSocketTest::SetUp() { // Open a server socket on an ephemeral port. listen_sock_.reset(new TCPServerSocket(NULL, NetLog::Source())); - IPAddressNumber address; - ParseIPLiteralToNumber("127.0.0.1", &address); - IPEndPoint local_address(address, 0); + IPEndPoint local_address(IPAddress::IPv4Localhost(), 0); ASSERT_EQ(OK, listen_sock_->Listen(local_address, 1)); // Get the server's address (including the actual port number). ASSERT_EQ(OK, listen_sock_->GetLocalAddress(&local_address)); diff --git a/chromium/net/socket/unix_domain_client_socket_posix.cc b/chromium/net/socket/unix_domain_client_socket_posix.cc index 564bef80d71..792cfac9bd2 100644 --- a/chromium/net/socket/unix_domain_client_socket_posix.cc +++ b/chromium/net/socket/unix_domain_client_socket_posix.cc @@ -9,8 +9,6 @@ #include <utility> #include "base/logging.h" -#include "base/posix/eintr_wrapper.h" -#include "net/base/ip_endpoint.h" #include "net/base/net_errors.h" #include "net/base/sockaddr_storage.h" #include "net/socket/socket_posix.h" @@ -34,8 +32,10 @@ UnixDomainClientSocket::~UnixDomainClientSocket() { bool UnixDomainClientSocket::FillAddress(const std::string& socket_path, bool use_abstract_namespace, SockaddrStorage* address) { - struct sockaddr_un* socket_addr = - reinterpret_cast<struct sockaddr_un*>(address->addr); + // Caller should provide a non-empty path for the socket address. + if (socket_path.empty()) + return false; + size_t path_max = address->addr_len - offsetof(struct sockaddr_un, sun_path); // Non abstract namespace pathname should be null-terminated. Abstract // namespace pathname must start with '\0'. So, the size is always greater @@ -44,6 +44,8 @@ bool UnixDomainClientSocket::FillAddress(const std::string& socket_path, if (path_size > path_max) return false; + struct sockaddr_un* socket_addr = + reinterpret_cast<struct sockaddr_un*>(address->addr); memset(socket_addr, 0, address->addr_len); socket_addr->sun_family = AF_UNIX; address->addr_len = path_size + offsetof(struct sockaddr_un, sun_path); @@ -68,9 +70,6 @@ bool UnixDomainClientSocket::FillAddress(const std::string& socket_path, int UnixDomainClientSocket::Connect(const CompletionCallback& callback) { DCHECK(!socket_); - if (socket_path_.empty()) - return ERR_ADDRESS_INVALID; - SockaddrStorage address; if (!FillAddress(socket_path_, use_abstract_namespace_, &address)) return ERR_ADDRESS_INVALID; @@ -132,10 +131,6 @@ bool UnixDomainClientSocket::WasEverUsed() const { return true; // We don't care. } -bool UnixDomainClientSocket::UsingTCPFastOpen() const { - return false; -} - bool UnixDomainClientSocket::WasNpnNegotiated() const { return false; } diff --git a/chromium/net/socket/unix_domain_client_socket_posix.h b/chromium/net/socket/unix_domain_client_socket_posix.h index 260b3b9be42..596aa096c53 100644 --- a/chromium/net/socket/unix_domain_client_socket_posix.h +++ b/chromium/net/socket/unix_domain_client_socket_posix.h @@ -52,7 +52,6 @@ class NET_EXPORT UnixDomainClientSocket : public StreamSocket { void SetSubresourceSpeculation() override; void SetOmniboxSpeculation() override; bool WasEverUsed() const override; - bool UsingTCPFastOpen() const override; bool WasNpnNegotiated() const override; NextProto GetNegotiatedProtocol() const override; bool GetSSLInfo(SSLInfo* ssl_info) override; diff --git a/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc index 35830759630..b20678b420a 100644 --- a/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc +++ b/chromium/net/socket/unix_domain_client_socket_posix_unittest.cc @@ -131,7 +131,7 @@ TEST_F(UnixDomainClientSocketTest, Connect) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), kUseAbstractNamespace); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); scoped_ptr<StreamSocket> accepted_socket; TestCompletionCallback accept_callback; @@ -157,7 +157,7 @@ TEST_F(UnixDomainClientSocketTest, ConnectWithSocketDescriptor) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), kUseAbstractNamespace); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); SocketDescriptor accepted_socket_fd = kInvalidSocket; TestCompletionCallback accept_callback; @@ -209,7 +209,7 @@ TEST_F(UnixDomainClientSocketTest, ConnectWithAbstractNamespace) { #if defined(OS_ANDROID) || defined(OS_LINUX) UnixDomainServerSocket server_socket(CreateAuthCallback(true), kUseAbstractNamespace); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); scoped_ptr<StreamSocket> accepted_socket; TestCompletionCallback accept_callback; @@ -255,7 +255,7 @@ TEST_F(UnixDomainClientSocketTest, TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), false); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); scoped_ptr<StreamSocket> accepted_socket; TestCompletionCallback accept_callback; EXPECT_EQ(ERR_IO_PENDING, @@ -288,7 +288,7 @@ TEST_F(UnixDomainClientSocketTest, DisconnectFromClient) { TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), false); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); scoped_ptr<StreamSocket> accepted_socket; TestCompletionCallback accept_callback; EXPECT_EQ(ERR_IO_PENDING, @@ -321,7 +321,7 @@ TEST_F(UnixDomainClientSocketTest, DisconnectFromServer) { TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), false); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); scoped_ptr<StreamSocket> accepted_socket; TestCompletionCallback accept_callback; EXPECT_EQ(ERR_IO_PENDING, @@ -390,7 +390,7 @@ TEST_F(UnixDomainClientSocketTest, ReadAfterWrite) { TEST_F(UnixDomainClientSocketTest, ReadBeforeWrite) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), false); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); scoped_ptr<StreamSocket> accepted_socket; TestCompletionCallback accept_callback; EXPECT_EQ(ERR_IO_PENDING, diff --git a/chromium/net/socket/unix_domain_server_socket_posix.cc b/chromium/net/socket/unix_domain_server_socket_posix.cc index 161893a4324..b6ba35c6362 100644 --- a/chromium/net/socket/unix_domain_server_socket_posix.cc +++ b/chromium/net/socket/unix_domain_server_socket_posix.cc @@ -68,13 +68,19 @@ int UnixDomainServerSocket::Listen(const IPEndPoint& address, int backlog) { } int UnixDomainServerSocket::ListenWithAddressAndPort( - const std::string& unix_domain_path, - uint16_t port_unused, + const std::string& address_string, + uint16_t port, int backlog) { + NOTIMPLEMENTED(); + return ERR_NOT_IMPLEMENTED; +} + +int UnixDomainServerSocket::BindAndListen(const std::string& socket_path, + int backlog) { DCHECK(!listen_socket_); SockaddrStorage address; - if (!UnixDomainClientSocket::FillAddress(unix_domain_path, + if (!UnixDomainClientSocket::FillAddress(socket_path, use_abstract_namespace_, &address)) { return ERR_ADDRESS_INVALID; @@ -90,7 +96,7 @@ int UnixDomainServerSocket::ListenWithAddressAndPort( DCHECK_NE(ERR_IO_PENDING, rv); if (rv != OK) { PLOG(ERROR) - << "Could not bind unix domain socket to " << unix_domain_path + << "Could not bind unix domain socket to " << socket_path << (use_abstract_namespace_ ? " (with abstract namespace)" : ""); return rv; } diff --git a/chromium/net/socket/unix_domain_server_socket_posix.h b/chromium/net/socket/unix_domain_server_socket_posix.h index 55d2708b323..7395f058ce5 100644 --- a/chromium/net/socket/unix_domain_server_socket_posix.h +++ b/chromium/net/socket/unix_domain_server_socket_posix.h @@ -52,13 +52,17 @@ class NET_EXPORT UnixDomainServerSocket : public ServerSocket { // ServerSocket implementation. int Listen(const IPEndPoint& address, int backlog) override; - int ListenWithAddressAndPort(const std::string& unix_domain_path, - uint16_t port_unused, + int ListenWithAddressAndPort(const std::string& address_string, + uint16_t port, int backlog) override; int GetLocalAddress(IPEndPoint* address) const override; int Accept(scoped_ptr<StreamSocket>* socket, const CompletionCallback& callback) override; + // Creates a server socket, binds it to the specified |socket_path| and + // starts listening for incoming connections with the specified |backlog|. + int BindAndListen(const std::string& socket_path, int backlog); + // Accepts an incoming connection on |listen_socket_|, but passes back // a raw SocketDescriptor instead of a StreamSocket. int AcceptSocketDescriptor(SocketDescriptor* socket_descriptor, @@ -88,4 +92,4 @@ class NET_EXPORT UnixDomainServerSocket : public ServerSocket { } // namespace net -#endif // NET_SOCKET_UNIX_DOMAIN_SOCKET_POSIX_H_ +#endif // NET_SOCKET_UNIX_DOMAIN_SERVER_SOCKET_POSIX_H_ diff --git a/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc b/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc index bdf1efa29c4..be472c0775c 100644 --- a/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc +++ b/chromium/net/socket/unix_domain_server_socket_posix_unittest.cc @@ -55,7 +55,7 @@ TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPath) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), kUseAbstractNamespace); EXPECT_EQ(ERR_FILE_NOT_FOUND, - server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1)); + server_socket.BindAndListen(kInvalidSocketPath, /*backlog=*/1)); } TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPathWithAbstractNamespace) { @@ -63,11 +63,10 @@ TEST_F(UnixDomainServerSocketTest, ListenWithInvalidPathWithAbstractNamespace) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), kUseAbstractNamespace); #if defined(OS_ANDROID) || defined(OS_LINUX) - EXPECT_EQ(OK, - server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(kInvalidSocketPath, /*backlog=*/1)); #else EXPECT_EQ(ERR_ADDRESS_INVALID, - server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1)); + server_socket.BindAndListen(kInvalidSocketPath, /*backlog=*/1)); #endif } @@ -76,8 +75,8 @@ TEST_F(UnixDomainServerSocketTest, ListenAgainAfterFailureWithInvalidPath) { UnixDomainServerSocket server_socket(CreateAuthCallback(true), kUseAbstractNamespace); EXPECT_EQ(ERR_FILE_NOT_FOUND, - server_socket.ListenWithAddressAndPort(kInvalidSocketPath, 0, 1)); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + server_socket.BindAndListen(kInvalidSocketPath, /*backlog=*/1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); } TEST_F(UnixDomainServerSocketTest, AcceptWithForbiddenUser) { @@ -85,7 +84,7 @@ TEST_F(UnixDomainServerSocketTest, AcceptWithForbiddenUser) { UnixDomainServerSocket server_socket(CreateAuthCallback(false), kUseAbstractNamespace); - EXPECT_EQ(OK, server_socket.ListenWithAddressAndPort(socket_path_, 0, 1)); + EXPECT_EQ(OK, server_socket.BindAndListen(socket_path_, /*backlog=*/1)); scoped_ptr<StreamSocket> accepted_socket; TestCompletionCallback accept_callback; @@ -119,6 +118,21 @@ TEST_F(UnixDomainServerSocketTest, AcceptWithForbiddenUser) { EXPECT_FALSE(accepted_socket); } +TEST_F(UnixDomainServerSocketTest, UnimplementedMethodsFail) { + const bool kUseAbstractNamespace = false; + UnixDomainServerSocket server_socket(CreateAuthCallback(true), + kUseAbstractNamespace); + + IPEndPoint ep; + EXPECT_EQ(ERR_NOT_IMPLEMENTED, server_socket.Listen(ep, 0)); + EXPECT_EQ(ERR_NOT_IMPLEMENTED, + server_socket.ListenWithAddressAndPort(kInvalidSocketPath, + 0, + /*backlog=*/1)); + + EXPECT_EQ(ERR_ADDRESS_INVALID, server_socket.GetLocalAddress(&ep)); +} + // Normal cases including read/write are tested by UnixDomainClientSocketTest. } // namespace diff --git a/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc b/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc index 6808d3733fc..640d8b3bd5d 100644 --- a/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc +++ b/chromium/net/socket/websocket_endpoint_lock_manager_unittest.cc @@ -9,6 +9,7 @@ #include "base/message_loop/message_loop.h" #include "base/run_loop.h" #include "base/time/time.h" +#include "net/base/ip_address.h" #include "net/base/net_errors.h" #include "net/socket/next_proto.h" #include "net/socket/socket_test_util.h" @@ -48,8 +49,6 @@ class FakeStreamSocket : public StreamSocket { bool WasEverUsed() const override { return false; } - bool UsingTCPFastOpen() const override { return false; } - bool WasNpnNegotiated() const override { return false; } NextProto GetNegotiatedProtocol() const override { return kProtoUnknown; } @@ -138,9 +137,7 @@ class WebSocketEndpointLockManagerTest : public ::testing::Test { WebSocketEndpointLockManager* instance() const { return instance_; } IPEndPoint DummyEndpoint() { - IPAddressNumber ip_address_number; - CHECK(ParseIPLiteralToNumber("127.0.0.1", &ip_address_number)); - return IPEndPoint(ip_address_number, 80); + return IPEndPoint(IPAddress::IPv4Localhost(), 80); } void UnlockDummyEndpoint(int times) { diff --git a/chromium/net/socket/websocket_transport_client_socket_pool.cc b/chromium/net/socket/websocket_transport_client_socket_pool.cc index 73d8599b0e9..e4f0883acce 100644 --- a/chromium/net/socket/websocket_transport_client_socket_pool.cc +++ b/chromium/net/socket/websocket_transport_client_socket_pool.cc @@ -39,6 +39,7 @@ const int kTransportConnectJobTimeoutInSeconds = 240; // 4 minutes. WebSocketTransportConnectJob::WebSocketTransportConnectJob( const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<TransportSocketParams>& params, TimeDelta timeout_duration, const CompletionCallback& callback, @@ -51,6 +52,7 @@ WebSocketTransportConnectJob::WebSocketTransportConnectJob( : ConnectJob(group_name, timeout_duration, priority, + respect_limits, delegate, BoundNetLog::Make(pool_net_log, NetLog::SOURCE_CONNECT_JOB)), helper_(params, client_socket_factory, host_resolver, &connect_timing_), @@ -272,6 +274,7 @@ int WebSocketTransportClientSocketPool::RequestSocket( const std::string& group_name, const void* params, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& request_net_log) { @@ -286,7 +289,8 @@ int WebSocketTransportClientSocketPool::RequestSocket( request_net_log.BeginEvent(NetLog::TYPE_SOCKET_POOL); - if (ReachedMaxSocketsLimit() && !casted_params->ignore_limits()) { + if (ReachedMaxSocketsLimit() && + respect_limits == ClientSocketPool::RespectLimits::ENABLED) { request_net_log.AddEvent(NetLog::TYPE_SOCKET_POOL_STALLED_MAX_SOCKETS); // TODO(ricea): Use emplace_back when C++11 becomes allowed. StalledRequest request( @@ -306,17 +310,10 @@ int WebSocketTransportClientSocketPool::RequestSocket( } scoped_ptr<WebSocketTransportConnectJob> connect_job( - new WebSocketTransportConnectJob(group_name, - priority, - casted_params, - ConnectionTimeout(), - callback, - client_socket_factory_, - host_resolver_, - handle, - &connect_job_delegate_, - pool_net_log_, - request_net_log)); + new WebSocketTransportConnectJob( + group_name, priority, respect_limits, casted_params, + ConnectionTimeout(), callback, client_socket_factory_, host_resolver_, + handle, &connect_job_delegate_, pool_net_log_, request_net_log)); int rv = connect_job->Connect(); // Regardless of the outcome of |connect_job|, it will always be bound to @@ -589,12 +586,11 @@ void WebSocketTransportClientSocketPool::ActivateStalledRequest() { StalledRequest request(stalled_request_queue_.front()); stalled_request_queue_.pop_front(); stalled_request_map_.erase(request.handle); - int rv = RequestSocket("ignored", - &request.params, - request.priority, - request.handle, - request.callback, - request.net_log); + int rv = RequestSocket("ignored", &request.params, request.priority, + // Stalled requests can't have |respect_limits| + // DISABLED. + RespectLimits::ENABLED, request.handle, + request.callback, request.net_log); // ActivateStalledRequest() never returns synchronously, so it is never // called re-entrantly. if (rv != ERR_IO_PENDING) @@ -638,6 +634,9 @@ WebSocketTransportClientSocketPool::StalledRequest::StalledRequest( callback(callback), net_log(net_log) {} +WebSocketTransportClientSocketPool::StalledRequest::StalledRequest( + const StalledRequest& other) = default; + WebSocketTransportClientSocketPool::StalledRequest::~StalledRequest() {} } // namespace net diff --git a/chromium/net/socket/websocket_transport_client_socket_pool.h b/chromium/net/socket/websocket_transport_client_socket_pool.h index c0f1351dfe8..23b777457fa 100644 --- a/chromium/net/socket/websocket_transport_client_socket_pool.h +++ b/chromium/net/socket/websocket_transport_client_socket_pool.h @@ -43,6 +43,7 @@ class NET_EXPORT_PRIVATE WebSocketTransportConnectJob : public ConnectJob { WebSocketTransportConnectJob( const std::string& group_name, RequestPriority priority, + ClientSocketPool::RespectLimits respect_limits, const scoped_refptr<TransportSocketParams>& params, base::TimeDelta timeout_duration, const CompletionCallback& callback, @@ -134,6 +135,7 @@ class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool int RequestSocket(const std::string& group_name, const void* resolve_info, RequestPriority priority, + RespectLimits respect_limits, ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log) override; @@ -183,6 +185,7 @@ class NET_EXPORT_PRIVATE WebSocketTransportClientSocketPool ClientSocketHandle* handle, const CompletionCallback& callback, const BoundNetLog& net_log); + StalledRequest(const StalledRequest& other); ~StalledRequest(); const scoped_refptr<TransportSocketParams> params; const RequestPriority priority; diff --git a/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc b/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc index 96aa84a9921..bbc9de80018 100644 --- a/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc +++ b/chromium/net/socket/websocket_transport_client_socket_pool_unittest.cc @@ -21,7 +21,6 @@ #include "net/base/load_timing_info.h" #include "net/base/load_timing_info_test_util.h" #include "net/base/net_errors.h" -#include "net/base/net_util.h" #include "net/base/test_completion_callback.h" #include "net/dns/mock_host_resolver.h" #include "net/log/test_net_log.h" @@ -55,7 +54,6 @@ class WebSocketTransportClientSocketPoolTest : public ::testing::Test { : params_(new TransportSocketParams( HostPortPair("www.google.com", 80), false, - false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)), host_resolver_(new MockHostResolver), @@ -81,11 +79,11 @@ class WebSocketTransportClientSocketPoolTest : public ::testing::Test { new TransportSocketParams( HostPortPair("www.google.com", 80), false, - false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); return test_base_.StartRequestUsingPool( - &pool_, group_name, priority, params); + &pool_, group_name, priority, ClientSocketPool::RespectLimits::ENABLED, + params); } int GetOrderOfRequest(size_t index) { @@ -122,8 +120,9 @@ class WebSocketTransportClientSocketPoolTest : public ::testing::Test { TEST_F(WebSocketTransportClientSocketPoolTest, Basic) { TestCompletionCallback callback; ClientSocketHandle handle; - int rv = handle.Init( - "a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); + int rv = + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -142,12 +141,9 @@ TEST_F(WebSocketTransportClientSocketPoolTest, SetResolvePriorityOnInit) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - priority, - callback.callback(), - &pool_, - BoundNetLog())); + handle.Init("a", params_, priority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(priority, host_resolver_->last_request_priority()); } } @@ -158,15 +154,12 @@ TEST_F(WebSocketTransportClientSocketPoolTest, InitHostResolutionFailure) { ClientSocketHandle handle; HostPortPair host_port_pair("unresolvable.host.name", 80); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - host_port_pair, false, false, OnHostResolutionCallback(), + host_port_pair, false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - dest, - kDefaultPriority, - callback.callback(), - &pool_, - BoundNetLog())); + handle.Init("a", dest, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(ERR_NAME_NOT_RESOLVED, callback.WaitForResult()); } @@ -176,23 +169,17 @@ TEST_F(WebSocketTransportClientSocketPoolTest, InitConnectionFailure) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - kDefaultPriority, - callback.callback(), - &pool_, - BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); // Make the host resolutions complete synchronously this time. host_resolver_->set_synchronous_mode(true); EXPECT_EQ(ERR_CONNECTION_FAILED, - handle.Init("a", - params_, - kDefaultPriority, - callback.callback(), - &pool_, - BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); } TEST_F(WebSocketTransportClientSocketPoolTest, PendingRequestsFinishFifo) { @@ -267,12 +254,9 @@ TEST_F(WebSocketTransportClientSocketPoolTest, CancelRequestClearGroup) { TestCompletionCallback callback; ClientSocketHandle handle; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - kDefaultPriority, - callback.callback(), - &pool_, - BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); handle.Reset(); } @@ -283,19 +267,13 @@ TEST_F(WebSocketTransportClientSocketPoolTest, TwoRequestsCancelOne) { TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - kDefaultPriority, - callback.callback(), - &pool_, - BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); EXPECT_EQ(ERR_IO_PENDING, - handle2.Init("a", - params_, - kDefaultPriority, - callback2.callback(), - &pool_, - BoundNetLog())); + handle2.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), &pool_, BoundNetLog())); handle.Reset(); @@ -309,23 +287,17 @@ TEST_F(WebSocketTransportClientSocketPoolTest, ConnectCancelConnect) { ClientSocketHandle handle; TestCompletionCallback callback; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - kDefaultPriority, - callback.callback(), - &pool_, - BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog())); handle.Reset(); TestCompletionCallback callback2; EXPECT_EQ(ERR_IO_PENDING, - handle.Init("a", - params_, - kDefaultPriority, - callback2.callback(), - &pool_, - BoundNetLog())); + handle.Init("a", params_, kDefaultPriority, + ClientSocketPool::RespectLimits::ENABLED, + callback2.callback(), &pool_, BoundNetLog())); host_resolver_->set_synchronous_mode(true); // At this point, handle has two ConnectingSockets out for it. Due to the @@ -395,11 +367,11 @@ void RequestSocketOnComplete(ClientSocketHandle* handle, handle->Reset(); scoped_refptr<TransportSocketParams> dest(new TransportSocketParams( - HostPortPair("www.google.com", 80), false, false, - OnHostResolutionCallback(), + HostPortPair("www.google.com", 80), false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); int rv = - handle->Init("a", dest, LOWEST, nested_callback, pool, BoundNetLog()); + handle->Init("a", dest, LOWEST, ClientSocketPool::RespectLimits::ENABLED, + nested_callback, pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); if (ERR_IO_PENDING != rv) nested_callback.Run(rv); @@ -414,14 +386,14 @@ TEST_F(WebSocketTransportClientSocketPoolTest, RequestTwice) { new TransportSocketParams( HostPortPair("www.google.com", 80), false, - false, OnHostResolutionCallback(), TransportSocketParams::COMBINE_CONNECT_AND_WRITE_DEFAULT)); TestCompletionCallback second_result_callback; - int rv = handle.Init("a", dest, LOWEST, - base::Bind(&RequestSocketOnComplete, &handle, &pool_, - second_result_callback.callback()), - &pool_, BoundNetLog()); + int rv = + handle.Init("a", dest, LOWEST, ClientSocketPool::RespectLimits::ENABLED, + base::Bind(&RequestSocketOnComplete, &handle, &pool_, + second_result_callback.callback()), + &pool_, BoundNetLog()); ASSERT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(OK, second_result_callback.WaitForResult()); @@ -492,8 +464,9 @@ TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleReset) { TEST_F(WebSocketTransportClientSocketPoolTest, LockReleasedOnHandleDelete) { TestCompletionCallback callback; scoped_ptr<ClientSocketHandle> handle(new ClientSocketHandle); - int rv = handle->Init( - "a", params_, LOW, callback.callback(), &pool_, BoundNetLog()); + int rv = + handle->Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool_, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(ERR_IO_PENDING, StartRequest("a", kDefaultPriority)); @@ -560,7 +533,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, TestCompletionCallback callback; ClientSocketHandle handle; int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -602,7 +576,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, TestCompletionCallback callback; ClientSocketHandle handle; int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -634,7 +609,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, TestCompletionCallback callback; ClientSocketHandle handle; int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -664,7 +640,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv4HasNoFallback) { TestCompletionCallback callback; ClientSocketHandle handle; int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.is_initialized()); EXPECT_FALSE(handle.socket()); @@ -705,7 +682,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv6InstantFail) { TestCompletionCallback callback; ClientSocketHandle handle; int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(OK, rv); ASSERT_TRUE(handle.socket()); @@ -741,7 +719,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, IPv6RapidFail) { TestCompletionCallback callback; ClientSocketHandle handle; int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_FALSE(handle.socket()); @@ -777,7 +756,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, FirstSuccessWins) { TestCompletionCallback callback; ClientSocketHandle handle; int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); ASSERT_FALSE(handle.socket()); @@ -829,7 +809,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, LastFailureWins) { ClientSocketHandle handle; base::TimeTicks start(base::TimeTicks::Now()); int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(ERR_CONNECTION_FAILED, callback.WaitForResult()); @@ -866,7 +847,8 @@ TEST_F(WebSocketTransportClientSocketPoolTest, DISABLED_OverallTimeoutApplies) { ClientSocketHandle handle; int rv = - handle.Init("a", params_, LOW, callback.callback(), &pool, BoundNetLog()); + handle.Init("a", params_, LOW, ClientSocketPool::RespectLimits::ENABLED, + callback.callback(), &pool, BoundNetLog()); EXPECT_EQ(ERR_IO_PENDING, rv); EXPECT_EQ(ERR_TIMED_OUT, callback.WaitForResult()); |